use std::cmp::Ordering; use std::path::Path; pub use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment}; pub use parakeet_rs::{ParakeetEOU, ParakeetTDT, TimedToken, TimestampMode}; use crate::audio; const STREAM_CHUNK_MS: u32 = 5_000; /// A segment of dialogue with speaker identification and timing. #[derive(Debug, Clone)] pub struct DialogueSegment { pub speaker: String, pub start: f32, pub end: f32, pub text: String, } pub(crate) fn listen() -> Result, Box> { let audio_path = Path::new("audio-ftc.mp3"); let normalized = audio::to_16k_mono_from_path(audio_path)?; let mut parakeet = ParakeetTDT::from_pretrained("models/parakeet-tdt-0.6b-v3", None)?; let mut sortformer = Sortformer::with_config( "models/diarization/diar_streaming_sortformer_4spk-v2.onnx", None, DiarizationConfig::callhome(), )?; let chunk_samples = samples_per_chunk(normalized.sample_rate, STREAM_CHUNK_MS); let mut cumulative_audio: Vec = Vec::new(); let mut last_printed_tokens = 0usize; let mut final_segments: Vec = Vec::new(); for (chunk_idx, chunk) in normalized.samples.chunks(chunk_samples).enumerate() { cumulative_audio.extend_from_slice(chunk); let diarization_segments = sortformer.diarize( cumulative_audio.clone(), normalized.sample_rate, normalized.channels, )?; let transcription = parakeet.transcribe_samples( cumulative_audio.clone(), normalized.sample_rate, normalized.channels, Some(TimestampMode::Sentences), )?; final_segments = align_speakers(&transcription.tokens, &diarization_segments); // Simulate "live" output by printing only newly emitted tokens. if transcription.tokens.len() > last_printed_tokens { let new_segments = &final_segments[last_printed_tokens..]; for segment in new_segments { println!( "[chunk {}] [{:.2}s - {:.2}s] {}: {}", chunk_idx, segment.start, segment.end, segment.speaker, segment.text ); } last_printed_tokens = transcription.tokens.len(); } } Ok(final_segments) } /// Align transcription tokens with speaker diarization segments. pub fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec { tokens .iter() .map(|token| { let speaker = speaker_for_span(token.start, token.end, speakers) .unwrap_or_else(|| "UNKNOWN".to_string()); DialogueSegment { speaker, start: token.start, end: token.end, text: token.text.trim().to_string(), } }) .collect() } /// Calculate the number of samples in a chunk of given duration. pub fn samples_per_chunk(sample_rate: u32, chunk_ms: u32) -> usize { let samples = (sample_rate as u64) .saturating_mul(chunk_ms as u64) .saturating_div(1_000); samples.max(1) as usize } #[cfg(test)] mod tests { use super::*; #[test] fn samples_per_chunk_rounds_down_and_clamps() { assert_eq!(samples_per_chunk(16_000, 1_000), 16_000); assert_eq!(samples_per_chunk(16_000, 160), 2_560); assert_eq!(samples_per_chunk(16_000, 0), 1); } } fn speaker_for_span(start: f32, end: f32, speakers: &[SpeakerSegment]) -> Option { speakers .iter() .filter_map(|segment| { let overlap_start = start.max(segment.start); let overlap_end = end.min(segment.end); let overlap = overlap_end - overlap_start; if overlap > 0.0 { Some((segment.speaker_id, overlap)) } else { None } }) .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)) .map(|(id, _)| format!("Speaker {}", id)) }