diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/listen.rs | 131 | ||||
| -rw-r--r-- | makima/src/main.rs | 7 |
2 files changed, 92 insertions, 46 deletions
diff --git a/makima/src/listen.rs b/makima/src/listen.rs index d478bd9..a0f4246 100644 --- a/makima/src/listen.rs +++ b/makima/src/listen.rs @@ -1,57 +1,102 @@ -use std::path::PathBuf; -use transcribe_rs::TranscriptionEngine; -use transcribe_rs::engines::parakeet::{ParakeetEngine, ParakeetModelParams}; +use std::cmp::Ordering; +use std::path::Path; -pub struct Listen { - name: String, - text: Vec<Dialogue>, -} -struct Dialogue { - character: Character, - text: String, -} +use hound::WavReader; +use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment}; +use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode}; -struct Character { - name: String, +const SAMPLE_RATE: u32 = 16_000; + +pub struct DialogueSegment { + pub speaker: String, + pub start: f32, + pub end: f32, + pub text: String, } -fn dialogue(character: &str, text: &str) -> Dialogue { - Dialogue { - character: Character { - name: character.to_string(), - }, - text: "".to_string(), +pub(crate) fn listen() -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error>> { + let audio_path = Path::new("audio.wav"); + + let (audio, sample_rate, channels) = load_audio(audio_path)?; + + let mut parakeet = ParakeetTDT::from_pretrained("models/parakeet-tdt-0.6b-v3", None)?; + let transcription = parakeet.transcribe_samples( + audio.clone(), + sample_rate, + channels, + Some(TimestampMode::Sentences), + )?; + + let mut sortformer = Sortformer::with_config( + "models/diarization/diar_streaming_sortformer_4spk-v2.onnx", + None, + DiarizationConfig::callhome(), + )?; + let diarization_segments = sortformer.diarize(audio, sample_rate, channels)?; + + let segments = align_speakers(&transcription.tokens, &diarization_segments); + for segment in &segments { + println!( + "[{:.2}s - {:.2}s] {}: {}", + segment.start, segment.end, segment.speaker, segment.text + ); } + + Ok(segments) } -pub(crate) fn listen() -> Result<Listen, Box<dyn std::error::Error>> { - let mut engine = ParakeetEngine::new(); - engine.load_model_with_params( - &PathBuf::from("models/parakeet-tdt-0.6b-v3-int8"), - ParakeetModelParams::int8(), - )?; +fn load_audio(path: &Path) -> Result<(Vec<f32>, u32, u16), Box<dyn std::error::Error>> { + let mut reader = WavReader::open(path)?; + let spec = reader.spec(); - // Only works with 16000 Hz and mono ( ffmpeg -i audio.mp3 -ar 16000 -ac 1 audio.wav) - let result = engine.transcribe_file(&PathBuf::from("audio.wav"), None)?; - println!("Transcription: {}", result.text); - - let mut txt: String = "".to_string(); - if let Some(segments) = result.segments { - for segment in segments { - println!( - "[{:.2}s - {:.2}s]: {}", - segment.start, segment.end, segment.text - ); - txt = txt + segment.text.as_str(); - } + if spec.sample_rate != SAMPLE_RATE { + return Err(format!( + "Expected {} Hz audio, got {} Hz", + SAMPLE_RATE, spec.sample_rate + ) + .into()); } - println!("Transcription: {}", txt); + let samples = match spec.sample_format { + hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + Ok((samples, spec.sample_rate, spec.channels)) +} +fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec<DialogueSegment> { + tokens + .iter() + .map(|token| { + let speaker = speaker_for_span(token.start, token.end, speakers) + .unwrap_or_else(|| "UNKNOWN".to_string()); + DialogueSegment { + speaker, + start: token.start, + end: token.end, + text: token.text.trim().to_string(), + } + }) + .collect() +} - Ok(Listen { - name: String::from("example"), - text: vec![dialogue("a", "part a"), dialogue("b", "part b")], - }) +fn speaker_for_span(start: f32, end: f32, speakers: &[SpeakerSegment]) -> Option<String> { + speakers + .iter() + .filter_map(|segment| { + let overlap_start = start.max(segment.start); + let overlap_end = end.min(segment.end); + let overlap = overlap_end - overlap_start; + if overlap > 0.0 { + Some((segment.speaker_id, overlap)) + } else { + None + } + }) + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)) + .map(|(id, _)| format!("Speaker {}", id)) } diff --git a/makima/src/main.rs b/makima/src/main.rs index 8b4845d..9097ef6 100644 --- a/makima/src/main.rs +++ b/makima/src/main.rs @@ -1,6 +1,7 @@ mod listen; -fn main() { - println!("Hello, world!"); - listen::listen().unwrap(); +fn main() -> Result<(), Box<dyn std::error::Error>> { + let segments = listen::listen()?; + println!("Captured {} diarized segments", segments.len()); + Ok(()) } |
