//! WebSocket handler for TTS streaming (direct in-process inference). //! //! This module implements the `/api/v1/speak` endpoint which performs //! text-to-speech synthesis directly using the Chatterbox ONNX TTS engine. //! No external Python service or proxy — the model runs in-process. //! //! ## Architecture //! //! The speak handler will: //! 1. Accept a WebSocket connection from the client //! 2. Lazily load the TTS model (Chatterbox ONNX) on first request //! 3. Parse JSON control messages (start, speak, stop, cancel) //! 4. Run inference directly and stream audio chunks back //! //! See `makima/src/tts/` for the TTS engine implementation. use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, }; use futures::{SinkExt, StreamExt}; use serde::Deserialize; use uuid::Uuid; use crate::server::state::SharedState; /// Client-to-server control messages. #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ClientMessage { /// Request speech synthesis for the given text. Speak { text: String, /// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning. /// Defaults to "makima" if not specified. #[serde(default)] voice: Option, }, /// Cancel any in-progress synthesis. Cancel, /// Graceful close. Stop, } /// WebSocket upgrade handler for TTS streaming. /// /// This endpoint accepts WebSocket connections for text-to-speech synthesis. /// The TTS model runs directly in-process using candle — no external service. #[utoipa::path( get, path = "/api/v1/speak", responses( (status = 101, description = "WebSocket connection established"), (status = 503, description = "TTS engine not available"), ), tag = "Speak" )] pub async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> Response { ws.on_upgrade(|socket| handle_speak_socket(socket, state)) } /// Handle TTS WebSocket session with direct in-process inference. /// /// Protocol: /// - Client sends JSON `{ "type": "speak", "text": "..." }` messages /// - Server responds with binary audio chunks (16-bit PCM @ 24kHz) /// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete /// - Server sends JSON `{ "type": "error", ... }` on failures async fn handle_speak_socket(socket: WebSocket, state: SharedState) { let session_id = Uuid::new_v4().to_string(); tracing::info!(session_id = %session_id, "New TTS WebSocket connection"); let (mut sender, mut receiver) = socket.split(); // Cancellation flag shared between the message loop and inference. // Each new Speak request resets it to false; Cancel sets it to true. let cancel_flag: Arc = Arc::new(AtomicBool::new(false)); // Process incoming messages while let Some(msg) = receiver.next().await { let msg = match msg { Ok(m) => m, Err(e) => { tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error"); break; } }; match msg { Message::Text(text) => { let client_msg: ClientMessage = match serde_json::from_str(&text) { Ok(m) => m, Err(e) => { let _ = send_error( &mut sender, "INVALID_MESSAGE", &format!("Failed to parse message: {e}"), ) .await; continue; } }; match client_msg { ClientMessage::Speak { text, voice } => { let voice_id = voice .as_deref() .unwrap_or(super::voice::DEFAULT_VOICE_ID); tracing::info!( session_id = %session_id, text_len = text.len(), voice_id = %voice_id, "TTS speak request" ); // Load voice reference audio for cloning let voice_ref = match super::voice::load_reference_audio(voice_id) { Ok(v) => { tracing::debug!( session_id = %session_id, voice_id = %voice_id, voice_name = %v.manifest.name, samples = v.samples.len(), "Voice reference loaded" ); Some(v) } Err(e) => { tracing::warn!( session_id = %session_id, voice_id = %voice_id, error = %e, "Failed to load voice reference, proceeding without cloning" ); None } }; // Get or lazily load the TTS engine let engine = match state.get_tts_engine().await { Ok(e) => e, Err(e) => { tracing::error!( session_id = %session_id, error = %e, "Failed to load TTS engine" ); let _ = send_error( &mut sender, "TTS_LOAD_FAILED", &format!("Failed to load TTS engine: {e}"), ) .await; continue; } }; if !engine.is_ready() { let _ = send_error( &mut sender, "TTS_NOT_READY", "TTS engine is not ready yet", ) .await; continue; } // Reset the cancel flag for this new generation request cancel_flag.store(false, Ordering::Relaxed); // Run TTS inference with optional voice reference for cloning // and the cancel flag so it can be stopped early. let (ref_audio, ref_rate) = match &voice_ref { Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)), None => (None, None), }; let flag = cancel_flag.clone(); match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await { Ok(chunks) => { // Check if generation was cancelled let was_cancelled = cancel_flag.load(Ordering::Relaxed); for chunk in &chunks { // Send binary PCM audio data let pcm_bytes = chunk.to_pcm16_bytes(); if sender .send(Message::Binary(pcm_bytes.into())) .await .is_err() { tracing::warn!( session_id = %session_id, "Failed to send audio chunk — client disconnected" ); return; } } // Signal end of audio (include cancelled status) let end_msg = serde_json::json!({ "type": "audio_end", "sample_rate": engine.sample_rate(), "format": "pcm_s16le", "channels": 1, "cancelled": was_cancelled, }); let _ = sender .send(Message::Text(end_msg.to_string().into())) .await; } Err(e) => { tracing::error!( session_id = %session_id, error = %e, "TTS inference failed" ); let _ = send_error( &mut sender, "TTS_INFERENCE_FAILED", &format!("TTS inference failed: {e}"), ) .await; } } } ClientMessage::Cancel => { tracing::info!(session_id = %session_id, "TTS cancel requested"); cancel_flag.store(true, Ordering::Relaxed); } ClientMessage::Stop => { tracing::info!(session_id = %session_id, "TTS stop requested, closing"); cancel_flag.store(true, Ordering::Relaxed); break; } } } Message::Close(_) => { tracing::info!(session_id = %session_id, "TTS WebSocket closed by client"); cancel_flag.store(true, Ordering::Relaxed); break; } _ => { // Ignore ping/pong/binary from client } } } tracing::info!(session_id = %session_id, "TTS WebSocket connection closed"); } /// Send an error message to the client. async fn send_error(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error> where S: SinkExt + Unpin, >::Error: std::error::Error, { let error_msg = serde_json::json!({ "type": "error", "code": code, "message": message, "recoverable": false }); sender .send(Message::Text(error_msg.to_string().into())) .await .ok(); Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_error_message_format() { let error = serde_json::json!({ "type": "error", "code": "TEST_ERROR", "message": "Test message", "recoverable": false }); assert_eq!(error["type"], "error"); assert_eq!(error["code"], "TEST_ERROR"); assert_eq!(error["message"], "Test message"); assert_eq!(error["recoverable"], false); } #[test] fn test_client_message_parse_speak() { let json = r#"{"type": "speak", "text": "Hello world"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); match msg { ClientMessage::Speak { text, voice } => { assert_eq!(text, "Hello world"); assert!(voice.is_none()); } _ => panic!("Expected Speak message"), } } #[test] fn test_client_message_parse_cancel() { let json = r#"{"type": "cancel"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Cancel)); } #[test] fn test_client_message_parse_stop() { let json = r#"{"type": "stop"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Stop)); } #[test] fn test_client_message_parse_speak_with_voice() { let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); match msg { ClientMessage::Speak { text, voice } => { assert_eq!(text, "Hello"); assert_eq!(voice.as_deref(), Some("makima")); } _ => panic!("Expected Speak message"), } } }