diff options
Diffstat (limited to 'parakeet-rs/examples')
| -rw-r--r-- | parakeet-rs/examples/diarization.rs | 137 | ||||
| -rw-r--r-- | parakeet-rs/examples/raw.rs | 86 | ||||
| -rw-r--r-- | parakeet-rs/examples/streaming.rs | 129 | ||||
| -rw-r--r-- | parakeet-rs/examples/transcribe.rs | 106 |
4 files changed, 0 insertions, 458 deletions
diff --git a/parakeet-rs/examples/diarization.rs b/parakeet-rs/examples/diarization.rs deleted file mode 100644 index 5982ecb..0000000 --- a/parakeet-rs/examples/diarization.rs +++ /dev/null @@ -1,137 +0,0 @@ -/* -Speaker Diarization with NVIDIA Sortformer v2 (Streaming) - -Download the Sortformer v2 model: -https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx -Download test audio: -wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav - -Usage: -cargo run --example diarization --features sortformer 6_speakers.wav - -NOTE: This example combines two NVIDIA models: -- Parakeet-TDT: Provides transcription with sentence-level timestamps -- Sortformer v2: Provides streaming speaker identification (4 speakers max) -- We use TDT's sentence timestamps + Sortformer's speaker IDs -- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN) -- For more information: -https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 - -WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence -length limitations (~8-10 minutes max). For production use with long audio files, -run Sortformer on the full audio for diarization, then chunk the audio into -~5-minute segments for TDT transcription, and map the results back together. -*/ - -#[cfg(feature = "sortformer")] -use parakeet_rs::sortformer::{DiarizationConfig, Sortformer}; -#[cfg(feature = "sortformer")] -use parakeet_rs::TimestampMode; -#[cfg(feature = "sortformer")] -use hound; -#[cfg(feature = "sortformer")] -use std::env; -#[cfg(feature = "sortformer")] -use std::time::Instant; - -#[allow(unreachable_code)] -fn main() -> Result<(), Box<dyn std::error::Error>> { - #[cfg(not(feature = "sortformer"))] - { - eprintln!("Error: This example requires the 'sortformer' feature."); - eprintln!("Please run with: cargo run --example diarization --features sortformer <audio.wav>"); - return Err("sortformer feature not enabled".into()); - } - - #[cfg(feature = "sortformer")] - { - let start_time = Instant::now(); - let args: Vec<String> = env::args().collect(); - let audio_path = args.get(1) - .expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>"); - - println!("{}", "=".repeat(80)); - println!("Step 1/3: Loading audio..."); - - let mut reader = hound::WavReader::open(audio_path)?; - let spec = reader.spec(); - - let audio: Vec<f32> = match spec.sample_format { - hound::SampleFormat::Float => reader - .samples::<f32>() - .collect::<Result<Vec<_>, _>>()?, - hound::SampleFormat::Int => reader - .samples::<i16>() - .map(|s| s.map(|s| s as f32 / 32768.0)) - .collect::<Result<Vec<_>, _>>()?, - }; - - let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32; - println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)", - audio.len(), spec.sample_rate, spec.channels, duration); - - println!("{}", "=".repeat(80)); - println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)..."); - - // Create Sortformer with default config (callhome) - let mut sortformer = Sortformer::with_config( - "diar_streaming_sortformer_4spk-v2.onnx", - None, // default exec config - DiarizationConfig::callhome(), - )?; - - let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?; - - println!("Found {} speaker segments from Sortformer", speaker_segments.len()); - - // Print raw diarization segments - println!("\nRaw diarization segments:"); - for seg in &speaker_segments { - println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id); - } - - println!("\n{}", "=".repeat(80)); - println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n"); - - // Use TDT for transcription with sentence-level timestamps - let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; - - // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation) - if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) { - // For each sentence from TDT, find the corresponding speaker from Sortformer - for segment in &result.tokens { - // Find speaker with maximum overlap - let speaker = speaker_segments - .iter() - .filter_map(|s| { - // Calculate overlap between transcription and diarization segment - let overlap_start = segment.start.max(s.start); - let overlap_end = segment.end.min(s.end); - let overlap = (overlap_end - overlap_start).max(0.0); - if overlap > 0.0 { - Some((s.speaker_id, overlap)) - } else { - None - } - }) - .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) - .map(|(id, _)| format!("Speaker {}", id)) - .unwrap_or_else(|| "UNKNOWN".to_string()); - - println!("[{:.2}s - {:.2}s] {}: {}", - segment.start, segment.end, speaker, segment.text); - } - } - - println!("\n{}", "=".repeat(80)); - let elapsed = start_time.elapsed(); - println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32()); - println!("• UNKNOWN: Segments where no speaker was detected by Sortformer"); - println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)"); - - Ok(()) - } - - #[cfg(not(feature = "sortformer"))] - unreachable!() -} diff --git a/parakeet-rs/examples/raw.rs b/parakeet-rs/examples/raw.rs deleted file mode 100644 index a1a2adc..0000000 --- a/parakeet-rs/examples/raw.rs +++ /dev/null @@ -1,86 +0,0 @@ -/* -Demonstrates using transcribe_samples() - -This example shows manual audio loading and calling transcribe_samples() directly -with sample_rate and channels instead of using transcribe_file() - -Usage: -cargo run --example raw 6_speakers.wav -cargo run --example raw 6_speakers.wav tdt - -WARNING: TDT model has sequence length limitations (~8-10 minutes max). -For longer audio files, you must split into chunks (e.g., 5-minute segments) -and transcribe each chunk separately. Attempting to transcribe 25+ minute -audio files in one call will cause ONNX runtime errors. -Otherwise you will likely get a error like: -"Error: Ort(Error { code: RuntimeException, msg: "Non-zero status code returned while running Add node. Name:'/layers.0/self_attn/Add_2' Status Message: /Users/runner/work/ort-artifacts/ort-artifacts/onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. })" -*/ - -use parakeet_rs::{Parakeet, ParakeetTDT, TimestampMode}; -use std::env; -use std::time::Instant; - -fn main() -> Result<(), Box<dyn std::error::Error>> { - let start_time = Instant::now(); - let args: Vec<String> = env::args().collect(); - let audio_path = if args.len() > 1 { - &args[1] - } else { - "6_speakers.wav" - }; - - let use_tdt = args.len() > 2 && args[2] == "tdt"; - - // Load audio manually using hound (or any other audio library) - // remember if you use raw audio API, you need to handle audio preprocessing yourself! - let mut reader = hound::WavReader::open(audio_path)?; - let spec = reader.spec(); - - println!("Audio info: {}Hz, {} channel(s)", spec.sample_rate, spec.channels); - - let audio: Vec<f32> = match spec.sample_format { - hound::SampleFormat::Float => reader - .samples::<f32>() - .collect::<Result<Vec<_>, _>>()?, - hound::SampleFormat::Int => reader - .samples::<i16>() - .map(|s| s.map(|s| s as f32 / 32768.0)) - .collect::<Result<Vec<_>, _>>()?, - }; - - if use_tdt { - println!("Loading TDT model..."); - let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?; - - // Use transcribe_samples() with raw parameters and timestamp mode - let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences))?; - - println!("{}", result.text); - println!("\nSentencess:"); - for segment in result.tokens.iter() { - println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); - } - } else { - println!("Loading CTC model..."); - let mut parakeet = Parakeet::from_pretrained(".", None)?; - - // CTC model doesn't predict punctuation (lowercase alphabet only) - // This means no sentence boundaries. we use Words mode instead of Sentences - let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Words))?; - - println!("{}", result.text); - - // Access word-level timestamps (showing first 10 for brevity) - // Note: CTC generates word-level timestamps but cannot segment into sentences - // due to lack of punctuation prediction - this is a model limitation if I not mistake - println!("\nWords (first 10):"); - for word in result.tokens.iter().take(10) { - println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); - } - } - - let elapsed = start_time.elapsed(); - println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); - - Ok(()) -} diff --git a/parakeet-rs/examples/streaming.rs b/parakeet-rs/examples/streaming.rs deleted file mode 100644 index f5d36c9..0000000 --- a/parakeet-rs/examples/streaming.rs +++ /dev/null @@ -1,129 +0,0 @@ -/* -Demonstrates streaming ASR with Parakeet RealTime EOU - -Download models files from: -https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx - -This example -- Maintains 4-second ring buffer for feature extraction context -- Processes 160ms chunks (2560 samples at 16kHz) -- Extracts features from full buffer, then slices last 25 frames -- Encoder receives: 9 frames (pre-encode cache) + 16 frames (new) = 25 total -- Cache states (cache_last_channel/time) maintain temporal context - -Model files required in ./fullstr/: - - encoder.onnx (cache_aware_stream_step export) - - decoder_joint.onnx - - tokenizer.json - -Additional notes: -let reset_on_eou: bool = false; -I must admit that this is not work very well on my real world tests :/ - - -Usage: -cargo run --release --example streaming <audio.wav> -*/ - -use hound; -use parakeet_rs::ParakeetEOU; -use std::env; -use std::time::Instant; - -fn main() -> Result<(), Box<dyn std::error::Error>> { - let start_time = Instant::now(); - - let args: Vec<String> = env::args().collect(); - let audio_path = args - .get(1) - .expect("Usage: cargo run --release --example streaming <audio.wav>"); - - println!("Loading model from ./fullstr..."); - let mut parakeet = ParakeetEOU::from_pretrained("./fullstr", None)?; - - println!("Loading audio: {}", audio_path); - let mut reader = hound::WavReader::open(audio_path)?; - let spec = reader.spec(); - - let mut audio: Vec<f32> = match spec.sample_format { - hound::SampleFormat::Float => reader - .samples::<f32>() - .collect::<Result<Vec<_>, _>>()?, - hound::SampleFormat::Int => reader - .samples::<i16>() - .map(|s| s.map(|s| s as f32 / 32768.0)) - .collect::<Result<Vec<_>, _>>()?, - }; - - if spec.sample_rate != 16000 { - return Err(format!( - "Expected 16kHz audio, got {}Hz. Please resample first.", - spec.sample_rate - ) - .into()); - } - - if spec.channels > 1 { - audio = audio - .chunks(spec.channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / spec.channels as f32) - .collect(); - } - - let max_val = audio.iter().fold(0.0f32, |a, &b| a.max(b.abs())); - if max_val > 1e-6 { - let norm_factor = max_val + 1e-5; - for sample in &mut audio { - *sample /= norm_factor; - } - } - - let duration = audio.len() as f32 / 16000.0; - // 160ms at 16kHz - const CHUNK_SIZE: usize = 2560; - let reset_on_eou: bool = false; - - println!("Streaming transcription (160ms chunks with 4s buffer)...\n"); - - let mut full_text = String::new(); - - for chunk in audio.chunks(CHUNK_SIZE) { - let chunk_vec = if chunk.len() < CHUNK_SIZE { - let mut padded = chunk.to_vec(); - padded.resize(CHUNK_SIZE, 0.0); - padded - } else { - chunk.to_vec() - }; - - let text = parakeet.transcribe(&chunk_vec, reset_on_eou)?; - if !text.is_empty() { - print!("{}", text); - std::io::Write::flush(&mut std::io::stdout())?; - full_text.push_str(&text); - } - } - - println!("\n\nFlushing decoder..."); - let silence = vec![0.0f32; CHUNK_SIZE]; - for _ in 0..3 { - let text = parakeet.transcribe(&silence, reset_on_eou)?; - if !text.is_empty() { - print!("{}", text); - std::io::Write::flush(&mut std::io::stdout())?; - full_text.push_str(&text); - } - } - - println!("\n\nFinal Transcription:\n{}", full_text.trim()); - - let elapsed = start_time.elapsed(); - println!( - "\nTranscription completed in {:.2}s (audio: {:.2}s, RTF: {:.2}x)", - elapsed.as_secs_f32(), - duration, - duration / elapsed.as_secs_f32() - ); - - Ok(()) -} diff --git a/parakeet-rs/examples/transcribe.rs b/parakeet-rs/examples/transcribe.rs deleted file mode 100644 index 685e8de..0000000 --- a/parakeet-rs/examples/transcribe.rs +++ /dev/null @@ -1,106 +0,0 @@ -/* -transcribes entire audio, no diarization -wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav - -CTC (English-only): -cargo run --example transcribe 6_speakers.wav - -TDT (Multilingual): -cargo run --example transcribe 6_speakers.wav tdt - -NOTE: For manual audio loading without using transcribe_file(), see examples/raw.rs -- Shows transcribe_samples(audio, sample_rate, channels, timestamps) usage - -WARNING: This may fail on very long audio files (>8 min). -For longer audio, use the pyannote example which processes segments, or split your audio into chunks. - -Note: The coreml feature flag is only for reproducing a known ONNX Runtime bug. -Just ignore it :). See: https://github.com/microsoft/onnxruntime/issues/26355 -*/ -use parakeet_rs::{Parakeet, TimestampMode}; -use std::env; -use std::time::Instant; - -#[cfg(feature = "coreml")] -use parakeet_rs::{ExecutionConfig, ExecutionProvider}; - -fn main() -> Result<(), Box<dyn std::error::Error>> { - let start_time = Instant::now(); - let args: Vec<String> = env::args().collect(); - let audio_path = if args.len() > 1 { - &args[1] - } else { - "6_speakers.wav" - }; - - let use_tdt = args.len() > 2 && args[2] == "tdt"; - - // TDT model (multilingual, 25 languages) - if use_tdt { - #[cfg(feature = "coreml")] - { - let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); - let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", Some(config))?; - let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; - println!("{}", result.text); - - println!("\nSentencess:"); - for segment in result.tokens.iter() { - println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); - } - - let elapsed = start_time.elapsed(); - println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); - return Ok(()); - } - - #[cfg(not(feature = "coreml"))] - { - let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; - let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; - println!("{}", result.text); - - println!("\nSentencess:"); - for segment in result.tokens.iter() { - println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); - } - - let elapsed = start_time.elapsed(); - println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); - return Ok(()); - } - } - - // CTC model (English-only) - #[cfg(feature = "coreml")] - let mut parakeet = { - let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); - Parakeet::from_pretrained(".", Some(config))? - }; - - // Default: CPU execution provider (works correctly) - // Auto-detects model with priority: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx - // Or specify exact model: Parakeet::from_pretrained("model_q4.onnx", None)? - #[cfg(not(feature = "coreml"))] - let mut parakeet = Parakeet::from_pretrained(".", None)?; - - // CTC model doesn't predict punctuation (lowercase alphabet only) - // This means no sentence boundaries - use Words mode instead of Sentences - let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Words))?; - - // Print transcription - println!("{}", result.text); - - // Access word-level timestamps (showing first 10 for brevity) - // Note: CTC generates word-level timestamps but cannot segment into sentences - // due to lack of punctuation prediction - this is a model limitation - println!("\nWords (first 10):"); - for word in result.tokens.iter().take(10) { - println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); - } - - let elapsed = start_time.elapsed(); - println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); - - Ok(()) -} |
