summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/speak.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/handlers/speak.rs')
-rw-r--r--makima/src/server/handlers/speak.rs274
1 files changed, 274 insertions, 0 deletions
diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs
new file mode 100644
index 0000000..75e7780
--- /dev/null
+++ b/makima/src/server/handlers/speak.rs
@@ -0,0 +1,274 @@
+//! WebSocket handler for TTS streaming (direct in-process inference).
+//!
+//! This module implements the `/api/v1/speak` endpoint which performs
+//! text-to-speech synthesis directly using the candle-based TTS engine.
+//! No external Python service or proxy — the model runs in-process.
+//!
+//! ## Architecture
+//!
+//! The speak handler will:
+//! 1. Accept a WebSocket connection from the client
+//! 2. Lazily load the TTS model (candle) on first request
+//! 3. Parse JSON control messages (start, speak, stop, cancel)
+//! 4. Run inference directly and stream audio chunks back
+//!
+//! See `makima/src/tts/` for the TTS engine implementation.
+//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
+
+use axum::{
+ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
+ response::Response,
+};
+use futures::{SinkExt, StreamExt};
+use serde::Deserialize;
+use uuid::Uuid;
+
+use crate::server::state::SharedState;
+
+/// Client-to-server control messages.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+enum ClientMessage {
+ /// Request speech synthesis for the given text.
+ Speak {
+ text: String,
+ /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection.
+ #[serde(default)]
+ #[allow(dead_code)]
+ voice: Option<String>,
+ },
+ /// Cancel any in-progress synthesis.
+ Cancel,
+ /// Graceful close.
+ Stop,
+}
+
+/// WebSocket upgrade handler for TTS streaming.
+///
+/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
+/// The TTS model runs directly in-process using candle — no external service.
+#[utoipa::path(
+ get,
+ path = "/api/v1/speak",
+ responses(
+ (status = 101, description = "WebSocket connection established"),
+ (status = 503, description = "TTS engine not available"),
+ ),
+ tag = "Speak"
+)]
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+/// Handle TTS WebSocket session with direct in-process inference.
+///
+/// Protocol:
+/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
+/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
+/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
+/// - Server sends JSON `{ "type": "error", ... }` on failures
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+ tracing::info!(session_id = %session_id, "New TTS WebSocket connection");
+
+ let (mut sender, mut receiver) = socket.split();
+
+ // Process incoming messages
+ while let Some(msg) = receiver.next().await {
+ let msg = match msg {
+ Ok(m) => m,
+ Err(e) => {
+ tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
+ break;
+ }
+ };
+
+ match msg {
+ Message::Text(text) => {
+ let client_msg: ClientMessage = match serde_json::from_str(&text) {
+ Ok(m) => m,
+ Err(e) => {
+ let _ = send_error(
+ &mut sender,
+ "INVALID_MESSAGE",
+ &format!("Failed to parse message: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ match client_msg {
+ ClientMessage::Speak { text, .. } => {
+ tracing::info!(
+ session_id = %session_id,
+ text_len = text.len(),
+ "TTS speak request"
+ );
+
+ // Get or lazily load the TTS engine
+ let engine = match state.get_tts_engine().await {
+ Ok(e) => e,
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "Failed to load TTS engine"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_LOAD_FAILED",
+ &format!("Failed to load TTS engine: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ if !engine.is_ready() {
+ let _ = send_error(
+ &mut sender,
+ "TTS_NOT_READY",
+ "TTS engine is not ready yet",
+ )
+ .await;
+ continue;
+ }
+
+ // Run TTS inference (no voice reference for now — uses default)
+ match engine.generate(&text, None, None).await {
+ Ok(chunks) => {
+ for chunk in &chunks {
+ // Send binary PCM audio data
+ let pcm_bytes = chunk.to_pcm16_bytes();
+ if sender
+ .send(Message::Binary(pcm_bytes.into()))
+ .await
+ .is_err()
+ {
+ tracing::warn!(
+ session_id = %session_id,
+ "Failed to send audio chunk — client disconnected"
+ );
+ return;
+ }
+ }
+
+ // Signal end of audio
+ let end_msg = serde_json::json!({
+ "type": "audio_end",
+ "sample_rate": engine.sample_rate(),
+ "format": "pcm_s16le",
+ "channels": 1,
+ });
+ let _ = sender
+ .send(Message::Text(end_msg.to_string().into()))
+ .await;
+ }
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "TTS inference failed"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_INFERENCE_FAILED",
+ &format!("TTS inference failed: {e}"),
+ )
+ .await;
+ }
+ }
+ }
+ ClientMessage::Cancel => {
+ tracing::info!(session_id = %session_id, "TTS cancel requested");
+ // TODO: support cancellation of in-progress inference
+ }
+ ClientMessage::Stop => {
+ tracing::info!(session_id = %session_id, "TTS stop requested, closing");
+ break;
+ }
+ }
+ }
+ Message::Close(_) => {
+ tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
+ break;
+ }
+ _ => {
+ // Ignore ping/pong/binary from client
+ }
+ }
+ }
+
+ tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
+}
+
+/// Send an error message to the client.
+async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
+where
+ S: SinkExt<Message> + Unpin,
+ <S as futures::Sink<Message>>::Error: std::error::Error,
+{
+ let error_msg = serde_json::json!({
+ "type": "error",
+ "code": code,
+ "message": message,
+ "recoverable": false
+ });
+
+ sender
+ .send(Message::Text(error_msg.to_string().into()))
+ .await
+ .ok();
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_error_message_format() {
+ let error = serde_json::json!({
+ "type": "error",
+ "code": "TEST_ERROR",
+ "message": "Test message",
+ "recoverable": false
+ });
+
+ assert_eq!(error["type"], "error");
+ assert_eq!(error["code"], "TEST_ERROR");
+ assert_eq!(error["message"], "Test message");
+ assert_eq!(error["recoverable"], false);
+ }
+
+ #[test]
+ fn test_client_message_parse_speak() {
+ let json = r#"{"type": "speak", "text": "Hello world"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ match msg {
+ ClientMessage::Speak { text, voice } => {
+ assert_eq!(text, "Hello world");
+ assert!(voice.is_none());
+ }
+ _ => panic!("Expected Speak message"),
+ }
+ }
+
+ #[test]
+ fn test_client_message_parse_cancel() {
+ let json = r#"{"type": "cancel"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Cancel));
+ }
+
+ #[test]
+ fn test_client_message_parse_stop() {
+ let json = r#"{"type": "stop"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Stop));
+ }
+}