//! WebSocket handler for TTS streaming (direct in-process inference). //! //! This module implements the `/api/v1/speak` endpoint which performs //! text-to-speech synthesis directly using the candle-based TTS engine. //! No external Python service or proxy — the model runs in-process. //! //! ## Architecture //! //! The speak handler will: //! 1. Accept a WebSocket connection from the client //! 2. Lazily load the TTS model (candle) on first request //! 3. Parse JSON control messages (start, speak, stop, cancel) //! 4. Run inference directly and stream audio chunks back //! //! See `makima/src/tts/` for the TTS engine implementation. //! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification. use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, }; use futures::{SinkExt, StreamExt}; use serde::Deserialize; use uuid::Uuid; use crate::server::state::SharedState; /// Client-to-server control messages. #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ClientMessage { /// Request speech synthesis for the given text. Speak { text: String, /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection. #[serde(default)] #[allow(dead_code)] voice: Option, }, /// Cancel any in-progress synthesis. Cancel, /// Graceful close. Stop, } /// WebSocket upgrade handler for TTS streaming. /// /// This endpoint accepts WebSocket connections for text-to-speech synthesis. /// The TTS model runs directly in-process using candle — no external service. #[utoipa::path( get, path = "/api/v1/speak", responses( (status = 101, description = "WebSocket connection established"), (status = 503, description = "TTS engine not available"), ), tag = "Speak" )] pub async fn websocket_handler( ws: WebSocketUpgrade, State(state): State, ) -> Response { ws.on_upgrade(|socket| handle_speak_socket(socket, state)) } /// Handle TTS WebSocket session with direct in-process inference. /// /// Protocol: /// - Client sends JSON `{ "type": "speak", "text": "..." }` messages /// - Server responds with binary audio chunks (16-bit PCM @ 24kHz) /// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete /// - Server sends JSON `{ "type": "error", ... }` on failures async fn handle_speak_socket(socket: WebSocket, state: SharedState) { let session_id = Uuid::new_v4().to_string(); tracing::info!(session_id = %session_id, "New TTS WebSocket connection"); let (mut sender, mut receiver) = socket.split(); // Process incoming messages while let Some(msg) = receiver.next().await { let msg = match msg { Ok(m) => m, Err(e) => { tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error"); break; } }; match msg { Message::Text(text) => { let client_msg: ClientMessage = match serde_json::from_str(&text) { Ok(m) => m, Err(e) => { let _ = send_error( &mut sender, "INVALID_MESSAGE", &format!("Failed to parse message: {e}"), ) .await; continue; } }; match client_msg { ClientMessage::Speak { text, .. } => { tracing::info!( session_id = %session_id, text_len = text.len(), "TTS speak request" ); // Get or lazily load the TTS engine let engine = match state.get_tts_engine().await { Ok(e) => e, Err(e) => { tracing::error!( session_id = %session_id, error = %e, "Failed to load TTS engine" ); let _ = send_error( &mut sender, "TTS_LOAD_FAILED", &format!("Failed to load TTS engine: {e}"), ) .await; continue; } }; if !engine.is_ready() { let _ = send_error( &mut sender, "TTS_NOT_READY", "TTS engine is not ready yet", ) .await; continue; } // Run TTS inference (no voice reference for now — uses default) match engine.generate(&text, None, None).await { Ok(chunks) => { for chunk in &chunks { // Send binary PCM audio data let pcm_bytes = chunk.to_pcm16_bytes(); if sender .send(Message::Binary(pcm_bytes.into())) .await .is_err() { tracing::warn!( session_id = %session_id, "Failed to send audio chunk — client disconnected" ); return; } } // Signal end of audio let end_msg = serde_json::json!({ "type": "audio_end", "sample_rate": engine.sample_rate(), "format": "pcm_s16le", "channels": 1, }); let _ = sender .send(Message::Text(end_msg.to_string().into())) .await; } Err(e) => { tracing::error!( session_id = %session_id, error = %e, "TTS inference failed" ); let _ = send_error( &mut sender, "TTS_INFERENCE_FAILED", &format!("TTS inference failed: {e}"), ) .await; } } } ClientMessage::Cancel => { tracing::info!(session_id = %session_id, "TTS cancel requested"); // TODO: support cancellation of in-progress inference } ClientMessage::Stop => { tracing::info!(session_id = %session_id, "TTS stop requested, closing"); break; } } } Message::Close(_) => { tracing::info!(session_id = %session_id, "TTS WebSocket closed by client"); break; } _ => { // Ignore ping/pong/binary from client } } } tracing::info!(session_id = %session_id, "TTS WebSocket connection closed"); } /// Send an error message to the client. async fn send_error(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error> where S: SinkExt + Unpin, >::Error: std::error::Error, { let error_msg = serde_json::json!({ "type": "error", "code": code, "message": message, "recoverable": false }); sender .send(Message::Text(error_msg.to_string().into())) .await .ok(); Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_error_message_format() { let error = serde_json::json!({ "type": "error", "code": "TEST_ERROR", "message": "Test message", "recoverable": false }); assert_eq!(error["type"], "error"); assert_eq!(error["code"], "TEST_ERROR"); assert_eq!(error["message"], "Test message"); assert_eq!(error["recoverable"], false); } #[test] fn test_client_message_parse_speak() { let json = r#"{"type": "speak", "text": "Hello world"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); match msg { ClientMessage::Speak { text, voice } => { assert_eq!(text, "Hello world"); assert!(voice.is_none()); } _ => panic!("Expected Speak message"), } } #[test] fn test_client_message_parse_cancel() { let json = r#"{"type": "cancel"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Cancel)); } #[test] fn test_client_message_parse_stop() { let json = r#"{"type": "stop"}"#; let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Stop)); } }