diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/bin/server.rs | 3 | ||||
| -rw-r--r-- | makima/src/listen.rs | 2 | ||||
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 229 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 12 |
4 files changed, 201 insertions, 45 deletions
diff --git a/makima/src/bin/server.rs b/makima/src/bin/server.rs index 1964cae..06b6585 100644 --- a/makima/src/bin/server.rs +++ b/makima/src/bin/server.rs @@ -10,6 +10,7 @@ use makima::server::{run_server, state::AppState}; /// Default model paths relative to the working directory. const PARAKEET_MODEL_DIR: &str = "models/parakeet-tdt-0.6b-v3"; +const PARAKEET_EOU_DIR: &str = "models/realtime_eou_120m-v1-onnx"; const SORTFORMER_MODEL_PATH: &str = "models/diarization/diar_streaming_sortformer_4spk-v2.onnx"; #[tokio::main] @@ -28,7 +29,7 @@ async fn main() -> anyhow::Result<()> { // Load ML models let state = Arc::new( - AppState::new(PARAKEET_MODEL_DIR, SORTFORMER_MODEL_PATH) + AppState::new(PARAKEET_MODEL_DIR, PARAKEET_EOU_DIR, SORTFORMER_MODEL_PATH) .map_err(|e| anyhow::anyhow!("Failed to load models: {}", e))?, ); diff --git a/makima/src/listen.rs b/makima/src/listen.rs index 24bf9da..91d616c 100644 --- a/makima/src/listen.rs +++ b/makima/src/listen.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use std::path::Path; pub use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment}; -pub use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode, Transcriber}; +pub use parakeet_rs::{ParakeetEOU, ParakeetTDT, TimedToken, TimestampMode}; use crate::audio; diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs index b1c1ad9..bf6746c 100644 --- a/makima/src/server/handlers/listen.rs +++ b/makima/src/server/handlers/listen.rs @@ -1,4 +1,4 @@ -//! WebSocket handler for streaming speech-to-text. +//! WebSocket handler for streaming speech-to-text with sliding window optimization. use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, @@ -9,15 +9,27 @@ use tokio::sync::mpsc; use uuid::Uuid; use crate::audio::{resample_and_mixdown, TARGET_CHANNELS, TARGET_SAMPLE_RATE}; -use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode, Transcriber}; +use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode}; use crate::server::messages::{ AudioEncoding, ClientMessage, ServerMessage, StartMessage, TranscriptMessage, }; use crate::server::state::SharedState; -/// Chunk size in milliseconds for streaming transcription. +/// Chunk size in milliseconds for triggering transcription processing. const STREAM_CHUNK_MS: u32 = 5_000; +/// Maximum window size in seconds for sliding window processing. +const MAX_WINDOW_SECONDS: f32 = 30.0; + +/// Maximum window size in samples at 16kHz. +const MAX_WINDOW_SAMPLES: usize = (MAX_WINDOW_SECONDS as usize) * (TARGET_SAMPLE_RATE as usize); + +/// EOU chunk size in samples (160ms at 16kHz). +const EOU_CHUNK_SIZE: usize = 2560; + +/// Context overlap in seconds to keep when trimming finalized audio. +const CONTEXT_OVERLAP_SECONDS: f32 = 2.0; + /// WebSocket upgrade handler for STT streaming. /// /// This endpoint accepts WebSocket connections for real-time speech-to-text @@ -28,7 +40,7 @@ const STREAM_CHUNK_MS: u32 = 5_000; responses( (status = 101, description = "WebSocket connection established"), ), - tag = "STT" + tag = "Listen" )] pub async fn websocket_handler( ws: WebSocketUpgrade, @@ -73,10 +85,25 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { // Audio format state let mut audio_format: Option<StartMessage> = None; - // Audio buffer for accumulating samples + // Main audio buffer for transcription (accumulates resampled 16kHz mono audio) let mut audio_buffer: Vec<f32> = Vec::new(); - let mut last_sent_end_time: f32 = 0.0; // Track the end time of last sent segment - let mut last_processed_len: usize = 0; // Track how much audio we've processed + + // EOU detection buffer (resampled audio for utterance detection) + let mut eou_buffer: Vec<f32> = Vec::new(); + let mut last_eou_text: String = String::new(); + let mut utterance_ended: bool = false; + + // Tracking state + let mut last_sent_end_time: f32 = 0.0; + let mut last_processed_len: usize = 0; + let mut audio_offset: f32 = 0.0; // Time offset from trimmed audio + let mut finalized_segments: Vec<DialogueSegment> = Vec::new(); + + // Reset Sortformer state for new session + { + let mut sortformer = state.sortformer.lock().await; + sortformer.reset_state(); + } // Process incoming messages while let Some(msg_result) = receiver.next().await { @@ -102,8 +129,17 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { ); audio_format = Some(start); audio_buffer.clear(); + eou_buffer.clear(); + last_eou_text.clear(); + utterance_ended = false; last_sent_end_time = 0.0; last_processed_len = 0; + audio_offset = 0.0; + finalized_segments.clear(); + + // Reset models for new session + let mut sortformer = state.sortformer.lock().await; + sortformer.reset_state(); } Ok(ClientMessage::Stop(stop)) => { tracing::info!( @@ -113,25 +149,52 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { "Session stopped by client" ); - if let Some(ref format) = audio_format { + if audio_format.is_some() { if !audio_buffer.is_empty() { tracing::debug!( session_id = %session_id, samples = audio_buffer.len(), "Processing final audio buffer" ); - match process_audio(&audio_buffer, format, &state).await { + + // Process remaining audio with sliding window + match process_audio_window(&audio_buffer, audio_offset, &state).await { Ok(segments) => { tracing::debug!( session_id = %session_id, total_segments = segments.len(), + finalized_count = finalized_segments.len(), last_sent_end = last_sent_end_time, "Final transcription complete" ); - // Step 1: Send any NEW segments as interim (is_final: false) - // These are segments that weren't sent during streaming + // Combine finalized segments with new segments + let mut all_segments = finalized_segments.clone(); + + // Add segments from current window that weren't finalized for seg in &segments { + // Adjust timestamps with offset + let adjusted_seg = DialogueSegment { + speaker: seg.speaker.clone(), + start: seg.start + audio_offset, + end: seg.end + audio_offset, + text: seg.text.clone(), + }; + + // Only add if not already finalized + if !finalized_segments.iter().any(|f| + (f.start - adjusted_seg.start).abs() < 0.1 && + f.text == adjusted_seg.text + ) { + all_segments.push(adjusted_seg); + } + } + + // Sort by start time + all_segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap()); + + // Send any NEW segments as interim first + for seg in &all_segments { if seg.end > last_sent_end_time { let _ = response_tx .send(ServerMessage::Transcript(TranscriptMessage { @@ -145,9 +208,8 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { } } - // Step 2: Send ALL segments as final (is_final: true) - // This is the complete authoritative transcript - for seg in &segments { + // Send ALL segments as final + for seg in &all_segments { let _ = response_tx .send(ServerMessage::Transcript(TranscriptMessage { speaker: seg.speaker.clone(), @@ -207,21 +269,56 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { // Decode binary audio data to f32 samples let samples = decode_audio_chunk(&data, format); - audio_buffer.extend(samples); - // Process when we have accumulated another chunk's worth of NEW audio - let chunk_samples = samples_per_chunk(format.sample_rate, STREAM_CHUNK_MS); + // Resample to 16kHz mono for all processing + let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS { + resample_and_mixdown(&samples, format.sample_rate, format.channels) + } else { + samples + }; + + audio_buffer.extend(&resampled); + eou_buffer.extend(&resampled); + + // Process EOU detection in 160ms chunks + while eou_buffer.len() >= EOU_CHUNK_SIZE { + let chunk: Vec<f32> = eou_buffer.drain(..EOU_CHUNK_SIZE).collect(); + + let mut eou = state.parakeet_eou.lock().await; + if let Ok(text) = eou.transcribe(&chunk, false) { + // Detect utterance boundary (sentence-ending punctuation) + if !text.is_empty() && text != last_eou_text { + if last_eou_text.ends_with('.') + || last_eou_text.ends_with('?') + || last_eou_text.ends_with('!') + { + utterance_ended = true; + tracing::debug!( + session_id = %session_id, + "Utterance boundary detected via EOU" + ); + } + last_eou_text = text; + } + } + } + + // Calculate if we should process (utterance ended OR enough new audio) + let chunk_samples = samples_per_chunk(TARGET_SAMPLE_RATE, STREAM_CHUNK_MS); let new_audio_len = audio_buffer.len() - last_processed_len; + let should_process = utterance_ended || new_audio_len >= chunk_samples; - if new_audio_len >= chunk_samples { + if should_process { tracing::debug!( session_id = %session_id, total_samples = audio_buffer.len(), new_samples = new_audio_len, - "Processing audio chunk" + utterance_ended = utterance_ended, + audio_offset = audio_offset, + "Processing audio with sliding window" ); - match process_audio(&audio_buffer, format, &state).await { + match process_audio_window(&audio_buffer, audio_offset, &state).await { Ok(segments) => { tracing::debug!( session_id = %session_id, @@ -230,23 +327,57 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { "Transcription produced segments" ); - // Send segments that end after our last sent time - // This handles re-segmentation by the model + // Send segments with adjusted timestamps for seg in &segments { - if seg.end > last_sent_end_time { + let adjusted_end = seg.end + audio_offset; + if adjusted_end > last_sent_end_time { let _ = response_tx .send(ServerMessage::Transcript(TranscriptMessage { speaker: seg.speaker.clone(), - start: seg.start, - end: seg.end, + start: seg.start + audio_offset, + end: adjusted_end, text: seg.text.clone(), is_final: false, })) .await; - last_sent_end_time = seg.end; + last_sent_end_time = adjusted_end; } } + + // If utterance ended, finalize and trim + if utterance_ended && segments.len() > 1 { + // Finalize all but the last segment + let to_finalize = &segments[..segments.len() - 1]; + for seg in to_finalize { + finalized_segments.push(DialogueSegment { + speaker: seg.speaker.clone(), + start: seg.start + audio_offset, + end: seg.end + audio_offset, + text: seg.text.clone(), + }); + } + + // Trim audio buffer + if let Some(last_finalized) = to_finalize.last() { + let trim_to_time = (last_finalized.end - CONTEXT_OVERLAP_SECONDS).max(0.0); + let trim_samples = (trim_to_time * TARGET_SAMPLE_RATE as f32) as usize; + + if trim_samples > 0 && trim_samples < audio_buffer.len() { + audio_buffer.drain(..trim_samples); + audio_offset += trim_to_time; + tracing::debug!( + session_id = %session_id, + trimmed_samples = trim_samples, + new_offset = audio_offset, + remaining_samples = audio_buffer.len(), + "Trimmed audio buffer after finalization" + ); + } + } + } + last_processed_len = audio_buffer.len(); + utterance_ended = false; } Err(e) => { tracing::error!(session_id = %session_id, error = %e, "Transcription error"); @@ -291,31 +422,37 @@ fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> { } } -/// Process accumulated audio through STT and diarization models. -async fn process_audio( +/// Process audio using sliding window through STT and streaming diarization models. +/// +/// Only processes the last MAX_WINDOW_SECONDS of audio to maintain constant +/// processing time regardless of total audio length. +async fn process_audio_window( samples: &[f32], - format: &StartMessage, + _audio_offset: f32, state: &SharedState, ) -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error + Send + Sync>> { - // Resample to 16kHz mono if needed - let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS - { - resample_and_mixdown(samples, format.sample_rate, format.channels) - } else { - samples.to_vec() - }; + // Apply sliding window - only process the last 30 seconds + let window_start = samples.len().saturating_sub(MAX_WINDOW_SAMPLES); + let window = &samples[window_start..]; + + tracing::trace!( + total_samples = samples.len(), + window_samples = window.len(), + window_start = window_start, + "Using sliding window for processing" + ); // Acquire model locks and run inference let mut parakeet = state.parakeet.lock().await; let mut sortformer = state.sortformer.lock().await; - // Run diarization + // Run streaming diarization (maintains speaker cache across calls) let diarization_segments = - sortformer.diarize(resampled.clone(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?; + sortformer.diarize_streaming(window.to_vec(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?; // Run transcription let transcription = parakeet.transcribe_samples( - resampled, + window.to_vec(), TARGET_SAMPLE_RATE, TARGET_CHANNELS, Some(TimestampMode::Sentences), @@ -324,5 +461,17 @@ async fn process_audio( // Align speakers with transcription let aligned = align_speakers(&transcription.tokens, &diarization_segments); - Ok(aligned) + // Adjust timestamps for window offset within the buffer + let window_offset = window_start as f32 / TARGET_SAMPLE_RATE as f32; + let adjusted: Vec<DialogueSegment> = aligned + .into_iter() + .map(|seg| DialogueSegment { + speaker: seg.speaker, + start: seg.start + window_offset, + end: seg.end + window_offset, + text: seg.text, + }) + .collect(); + + Ok(adjusted) } diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index c38359d..31e1518 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -3,14 +3,16 @@ use std::sync::Arc; use tokio::sync::Mutex; -use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer}; +use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; /// Shared application state containing ML models. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. pub struct AppState { - /// Speech-to-text model (Parakeet) + /// Speech-to-text model (Parakeet TDT) pub parakeet: Mutex<ParakeetTDT>, + /// End-of-Utterance detection model for streaming + pub parakeet_eou: Mutex<ParakeetEOU>, /// Speaker diarization model (Sortformer) pub sortformer: Mutex<Sortformer>, } @@ -19,13 +21,16 @@ impl AppState { /// Load all ML models from the specified directories. /// /// # Arguments - /// * `parakeet_model_dir` - Path to the Parakeet STT model directory + /// * `parakeet_model_dir` - Path to the Parakeet TDT model directory + /// * `parakeet_eou_dir` - Path to the Parakeet EOU model directory /// * `sortformer_model_path` - Path to the Sortformer diarization model file pub fn new( parakeet_model_dir: &str, + parakeet_eou_dir: &str, sortformer_model_path: &str, ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?; + let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?; let sortformer = Sortformer::with_config( sortformer_model_path, None, @@ -34,6 +39,7 @@ impl AppState { Ok(Self { parakeet: Mutex::new(parakeet), + parakeet_eou: Mutex::new(parakeet_eou), sortformer: Mutex::new(sortformer), }) } |
