diff options
Diffstat (limited to 'makima/src/server/handlers')
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 328 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 4 | ||||
| -rw-r--r-- | makima/src/server/handlers/tts.rs | 185 |
3 files changed, 517 insertions, 0 deletions
diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs new file mode 100644 index 0000000..b1c1ad9 --- /dev/null +++ b/makima/src/server/handlers/listen.rs @@ -0,0 +1,328 @@ +//! WebSocket handler for streaming speech-to-text. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::audio::{resample_and_mixdown, TARGET_CHANNELS, TARGET_SAMPLE_RATE}; +use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode, Transcriber}; +use crate::server::messages::{ + AudioEncoding, ClientMessage, ServerMessage, StartMessage, TranscriptMessage, +}; +use crate::server::state::SharedState; + +/// Chunk size in milliseconds for streaming transcription. +const STREAM_CHUNK_MS: u32 = 5_000; + +/// WebSocket upgrade handler for STT streaming. +/// +/// This endpoint accepts WebSocket connections for real-time speech-to-text +/// transcription with speaker diarization. +#[utoipa::path( + get, + path = "/api/v1/listen", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "STT" +)] +pub async fn websocket_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_socket(socket, state)) +} + +async fn handle_socket(socket: WebSocket, state: SharedState) { + let session_id = Uuid::new_v4().to_string(); + tracing::info!(session_id = %session_id, "New WebSocket connection"); + + // Split socket for concurrent read/write + let (mut sender, mut receiver) = socket.split(); + + // Channel for sending responses back to client + let (response_tx, mut response_rx) = mpsc::channel::<ServerMessage>(32); + + // Spawn task to forward responses to WebSocket + let sender_task = tokio::spawn(async move { + while let Some(msg) = response_rx.recv().await { + let json = match serde_json::to_string(&msg) { + Ok(j) => j, + Err(e) => { + tracing::error!("Failed to serialize message: {}", e); + continue; + } + }; + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + }); + + // Send ready message + let _ = response_tx + .send(ServerMessage::Ready { + session_id: session_id.clone(), + }) + .await; + + // Audio format state + let mut audio_format: Option<StartMessage> = None; + + // Audio buffer for accumulating samples + let mut audio_buffer: Vec<f32> = Vec::new(); + let mut last_sent_end_time: f32 = 0.0; // Track the end time of last sent segment + let mut last_processed_len: usize = 0; // Track how much audio we've processed + + // Process incoming messages + while let Some(msg_result) = receiver.next().await { + let msg = match msg_result { + Ok(m) => m, + Err(e) => { + tracing::error!("WebSocket error: {}", e); + break; + } + }; + + match msg { + Message::Text(text) => { + // Parse JSON control messages + match serde_json::from_str::<ClientMessage>(&text) { + Ok(ClientMessage::Start(start)) => { + tracing::info!( + session_id = %session_id, + sample_rate = start.sample_rate, + channels = start.channels, + encoding = ?start.encoding, + "Session started" + ); + audio_format = Some(start); + audio_buffer.clear(); + last_sent_end_time = 0.0; + last_processed_len = 0; + } + Ok(ClientMessage::Stop(stop)) => { + tracing::info!( + session_id = %session_id, + reason = ?stop.reason, + audio_buffer_len = audio_buffer.len(), + "Session stopped by client" + ); + + if let Some(ref format) = audio_format { + if !audio_buffer.is_empty() { + tracing::debug!( + session_id = %session_id, + samples = audio_buffer.len(), + "Processing final audio buffer" + ); + match process_audio(&audio_buffer, format, &state).await { + Ok(segments) => { + tracing::debug!( + session_id = %session_id, + total_segments = segments.len(), + last_sent_end = last_sent_end_time, + "Final transcription complete" + ); + + // Step 1: Send any NEW segments as interim (is_final: false) + // These are segments that weren't sent during streaming + for seg in &segments { + if seg.end > last_sent_end_time { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: false, + })) + .await; + } + } + + // Step 2: Send ALL segments as final (is_final: true) + // This is the complete authoritative transcript + for seg in &segments { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: true, + })) + .await; + } + } + Err(e) => { + tracing::error!( + session_id = %session_id, + error = %e, + "Final transcription failed" + ); + let _ = response_tx + .send(ServerMessage::Error { + code: "TRANSCRIPTION_ERROR".into(), + message: e.to_string(), + }) + .await; + } + } + } + } + + let _ = response_tx + .send(ServerMessage::Stopped { + reason: stop.reason.unwrap_or_else(|| "client_requested".into()), + }) + .await; + break; + } + Err(e) => { + tracing::warn!(session_id = %session_id, error = %e, "Failed to parse message"); + let _ = response_tx + .send(ServerMessage::Error { + code: "PARSE_ERROR".into(), + message: format!("Failed to parse message: {}", e), + }) + .await; + } + } + } + Message::Binary(data) => { + let Some(ref format) = audio_format else { + let _ = response_tx + .send(ServerMessage::Error { + code: "NO_FORMAT".into(), + message: "Received audio before start message".into(), + }) + .await; + continue; + }; + + // Decode binary audio data to f32 samples + let samples = decode_audio_chunk(&data, format); + audio_buffer.extend(samples); + + // Process when we have accumulated another chunk's worth of NEW audio + let chunk_samples = samples_per_chunk(format.sample_rate, STREAM_CHUNK_MS); + let new_audio_len = audio_buffer.len() - last_processed_len; + + if new_audio_len >= chunk_samples { + tracing::debug!( + session_id = %session_id, + total_samples = audio_buffer.len(), + new_samples = new_audio_len, + "Processing audio chunk" + ); + + match process_audio(&audio_buffer, format, &state).await { + Ok(segments) => { + tracing::debug!( + session_id = %session_id, + total_segments = segments.len(), + last_sent_end = last_sent_end_time, + "Transcription produced segments" + ); + + // Send segments that end after our last sent time + // This handles re-segmentation by the model + for seg in &segments { + if seg.end > last_sent_end_time { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: false, + })) + .await; + last_sent_end_time = seg.end; + } + } + last_processed_len = audio_buffer.len(); + } + Err(e) => { + tracing::error!(session_id = %session_id, error = %e, "Transcription error"); + let _ = response_tx + .send(ServerMessage::Error { + code: "TRANSCRIPTION_ERROR".into(), + message: e.to_string(), + }) + .await; + } + } + } + } + Message::Close(_) => { + tracing::info!(session_id = %session_id, "WebSocket closed by client"); + break; + } + _ => {} + } + } + + // Cleanup + drop(response_tx); + let _ = sender_task.await; + tracing::info!(session_id = %session_id, "WebSocket connection closed"); +} + +/// Decode binary audio chunk to f32 samples based on encoding format. +fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> { + match format.encoding { + AudioEncoding::Pcm32f => data + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(), + AudioEncoding::Pcm16 | AudioEncoding::Raw => data + .chunks_exact(2) + .map(|chunk| { + let sample = i16::from_le_bytes([chunk[0], chunk[1]]); + sample as f32 / 32768.0 + }) + .collect(), + } +} + +/// Process accumulated audio through STT and diarization models. +async fn process_audio( + samples: &[f32], + format: &StartMessage, + state: &SharedState, +) -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error + Send + Sync>> { + // Resample to 16kHz mono if needed + let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS + { + resample_and_mixdown(samples, format.sample_rate, format.channels) + } else { + samples.to_vec() + }; + + // Acquire model locks and run inference + let mut parakeet = state.parakeet.lock().await; + let mut sortformer = state.sortformer.lock().await; + + // Run diarization + let diarization_segments = + sortformer.diarize(resampled.clone(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?; + + // Run transcription + let transcription = parakeet.transcribe_samples( + resampled, + TARGET_SAMPLE_RATE, + TARGET_CHANNELS, + Some(TimestampMode::Sentences), + )?; + + // Align speakers with transcription + let aligned = align_speakers(&transcription.tokens, &diarization_segments); + + Ok(aligned) +} diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs new file mode 100644 index 0000000..90798f9 --- /dev/null +++ b/makima/src/server/handlers/mod.rs @@ -0,0 +1,4 @@ +//! HTTP and WebSocket request handlers. + +pub mod listen; +pub mod tts; diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs new file mode 100644 index 0000000..94a261f --- /dev/null +++ b/makima/src/server/handlers/tts.rs @@ -0,0 +1,185 @@ +//! HTTP handler for text-to-speech synthesis. + +use axum::{ + body::Body, + extract::{Multipart, State}, + http::{header, StatusCode}, + response::Response, + Json, +}; +use std::io::Cursor; + +use crate::audio::to_16k_mono_from_reader; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; +use crate::tts::SAMPLE_RATE; + +/// POST /api/v1/tts/synthesize +/// +/// Synthesize speech from text using voice cloning. +/// +/// Accepts multipart form data with: +/// - `text`: The text to synthesize (required) +/// - `voice`: Audio file for voice cloning reference (required) +/// +/// Returns: WAV audio file (24kHz mono) +#[utoipa::path( + post, + path = "/api/v1/tts/synthesize", + request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"), + responses( + (status = 200, description = "Generated audio file", content_type = "audio/wav"), + (status = 400, description = "Bad request", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + tag = "TTS" +)] +pub async fn synthesize_handler( + State(state): State<SharedState>, + mut multipart: Multipart, +) -> Result<Response, (StatusCode, Json<ApiError>)> { + let mut text: Option<String> = None; + let mut voice_samples: Option<Vec<f32>> = None; + let mut voice_sample_rate: u32 = 16_000; + + // Parse multipart fields + while let Some(field) = multipart.next_field().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MULTIPART_ERROR", e.to_string())), + ) + })? { + let name = field.name().unwrap_or("").to_string(); + + match name.as_str() { + "text" => { + text = Some(field.text().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())), + ) + })?); + } + "voice" => { + let data = field.bytes().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())), + ) + })?; + + // Decode audio file to PCM samples + let cursor = Cursor::new(data.to_vec()); + let pcm = to_16k_mono_from_reader(cursor).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "AUDIO_DECODE_ERROR", + format!("Failed to decode voice audio: {}", e), + )), + ) + })?; + + voice_samples = Some(pcm.samples); + voice_sample_rate = pcm.sample_rate; + } + _ => { + // Ignore unknown fields + tracing::debug!("Ignoring unknown field: {}", name); + } + } + } + + // Validate required fields + let text = text.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MISSING_TEXT", "Text field is required")), + ) + })?; + + let samples = voice_samples.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "MISSING_VOICE", + "Voice audio file is required", + )), + ) + })?; + + tracing::info!( + text_len = text.len(), + voice_samples = samples.len(), + "Generating TTS" + ); + + // Generate TTS with voice cloning + let mut chatterbox = state.chatterbox.lock().await; + let audio = chatterbox + .generate_tts_with_samples(&text, &samples, voice_sample_rate) + .map_err(|e| { + tracing::error!(error = %e, "TTS generation failed"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("TTS_ERROR", e.to_string())), + ) + })?; + + tracing::info!(samples = audio.len(), "TTS generation complete"); + + // Encode as WAV + let wav_data = encode_wav(&audio, SAMPLE_RATE); + + // Return WAV response + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "audio/wav") + .header( + header::CONTENT_DISPOSITION, + "attachment; filename=\"output.wav\"", + ) + .body(Body::from(wav_data)) + .unwrap()) +} + +/// Encode f32 samples as a WAV file in memory. +fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> { + let mut buf = Vec::new(); + + let num_samples = samples.len() as u32; + let num_channels: u16 = 1; + let bits_per_sample: u16 = 16; + let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; + let block_align = num_channels * bits_per_sample / 8; + let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; + let file_size = 36 + data_size; + + // RIFF header + buf.extend_from_slice(b"RIFF"); + buf.extend_from_slice(&file_size.to_le_bytes()); + buf.extend_from_slice(b"WAVE"); + + // fmt chunk + buf.extend_from_slice(b"fmt "); + buf.extend_from_slice(&16u32.to_le_bytes()); + buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format + buf.extend_from_slice(&num_channels.to_le_bytes()); + buf.extend_from_slice(&sample_rate.to_le_bytes()); + buf.extend_from_slice(&byte_rate.to_le_bytes()); + buf.extend_from_slice(&block_align.to_le_bytes()); + buf.extend_from_slice(&bits_per_sample.to_le_bytes()); + + // data chunk + buf.extend_from_slice(b"data"); + buf.extend_from_slice(&data_size.to_le_bytes()); + + // Convert f32 samples to i16 PCM + for &sample in samples { + let clamped = sample.clamp(-1.0, 1.0); + let int_sample = (clamped * 32767.0) as i16; + buf.extend_from_slice(&int_sample.to_le_bytes()); + } + + buf +} |
