From 55cacf6e1a087c0fa6950a1ddeb09060f787e541 Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 21 Dec 2025 00:40:04 +0000 Subject: Add EOU detection and streaming diarization --- parakeet-rs/examples/diarization.rs | 137 ++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 parakeet-rs/examples/diarization.rs (limited to 'parakeet-rs/examples/diarization.rs') diff --git a/parakeet-rs/examples/diarization.rs b/parakeet-rs/examples/diarization.rs new file mode 100644 index 0000000..5982ecb --- /dev/null +++ b/parakeet-rs/examples/diarization.rs @@ -0,0 +1,137 @@ +/* +Speaker Diarization with NVIDIA Sortformer v2 (Streaming) + +Download the Sortformer v2 model: +https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx +Download test audio: +wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav + +Usage: +cargo run --example diarization --features sortformer 6_speakers.wav + +NOTE: This example combines two NVIDIA models: +- Parakeet-TDT: Provides transcription with sentence-level timestamps +- Sortformer v2: Provides streaming speaker identification (4 speakers max) +- We use TDT's sentence timestamps + Sortformer's speaker IDs +- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN) +- For more information: +https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 + +WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence +length limitations (~8-10 minutes max). For production use with long audio files, +run Sortformer on the full audio for diarization, then chunk the audio into +~5-minute segments for TDT transcription, and map the results back together. +*/ + +#[cfg(feature = "sortformer")] +use parakeet_rs::sortformer::{DiarizationConfig, Sortformer}; +#[cfg(feature = "sortformer")] +use parakeet_rs::TimestampMode; +#[cfg(feature = "sortformer")] +use hound; +#[cfg(feature = "sortformer")] +use std::env; +#[cfg(feature = "sortformer")] +use std::time::Instant; + +#[allow(unreachable_code)] +fn main() -> Result<(), Box> { + #[cfg(not(feature = "sortformer"))] + { + eprintln!("Error: This example requires the 'sortformer' feature."); + eprintln!("Please run with: cargo run --example diarization --features sortformer "); + return Err("sortformer feature not enabled".into()); + } + + #[cfg(feature = "sortformer")] + { + let start_time = Instant::now(); + let args: Vec = env::args().collect(); + let audio_path = args.get(1) + .expect("Please specify audio file: cargo run --example diarization --features sortformer "); + + println!("{}", "=".repeat(80)); + println!("Step 1/3: Loading audio..."); + + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + let audio: Vec = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::() + .collect::, _>>()?, + hound::SampleFormat::Int => reader + .samples::() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::, _>>()?, + }; + + let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32; + println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)", + audio.len(), spec.sample_rate, spec.channels, duration); + + println!("{}", "=".repeat(80)); + println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)..."); + + // Create Sortformer with default config (callhome) + let mut sortformer = Sortformer::with_config( + "diar_streaming_sortformer_4spk-v2.onnx", + None, // default exec config + DiarizationConfig::callhome(), + )?; + + let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?; + + println!("Found {} speaker segments from Sortformer", speaker_segments.len()); + + // Print raw diarization segments + println!("\nRaw diarization segments:"); + for seg in &speaker_segments { + println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id); + } + + println!("\n{}", "=".repeat(80)); + println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n"); + + // Use TDT for transcription with sentence-level timestamps + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; + + // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation) + if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) { + // For each sentence from TDT, find the corresponding speaker from Sortformer + for segment in &result.tokens { + // Find speaker with maximum overlap + let speaker = speaker_segments + .iter() + .filter_map(|s| { + // Calculate overlap between transcription and diarization segment + let overlap_start = segment.start.max(s.start); + let overlap_end = segment.end.min(s.end); + let overlap = (overlap_end - overlap_start).max(0.0); + if overlap > 0.0 { + Some((s.speaker_id, overlap)) + } else { + None + } + }) + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .map(|(id, _)| format!("Speaker {}", id)) + .unwrap_or_else(|| "UNKNOWN".to_string()); + + println!("[{:.2}s - {:.2}s] {}: {}", + segment.start, segment.end, speaker, segment.text); + } + } + + println!("\n{}", "=".repeat(80)); + let elapsed = start_time.elapsed(); + println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32()); + println!("• UNKNOWN: Segments where no speaker was detected by Sortformer"); + println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)"); + + Ok(()) + } + + #[cfg(not(feature = "sortformer"))] + unreachable!() +} -- cgit v1.2.3