summaryrefslogtreecommitdiff
path: root/makima/src
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src')
-rw-r--r--makima/src/listen.rs131
-rw-r--r--makima/src/main.rs7
2 files changed, 92 insertions, 46 deletions
diff --git a/makima/src/listen.rs b/makima/src/listen.rs
index d478bd9..a0f4246 100644
--- a/makima/src/listen.rs
+++ b/makima/src/listen.rs
@@ -1,57 +1,102 @@
-use std::path::PathBuf;
-use transcribe_rs::TranscriptionEngine;
-use transcribe_rs::engines::parakeet::{ParakeetEngine, ParakeetModelParams};
+use std::cmp::Ordering;
+use std::path::Path;
-pub struct Listen {
- name: String,
- text: Vec<Dialogue>,
-}
-struct Dialogue {
- character: Character,
- text: String,
-}
+use hound::WavReader;
+use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment};
+use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode};
-struct Character {
- name: String,
+const SAMPLE_RATE: u32 = 16_000;
+
+pub struct DialogueSegment {
+ pub speaker: String,
+ pub start: f32,
+ pub end: f32,
+ pub text: String,
}
-fn dialogue(character: &str, text: &str) -> Dialogue {
- Dialogue {
- character: Character {
- name: character.to_string(),
- },
- text: "".to_string(),
+pub(crate) fn listen() -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error>> {
+ let audio_path = Path::new("audio.wav");
+
+ let (audio, sample_rate, channels) = load_audio(audio_path)?;
+
+ let mut parakeet = ParakeetTDT::from_pretrained("models/parakeet-tdt-0.6b-v3", None)?;
+ let transcription = parakeet.transcribe_samples(
+ audio.clone(),
+ sample_rate,
+ channels,
+ Some(TimestampMode::Sentences),
+ )?;
+
+ let mut sortformer = Sortformer::with_config(
+ "models/diarization/diar_streaming_sortformer_4spk-v2.onnx",
+ None,
+ DiarizationConfig::callhome(),
+ )?;
+ let diarization_segments = sortformer.diarize(audio, sample_rate, channels)?;
+
+ let segments = align_speakers(&transcription.tokens, &diarization_segments);
+ for segment in &segments {
+ println!(
+ "[{:.2}s - {:.2}s] {}: {}",
+ segment.start, segment.end, segment.speaker, segment.text
+ );
}
+
+ Ok(segments)
}
-pub(crate) fn listen() -> Result<Listen, Box<dyn std::error::Error>> {
- let mut engine = ParakeetEngine::new();
- engine.load_model_with_params(
- &PathBuf::from("models/parakeet-tdt-0.6b-v3-int8"),
- ParakeetModelParams::int8(),
- )?;
+fn load_audio(path: &Path) -> Result<(Vec<f32>, u32, u16), Box<dyn std::error::Error>> {
+ let mut reader = WavReader::open(path)?;
+ let spec = reader.spec();
- // Only works with 16000 Hz and mono ( ffmpeg -i audio.mp3 -ar 16000 -ac 1 audio.wav)
- let result = engine.transcribe_file(&PathBuf::from("audio.wav"), None)?;
- println!("Transcription: {}", result.text);
-
- let mut txt: String = "".to_string();
- if let Some(segments) = result.segments {
- for segment in segments {
- println!(
- "[{:.2}s - {:.2}s]: {}",
- segment.start, segment.end, segment.text
- );
- txt = txt + segment.text.as_str();
- }
+ if spec.sample_rate != SAMPLE_RATE {
+ return Err(format!(
+ "Expected {} Hz audio, got {} Hz",
+ SAMPLE_RATE, spec.sample_rate
+ )
+ .into());
}
- println!("Transcription: {}", txt);
+ let samples = match spec.sample_format {
+ hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?,
+ hound::SampleFormat::Int => reader
+ .samples::<i16>()
+ .map(|s| s.map(|s| s as f32 / 32768.0))
+ .collect::<Result<Vec<_>, _>>()?,
+ };
+ Ok((samples, spec.sample_rate, spec.channels))
+}
+fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec<DialogueSegment> {
+ tokens
+ .iter()
+ .map(|token| {
+ let speaker = speaker_for_span(token.start, token.end, speakers)
+ .unwrap_or_else(|| "UNKNOWN".to_string());
+ DialogueSegment {
+ speaker,
+ start: token.start,
+ end: token.end,
+ text: token.text.trim().to_string(),
+ }
+ })
+ .collect()
+}
- Ok(Listen {
- name: String::from("example"),
- text: vec![dialogue("a", "part a"), dialogue("b", "part b")],
- })
+fn speaker_for_span(start: f32, end: f32, speakers: &[SpeakerSegment]) -> Option<String> {
+ speakers
+ .iter()
+ .filter_map(|segment| {
+ let overlap_start = start.max(segment.start);
+ let overlap_end = end.min(segment.end);
+ let overlap = overlap_end - overlap_start;
+ if overlap > 0.0 {
+ Some((segment.speaker_id, overlap))
+ } else {
+ None
+ }
+ })
+ .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
+ .map(|(id, _)| format!("Speaker {}", id))
}
diff --git a/makima/src/main.rs b/makima/src/main.rs
index 8b4845d..9097ef6 100644
--- a/makima/src/main.rs
+++ b/makima/src/main.rs
@@ -1,6 +1,7 @@
mod listen;
-fn main() {
- println!("Hello, world!");
- listen::listen().unwrap();
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+ let segments = listen::listen()?;
+ println!("Captured {} diarized segments", segments.len());
+ Ok(())
}