summaryrefslogtreecommitdiff
path: root/makima/src
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src')
-rw-r--r--makima/src/bin/server.rs3
-rw-r--r--makima/src/listen.rs2
-rw-r--r--makima/src/server/handlers/listen.rs229
-rw-r--r--makima/src/server/state.rs12
4 files changed, 201 insertions, 45 deletions
diff --git a/makima/src/bin/server.rs b/makima/src/bin/server.rs
index 1964cae..06b6585 100644
--- a/makima/src/bin/server.rs
+++ b/makima/src/bin/server.rs
@@ -10,6 +10,7 @@ use makima::server::{run_server, state::AppState};
/// Default model paths relative to the working directory.
const PARAKEET_MODEL_DIR: &str = "models/parakeet-tdt-0.6b-v3";
+const PARAKEET_EOU_DIR: &str = "models/realtime_eou_120m-v1-onnx";
const SORTFORMER_MODEL_PATH: &str = "models/diarization/diar_streaming_sortformer_4spk-v2.onnx";
#[tokio::main]
@@ -28,7 +29,7 @@ async fn main() -> anyhow::Result<()> {
// Load ML models
let state = Arc::new(
- AppState::new(PARAKEET_MODEL_DIR, SORTFORMER_MODEL_PATH)
+ AppState::new(PARAKEET_MODEL_DIR, PARAKEET_EOU_DIR, SORTFORMER_MODEL_PATH)
.map_err(|e| anyhow::anyhow!("Failed to load models: {}", e))?,
);
diff --git a/makima/src/listen.rs b/makima/src/listen.rs
index 24bf9da..91d616c 100644
--- a/makima/src/listen.rs
+++ b/makima/src/listen.rs
@@ -2,7 +2,7 @@ use std::cmp::Ordering;
use std::path::Path;
pub use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment};
-pub use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode, Transcriber};
+pub use parakeet_rs::{ParakeetEOU, ParakeetTDT, TimedToken, TimestampMode};
use crate::audio;
diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs
index b1c1ad9..bf6746c 100644
--- a/makima/src/server/handlers/listen.rs
+++ b/makima/src/server/handlers/listen.rs
@@ -1,4 +1,4 @@
-//! WebSocket handler for streaming speech-to-text.
+//! WebSocket handler for streaming speech-to-text with sliding window optimization.
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
@@ -9,15 +9,27 @@ use tokio::sync::mpsc;
use uuid::Uuid;
use crate::audio::{resample_and_mixdown, TARGET_CHANNELS, TARGET_SAMPLE_RATE};
-use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode, Transcriber};
+use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode};
use crate::server::messages::{
AudioEncoding, ClientMessage, ServerMessage, StartMessage, TranscriptMessage,
};
use crate::server::state::SharedState;
-/// Chunk size in milliseconds for streaming transcription.
+/// Chunk size in milliseconds for triggering transcription processing.
const STREAM_CHUNK_MS: u32 = 5_000;
+/// Maximum window size in seconds for sliding window processing.
+const MAX_WINDOW_SECONDS: f32 = 30.0;
+
+/// Maximum window size in samples at 16kHz.
+const MAX_WINDOW_SAMPLES: usize = (MAX_WINDOW_SECONDS as usize) * (TARGET_SAMPLE_RATE as usize);
+
+/// EOU chunk size in samples (160ms at 16kHz).
+const EOU_CHUNK_SIZE: usize = 2560;
+
+/// Context overlap in seconds to keep when trimming finalized audio.
+const CONTEXT_OVERLAP_SECONDS: f32 = 2.0;
+
/// WebSocket upgrade handler for STT streaming.
///
/// This endpoint accepts WebSocket connections for real-time speech-to-text
@@ -28,7 +40,7 @@ const STREAM_CHUNK_MS: u32 = 5_000;
responses(
(status = 101, description = "WebSocket connection established"),
),
- tag = "STT"
+ tag = "Listen"
)]
pub async fn websocket_handler(
ws: WebSocketUpgrade,
@@ -73,10 +85,25 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
// Audio format state
let mut audio_format: Option<StartMessage> = None;
- // Audio buffer for accumulating samples
+ // Main audio buffer for transcription (accumulates resampled 16kHz mono audio)
let mut audio_buffer: Vec<f32> = Vec::new();
- let mut last_sent_end_time: f32 = 0.0; // Track the end time of last sent segment
- let mut last_processed_len: usize = 0; // Track how much audio we've processed
+
+ // EOU detection buffer (resampled audio for utterance detection)
+ let mut eou_buffer: Vec<f32> = Vec::new();
+ let mut last_eou_text: String = String::new();
+ let mut utterance_ended: bool = false;
+
+ // Tracking state
+ let mut last_sent_end_time: f32 = 0.0;
+ let mut last_processed_len: usize = 0;
+ let mut audio_offset: f32 = 0.0; // Time offset from trimmed audio
+ let mut finalized_segments: Vec<DialogueSegment> = Vec::new();
+
+ // Reset Sortformer state for new session
+ {
+ let mut sortformer = state.sortformer.lock().await;
+ sortformer.reset_state();
+ }
// Process incoming messages
while let Some(msg_result) = receiver.next().await {
@@ -102,8 +129,17 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
);
audio_format = Some(start);
audio_buffer.clear();
+ eou_buffer.clear();
+ last_eou_text.clear();
+ utterance_ended = false;
last_sent_end_time = 0.0;
last_processed_len = 0;
+ audio_offset = 0.0;
+ finalized_segments.clear();
+
+ // Reset models for new session
+ let mut sortformer = state.sortformer.lock().await;
+ sortformer.reset_state();
}
Ok(ClientMessage::Stop(stop)) => {
tracing::info!(
@@ -113,25 +149,52 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
"Session stopped by client"
);
- if let Some(ref format) = audio_format {
+ if audio_format.is_some() {
if !audio_buffer.is_empty() {
tracing::debug!(
session_id = %session_id,
samples = audio_buffer.len(),
"Processing final audio buffer"
);
- match process_audio(&audio_buffer, format, &state).await {
+
+ // Process remaining audio with sliding window
+ match process_audio_window(&audio_buffer, audio_offset, &state).await {
Ok(segments) => {
tracing::debug!(
session_id = %session_id,
total_segments = segments.len(),
+ finalized_count = finalized_segments.len(),
last_sent_end = last_sent_end_time,
"Final transcription complete"
);
- // Step 1: Send any NEW segments as interim (is_final: false)
- // These are segments that weren't sent during streaming
+ // Combine finalized segments with new segments
+ let mut all_segments = finalized_segments.clone();
+
+ // Add segments from current window that weren't finalized
for seg in &segments {
+ // Adjust timestamps with offset
+ let adjusted_seg = DialogueSegment {
+ speaker: seg.speaker.clone(),
+ start: seg.start + audio_offset,
+ end: seg.end + audio_offset,
+ text: seg.text.clone(),
+ };
+
+ // Only add if not already finalized
+ if !finalized_segments.iter().any(|f|
+ (f.start - adjusted_seg.start).abs() < 0.1 &&
+ f.text == adjusted_seg.text
+ ) {
+ all_segments.push(adjusted_seg);
+ }
+ }
+
+ // Sort by start time
+ all_segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap());
+
+ // Send any NEW segments as interim first
+ for seg in &all_segments {
if seg.end > last_sent_end_time {
let _ = response_tx
.send(ServerMessage::Transcript(TranscriptMessage {
@@ -145,9 +208,8 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
}
}
- // Step 2: Send ALL segments as final (is_final: true)
- // This is the complete authoritative transcript
- for seg in &segments {
+ // Send ALL segments as final
+ for seg in &all_segments {
let _ = response_tx
.send(ServerMessage::Transcript(TranscriptMessage {
speaker: seg.speaker.clone(),
@@ -207,21 +269,56 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
// Decode binary audio data to f32 samples
let samples = decode_audio_chunk(&data, format);
- audio_buffer.extend(samples);
- // Process when we have accumulated another chunk's worth of NEW audio
- let chunk_samples = samples_per_chunk(format.sample_rate, STREAM_CHUNK_MS);
+ // Resample to 16kHz mono for all processing
+ let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS {
+ resample_and_mixdown(&samples, format.sample_rate, format.channels)
+ } else {
+ samples
+ };
+
+ audio_buffer.extend(&resampled);
+ eou_buffer.extend(&resampled);
+
+ // Process EOU detection in 160ms chunks
+ while eou_buffer.len() >= EOU_CHUNK_SIZE {
+ let chunk: Vec<f32> = eou_buffer.drain(..EOU_CHUNK_SIZE).collect();
+
+ let mut eou = state.parakeet_eou.lock().await;
+ if let Ok(text) = eou.transcribe(&chunk, false) {
+ // Detect utterance boundary (sentence-ending punctuation)
+ if !text.is_empty() && text != last_eou_text {
+ if last_eou_text.ends_with('.')
+ || last_eou_text.ends_with('?')
+ || last_eou_text.ends_with('!')
+ {
+ utterance_ended = true;
+ tracing::debug!(
+ session_id = %session_id,
+ "Utterance boundary detected via EOU"
+ );
+ }
+ last_eou_text = text;
+ }
+ }
+ }
+
+ // Calculate if we should process (utterance ended OR enough new audio)
+ let chunk_samples = samples_per_chunk(TARGET_SAMPLE_RATE, STREAM_CHUNK_MS);
let new_audio_len = audio_buffer.len() - last_processed_len;
+ let should_process = utterance_ended || new_audio_len >= chunk_samples;
- if new_audio_len >= chunk_samples {
+ if should_process {
tracing::debug!(
session_id = %session_id,
total_samples = audio_buffer.len(),
new_samples = new_audio_len,
- "Processing audio chunk"
+ utterance_ended = utterance_ended,
+ audio_offset = audio_offset,
+ "Processing audio with sliding window"
);
- match process_audio(&audio_buffer, format, &state).await {
+ match process_audio_window(&audio_buffer, audio_offset, &state).await {
Ok(segments) => {
tracing::debug!(
session_id = %session_id,
@@ -230,23 +327,57 @@ async fn handle_socket(socket: WebSocket, state: SharedState) {
"Transcription produced segments"
);
- // Send segments that end after our last sent time
- // This handles re-segmentation by the model
+ // Send segments with adjusted timestamps
for seg in &segments {
- if seg.end > last_sent_end_time {
+ let adjusted_end = seg.end + audio_offset;
+ if adjusted_end > last_sent_end_time {
let _ = response_tx
.send(ServerMessage::Transcript(TranscriptMessage {
speaker: seg.speaker.clone(),
- start: seg.start,
- end: seg.end,
+ start: seg.start + audio_offset,
+ end: adjusted_end,
text: seg.text.clone(),
is_final: false,
}))
.await;
- last_sent_end_time = seg.end;
+ last_sent_end_time = adjusted_end;
}
}
+
+ // If utterance ended, finalize and trim
+ if utterance_ended && segments.len() > 1 {
+ // Finalize all but the last segment
+ let to_finalize = &segments[..segments.len() - 1];
+ for seg in to_finalize {
+ finalized_segments.push(DialogueSegment {
+ speaker: seg.speaker.clone(),
+ start: seg.start + audio_offset,
+ end: seg.end + audio_offset,
+ text: seg.text.clone(),
+ });
+ }
+
+ // Trim audio buffer
+ if let Some(last_finalized) = to_finalize.last() {
+ let trim_to_time = (last_finalized.end - CONTEXT_OVERLAP_SECONDS).max(0.0);
+ let trim_samples = (trim_to_time * TARGET_SAMPLE_RATE as f32) as usize;
+
+ if trim_samples > 0 && trim_samples < audio_buffer.len() {
+ audio_buffer.drain(..trim_samples);
+ audio_offset += trim_to_time;
+ tracing::debug!(
+ session_id = %session_id,
+ trimmed_samples = trim_samples,
+ new_offset = audio_offset,
+ remaining_samples = audio_buffer.len(),
+ "Trimmed audio buffer after finalization"
+ );
+ }
+ }
+ }
+
last_processed_len = audio_buffer.len();
+ utterance_ended = false;
}
Err(e) => {
tracing::error!(session_id = %session_id, error = %e, "Transcription error");
@@ -291,31 +422,37 @@ fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> {
}
}
-/// Process accumulated audio through STT and diarization models.
-async fn process_audio(
+/// Process audio using sliding window through STT and streaming diarization models.
+///
+/// Only processes the last MAX_WINDOW_SECONDS of audio to maintain constant
+/// processing time regardless of total audio length.
+async fn process_audio_window(
samples: &[f32],
- format: &StartMessage,
+ _audio_offset: f32,
state: &SharedState,
) -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error + Send + Sync>> {
- // Resample to 16kHz mono if needed
- let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS
- {
- resample_and_mixdown(samples, format.sample_rate, format.channels)
- } else {
- samples.to_vec()
- };
+ // Apply sliding window - only process the last 30 seconds
+ let window_start = samples.len().saturating_sub(MAX_WINDOW_SAMPLES);
+ let window = &samples[window_start..];
+
+ tracing::trace!(
+ total_samples = samples.len(),
+ window_samples = window.len(),
+ window_start = window_start,
+ "Using sliding window for processing"
+ );
// Acquire model locks and run inference
let mut parakeet = state.parakeet.lock().await;
let mut sortformer = state.sortformer.lock().await;
- // Run diarization
+ // Run streaming diarization (maintains speaker cache across calls)
let diarization_segments =
- sortformer.diarize(resampled.clone(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?;
+ sortformer.diarize_streaming(window.to_vec(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?;
// Run transcription
let transcription = parakeet.transcribe_samples(
- resampled,
+ window.to_vec(),
TARGET_SAMPLE_RATE,
TARGET_CHANNELS,
Some(TimestampMode::Sentences),
@@ -324,5 +461,17 @@ async fn process_audio(
// Align speakers with transcription
let aligned = align_speakers(&transcription.tokens, &diarization_segments);
- Ok(aligned)
+ // Adjust timestamps for window offset within the buffer
+ let window_offset = window_start as f32 / TARGET_SAMPLE_RATE as f32;
+ let adjusted: Vec<DialogueSegment> = aligned
+ .into_iter()
+ .map(|seg| DialogueSegment {
+ speaker: seg.speaker,
+ start: seg.start + window_offset,
+ end: seg.end + window_offset,
+ text: seg.text,
+ })
+ .collect();
+
+ Ok(adjusted)
}
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
index c38359d..31e1518 100644
--- a/makima/src/server/state.rs
+++ b/makima/src/server/state.rs
@@ -3,14 +3,16 @@
use std::sync::Arc;
use tokio::sync::Mutex;
-use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer};
+use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
/// Shared application state containing ML models.
///
/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
pub struct AppState {
- /// Speech-to-text model (Parakeet)
+ /// Speech-to-text model (Parakeet TDT)
pub parakeet: Mutex<ParakeetTDT>,
+ /// End-of-Utterance detection model for streaming
+ pub parakeet_eou: Mutex<ParakeetEOU>,
/// Speaker diarization model (Sortformer)
pub sortformer: Mutex<Sortformer>,
}
@@ -19,13 +21,16 @@ impl AppState {
/// Load all ML models from the specified directories.
///
/// # Arguments
- /// * `parakeet_model_dir` - Path to the Parakeet STT model directory
+ /// * `parakeet_model_dir` - Path to the Parakeet TDT model directory
+ /// * `parakeet_eou_dir` - Path to the Parakeet EOU model directory
/// * `sortformer_model_path` - Path to the Sortformer diarization model file
pub fn new(
parakeet_model_dir: &str,
+ parakeet_eou_dir: &str,
sortformer_model_path: &str,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
+ let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?;
let sortformer = Sortformer::with_config(
sortformer_model_path,
None,
@@ -34,6 +39,7 @@ impl AppState {
Ok(Self {
parakeet: Mutex::new(parakeet),
+ parakeet_eou: Mutex::new(parakeet_eou),
sortformer: Mutex::new(sortformer),
})
}