diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 00:40:04 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch) | |
| tree | 0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/examples | |
| parent | 84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff) | |
| download | soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip | |
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/examples')
| -rw-r--r-- | parakeet-rs/examples/diarization.rs | 137 | ||||
| -rw-r--r-- | parakeet-rs/examples/raw.rs | 86 | ||||
| -rw-r--r-- | parakeet-rs/examples/streaming.rs | 129 | ||||
| -rw-r--r-- | parakeet-rs/examples/transcribe.rs | 106 |
4 files changed, 458 insertions, 0 deletions
diff --git a/parakeet-rs/examples/diarization.rs b/parakeet-rs/examples/diarization.rs new file mode 100644 index 0000000..5982ecb --- /dev/null +++ b/parakeet-rs/examples/diarization.rs @@ -0,0 +1,137 @@ +/* +Speaker Diarization with NVIDIA Sortformer v2 (Streaming) + +Download the Sortformer v2 model: +https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx +Download test audio: +wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav + +Usage: +cargo run --example diarization --features sortformer 6_speakers.wav + +NOTE: This example combines two NVIDIA models: +- Parakeet-TDT: Provides transcription with sentence-level timestamps +- Sortformer v2: Provides streaming speaker identification (4 speakers max) +- We use TDT's sentence timestamps + Sortformer's speaker IDs +- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN) +- For more information: +https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 + +WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence +length limitations (~8-10 minutes max). For production use with long audio files, +run Sortformer on the full audio for diarization, then chunk the audio into +~5-minute segments for TDT transcription, and map the results back together. +*/ + +#[cfg(feature = "sortformer")] +use parakeet_rs::sortformer::{DiarizationConfig, Sortformer}; +#[cfg(feature = "sortformer")] +use parakeet_rs::TimestampMode; +#[cfg(feature = "sortformer")] +use hound; +#[cfg(feature = "sortformer")] +use std::env; +#[cfg(feature = "sortformer")] +use std::time::Instant; + +#[allow(unreachable_code)] +fn main() -> Result<(), Box<dyn std::error::Error>> { + #[cfg(not(feature = "sortformer"))] + { + eprintln!("Error: This example requires the 'sortformer' feature."); + eprintln!("Please run with: cargo run --example diarization --features sortformer <audio.wav>"); + return Err("sortformer feature not enabled".into()); + } + + #[cfg(feature = "sortformer")] + { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = args.get(1) + .expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>"); + + println!("{}", "=".repeat(80)); + println!("Step 1/3: Loading audio..."); + + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + let audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32; + println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)", + audio.len(), spec.sample_rate, spec.channels, duration); + + println!("{}", "=".repeat(80)); + println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)..."); + + // Create Sortformer with default config (callhome) + let mut sortformer = Sortformer::with_config( + "diar_streaming_sortformer_4spk-v2.onnx", + None, // default exec config + DiarizationConfig::callhome(), + )?; + + let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?; + + println!("Found {} speaker segments from Sortformer", speaker_segments.len()); + + // Print raw diarization segments + println!("\nRaw diarization segments:"); + for seg in &speaker_segments { + println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id); + } + + println!("\n{}", "=".repeat(80)); + println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n"); + + // Use TDT for transcription with sentence-level timestamps + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; + + // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation) + if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) { + // For each sentence from TDT, find the corresponding speaker from Sortformer + for segment in &result.tokens { + // Find speaker with maximum overlap + let speaker = speaker_segments + .iter() + .filter_map(|s| { + // Calculate overlap between transcription and diarization segment + let overlap_start = segment.start.max(s.start); + let overlap_end = segment.end.min(s.end); + let overlap = (overlap_end - overlap_start).max(0.0); + if overlap > 0.0 { + Some((s.speaker_id, overlap)) + } else { + None + } + }) + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .map(|(id, _)| format!("Speaker {}", id)) + .unwrap_or_else(|| "UNKNOWN".to_string()); + + println!("[{:.2}s - {:.2}s] {}: {}", + segment.start, segment.end, speaker, segment.text); + } + } + + println!("\n{}", "=".repeat(80)); + let elapsed = start_time.elapsed(); + println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32()); + println!("• UNKNOWN: Segments where no speaker was detected by Sortformer"); + println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)"); + + Ok(()) + } + + #[cfg(not(feature = "sortformer"))] + unreachable!() +} diff --git a/parakeet-rs/examples/raw.rs b/parakeet-rs/examples/raw.rs new file mode 100644 index 0000000..a1a2adc --- /dev/null +++ b/parakeet-rs/examples/raw.rs @@ -0,0 +1,86 @@ +/* +Demonstrates using transcribe_samples() + +This example shows manual audio loading and calling transcribe_samples() directly +with sample_rate and channels instead of using transcribe_file() + +Usage: +cargo run --example raw 6_speakers.wav +cargo run --example raw 6_speakers.wav tdt + +WARNING: TDT model has sequence length limitations (~8-10 minutes max). +For longer audio files, you must split into chunks (e.g., 5-minute segments) +and transcribe each chunk separately. Attempting to transcribe 25+ minute +audio files in one call will cause ONNX runtime errors. +Otherwise you will likely get a error like: +"Error: Ort(Error { code: RuntimeException, msg: "Non-zero status code returned while running Add node. Name:'/layers.0/self_attn/Add_2' Status Message: /Users/runner/work/ort-artifacts/ort-artifacts/onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. })" +*/ + +use parakeet_rs::{Parakeet, ParakeetTDT, TimestampMode}; +use std::env; +use std::time::Instant; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = if args.len() > 1 { + &args[1] + } else { + "6_speakers.wav" + }; + + let use_tdt = args.len() > 2 && args[2] == "tdt"; + + // Load audio manually using hound (or any other audio library) + // remember if you use raw audio API, you need to handle audio preprocessing yourself! + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + println!("Audio info: {}Hz, {} channel(s)", spec.sample_rate, spec.channels); + + let audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + if use_tdt { + println!("Loading TDT model..."); + let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?; + + // Use transcribe_samples() with raw parameters and timestamp mode + let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences))?; + + println!("{}", result.text); + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + } else { + println!("Loading CTC model..."); + let mut parakeet = Parakeet::from_pretrained(".", None)?; + + // CTC model doesn't predict punctuation (lowercase alphabet only) + // This means no sentence boundaries. we use Words mode instead of Sentences + let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Words))?; + + println!("{}", result.text); + + // Access word-level timestamps (showing first 10 for brevity) + // Note: CTC generates word-level timestamps but cannot segment into sentences + // due to lack of punctuation prediction - this is a model limitation if I not mistake + println!("\nWords (first 10):"); + for word in result.tokens.iter().take(10) { + println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); + } + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + + Ok(()) +} diff --git a/parakeet-rs/examples/streaming.rs b/parakeet-rs/examples/streaming.rs new file mode 100644 index 0000000..f5d36c9 --- /dev/null +++ b/parakeet-rs/examples/streaming.rs @@ -0,0 +1,129 @@ +/* +Demonstrates streaming ASR with Parakeet RealTime EOU + +Download models files from: +https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx + +This example +- Maintains 4-second ring buffer for feature extraction context +- Processes 160ms chunks (2560 samples at 16kHz) +- Extracts features from full buffer, then slices last 25 frames +- Encoder receives: 9 frames (pre-encode cache) + 16 frames (new) = 25 total +- Cache states (cache_last_channel/time) maintain temporal context + +Model files required in ./fullstr/: + - encoder.onnx (cache_aware_stream_step export) + - decoder_joint.onnx + - tokenizer.json + +Additional notes: +let reset_on_eou: bool = false; +I must admit that this is not work very well on my real world tests :/ + + +Usage: +cargo run --release --example streaming <audio.wav> +*/ + +use hound; +use parakeet_rs::ParakeetEOU; +use std::env; +use std::time::Instant; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + + let args: Vec<String> = env::args().collect(); + let audio_path = args + .get(1) + .expect("Usage: cargo run --release --example streaming <audio.wav>"); + + println!("Loading model from ./fullstr..."); + let mut parakeet = ParakeetEOU::from_pretrained("./fullstr", None)?; + + println!("Loading audio: {}", audio_path); + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + let mut audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + if spec.sample_rate != 16000 { + return Err(format!( + "Expected 16kHz audio, got {}Hz. Please resample first.", + spec.sample_rate + ) + .into()); + } + + if spec.channels > 1 { + audio = audio + .chunks(spec.channels as usize) + .map(|chunk| chunk.iter().sum::<f32>() / spec.channels as f32) + .collect(); + } + + let max_val = audio.iter().fold(0.0f32, |a, &b| a.max(b.abs())); + if max_val > 1e-6 { + let norm_factor = max_val + 1e-5; + for sample in &mut audio { + *sample /= norm_factor; + } + } + + let duration = audio.len() as f32 / 16000.0; + // 160ms at 16kHz + const CHUNK_SIZE: usize = 2560; + let reset_on_eou: bool = false; + + println!("Streaming transcription (160ms chunks with 4s buffer)...\n"); + + let mut full_text = String::new(); + + for chunk in audio.chunks(CHUNK_SIZE) { + let chunk_vec = if chunk.len() < CHUNK_SIZE { + let mut padded = chunk.to_vec(); + padded.resize(CHUNK_SIZE, 0.0); + padded + } else { + chunk.to_vec() + }; + + let text = parakeet.transcribe(&chunk_vec, reset_on_eou)?; + if !text.is_empty() { + print!("{}", text); + std::io::Write::flush(&mut std::io::stdout())?; + full_text.push_str(&text); + } + } + + println!("\n\nFlushing decoder..."); + let silence = vec![0.0f32; CHUNK_SIZE]; + for _ in 0..3 { + let text = parakeet.transcribe(&silence, reset_on_eou)?; + if !text.is_empty() { + print!("{}", text); + std::io::Write::flush(&mut std::io::stdout())?; + full_text.push_str(&text); + } + } + + println!("\n\nFinal Transcription:\n{}", full_text.trim()); + + let elapsed = start_time.elapsed(); + println!( + "\nTranscription completed in {:.2}s (audio: {:.2}s, RTF: {:.2}x)", + elapsed.as_secs_f32(), + duration, + duration / elapsed.as_secs_f32() + ); + + Ok(()) +} diff --git a/parakeet-rs/examples/transcribe.rs b/parakeet-rs/examples/transcribe.rs new file mode 100644 index 0000000..685e8de --- /dev/null +++ b/parakeet-rs/examples/transcribe.rs @@ -0,0 +1,106 @@ +/* +transcribes entire audio, no diarization +wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav + +CTC (English-only): +cargo run --example transcribe 6_speakers.wav + +TDT (Multilingual): +cargo run --example transcribe 6_speakers.wav tdt + +NOTE: For manual audio loading without using transcribe_file(), see examples/raw.rs +- Shows transcribe_samples(audio, sample_rate, channels, timestamps) usage + +WARNING: This may fail on very long audio files (>8 min). +For longer audio, use the pyannote example which processes segments, or split your audio into chunks. + +Note: The coreml feature flag is only for reproducing a known ONNX Runtime bug. +Just ignore it :). See: https://github.com/microsoft/onnxruntime/issues/26355 +*/ +use parakeet_rs::{Parakeet, TimestampMode}; +use std::env; +use std::time::Instant; + +#[cfg(feature = "coreml")] +use parakeet_rs::{ExecutionConfig, ExecutionProvider}; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = if args.len() > 1 { + &args[1] + } else { + "6_speakers.wav" + }; + + let use_tdt = args.len() > 2 && args[2] == "tdt"; + + // TDT model (multilingual, 25 languages) + if use_tdt { + #[cfg(feature = "coreml")] + { + let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", Some(config))?; + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; + println!("{}", result.text); + + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + return Ok(()); + } + + #[cfg(not(feature = "coreml"))] + { + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; + println!("{}", result.text); + + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + return Ok(()); + } + } + + // CTC model (English-only) + #[cfg(feature = "coreml")] + let mut parakeet = { + let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); + Parakeet::from_pretrained(".", Some(config))? + }; + + // Default: CPU execution provider (works correctly) + // Auto-detects model with priority: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx + // Or specify exact model: Parakeet::from_pretrained("model_q4.onnx", None)? + #[cfg(not(feature = "coreml"))] + let mut parakeet = Parakeet::from_pretrained(".", None)?; + + // CTC model doesn't predict punctuation (lowercase alphabet only) + // This means no sentence boundaries - use Words mode instead of Sentences + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Words))?; + + // Print transcription + println!("{}", result.text); + + // Access word-level timestamps (showing first 10 for brevity) + // Note: CTC generates word-level timestamps but cannot segment into sentences + // due to lack of punctuation prediction - this is a model limitation + println!("\nWords (first 10):"); + for word in result.tokens.iter().take(10) { + println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + + Ok(()) +} |
