summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/speak.rs
blob: 0f94b40311a18b8a095fd2cb602795e0009f39aa (plain) (tree)
1
2
3
4
5
6
7
8
9
10


                                                                      
                                                                           





                                                                      
                                                                   



                                                               
 


                                              
















                                                                   

                                                                                               
                         










                                                                             
                                                                              




























                                                                             



                                                                         

























                                                                                                




                                                                       


                                                     
                                                 


                                               






















                                                                                                




























                                                                               










                                                                                             
                                           


                                                                                        















                                                                                                
                                                                                 




                                                                        
                                                               





















                                                                                         
                                                                   


                                                                                                
                                                                   





                                                                                           
                                                           











































































                                                                                            












                                                                              
 
//! WebSocket handler for TTS streaming (direct in-process inference).
//!
//! This module implements the `/api/v1/speak` endpoint which performs
//! text-to-speech synthesis directly using the Chatterbox ONNX TTS engine.
//! No external Python service or proxy — the model runs in-process.
//!
//! ## Architecture
//!
//! The speak handler will:
//! 1. Accept a WebSocket connection from the client
//! 2. Lazily load the TTS model (Chatterbox ONNX) on first request
//! 3. Parse JSON control messages (start, speak, stop, cancel)
//! 4. Run inference directly and stream audio chunks back
//!
//! See `makima/src/tts/` for the TTS engine implementation.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use axum::{
    extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
    response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use uuid::Uuid;

use crate::server::state::SharedState;

/// Client-to-server control messages.
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage {
    /// Request speech synthesis for the given text.
    Speak {
        text: String,
        /// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning.
        /// Defaults to "makima" if not specified.
        #[serde(default)]
        voice: Option<String>,
    },
    /// Cancel any in-progress synthesis.
    Cancel,
    /// Graceful close.
    Stop,
}

/// WebSocket upgrade handler for TTS streaming.
///
/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
/// The TTS model runs directly in-process using ONNX — no external service.
#[utoipa::path(
    get,
    path = "/api/v1/speak",
    responses(
        (status = 101, description = "WebSocket connection established"),
        (status = 503, description = "TTS engine not available"),
    ),
    tag = "Speak"
)]
pub async fn websocket_handler(
    ws: WebSocketUpgrade,
    State(state): State<SharedState>,
) -> Response {
    ws.on_upgrade(|socket| handle_speak_socket(socket, state))
}

/// Handle TTS WebSocket session with direct in-process inference.
///
/// Protocol:
/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
/// - Server sends JSON `{ "type": "error", ... }` on failures
async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
    let session_id = Uuid::new_v4().to_string();
    tracing::info!(session_id = %session_id, "New TTS WebSocket connection");

    let (mut sender, mut receiver) = socket.split();

    // Cancellation flag shared between the message loop and inference.
    // Each new Speak request resets it to false; Cancel sets it to true.
    let cancel_flag: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));

    // Process incoming messages
    while let Some(msg) = receiver.next().await {
        let msg = match msg {
            Ok(m) => m,
            Err(e) => {
                tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
                break;
            }
        };

        match msg {
            Message::Text(text) => {
                let client_msg: ClientMessage = match serde_json::from_str(&text) {
                    Ok(m) => m,
                    Err(e) => {
                        let _ = send_error(
                            &mut sender,
                            "INVALID_MESSAGE",
                            &format!("Failed to parse message: {e}"),
                        )
                        .await;
                        continue;
                    }
                };

                match client_msg {
                    ClientMessage::Speak { text, voice } => {
                        let voice_id = voice
                            .as_deref()
                            .unwrap_or(super::voice::DEFAULT_VOICE_ID);

                        tracing::info!(
                            session_id = %session_id,
                            text_len = text.len(),
                            voice_id = %voice_id,
                            "TTS speak request"
                        );

                        // Load voice reference audio for cloning
                        let voice_ref = match super::voice::load_reference_audio(voice_id) {
                            Ok(v) => {
                                tracing::debug!(
                                    session_id = %session_id,
                                    voice_id = %voice_id,
                                    voice_name = %v.manifest.name,
                                    samples = v.samples.len(),
                                    "Voice reference loaded"
                                );
                                Some(v)
                            }
                            Err(e) => {
                                tracing::warn!(
                                    session_id = %session_id,
                                    voice_id = %voice_id,
                                    error = %e,
                                    "Failed to load voice reference, proceeding without cloning"
                                );
                                None
                            }
                        };

                        // Get or lazily load the TTS engine
                        let engine = match state.get_tts_engine().await {
                            Ok(e) => e,
                            Err(e) => {
                                tracing::error!(
                                    session_id = %session_id,
                                    error = %e,
                                    "Failed to load TTS engine"
                                );
                                let _ = send_error(
                                    &mut sender,
                                    "TTS_LOAD_FAILED",
                                    &format!("Failed to load TTS engine: {e}"),
                                )
                                .await;
                                continue;
                            }
                        };

                        if !engine.is_ready() {
                            let _ = send_error(
                                &mut sender,
                                "TTS_NOT_READY",
                                "TTS engine is not ready yet",
                            )
                            .await;
                            continue;
                        }

                        // Reset the cancel flag for this new generation request
                        cancel_flag.store(false, Ordering::Relaxed);

                        // Run TTS inference with optional voice reference for cloning
                        // and the cancel flag so it can be stopped early.
                        let (ref_audio, ref_rate) = match &voice_ref {
                            Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)),
                            None => (None, None),
                        };
                        let flag = cancel_flag.clone();
                        match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await {
                            Ok(chunks) => {
                                // Check if generation was cancelled
                                let was_cancelled = cancel_flag.load(Ordering::Relaxed);

                                for chunk in &chunks {
                                    // Send binary PCM audio data
                                    let pcm_bytes = chunk.to_pcm16_bytes();
                                    if sender
                                        .send(Message::Binary(pcm_bytes.into()))
                                        .await
                                        .is_err()
                                    {
                                        tracing::warn!(
                                            session_id = %session_id,
                                            "Failed to send audio chunk — client disconnected"
                                        );
                                        return;
                                    }
                                }

                                // Signal end of audio (include cancelled status)
                                let end_msg = serde_json::json!({
                                    "type": "audio_end",
                                    "sample_rate": engine.sample_rate(),
                                    "format": "pcm_s16le",
                                    "channels": 1,
                                    "cancelled": was_cancelled,
                                });
                                let _ = sender
                                    .send(Message::Text(end_msg.to_string().into()))
                                    .await;
                            }
                            Err(e) => {
                                tracing::error!(
                                    session_id = %session_id,
                                    error = %e,
                                    "TTS inference failed"
                                );
                                let _ = send_error(
                                    &mut sender,
                                    "TTS_INFERENCE_FAILED",
                                    &format!("TTS inference failed: {e}"),
                                )
                                .await;
                            }
                        }
                    }
                    ClientMessage::Cancel => {
                        tracing::info!(session_id = %session_id, "TTS cancel requested");
                        cancel_flag.store(true, Ordering::Relaxed);
                    }
                    ClientMessage::Stop => {
                        tracing::info!(session_id = %session_id, "TTS stop requested, closing");
                        cancel_flag.store(true, Ordering::Relaxed);
                        break;
                    }
                }
            }
            Message::Close(_) => {
                tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
                cancel_flag.store(true, Ordering::Relaxed);
                break;
            }
            _ => {
                // Ignore ping/pong/binary from client
            }
        }
    }

    tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
}

/// Send an error message to the client.
async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
where
    S: SinkExt<Message> + Unpin,
    <S as futures::Sink<Message>>::Error: std::error::Error,
{
    let error_msg = serde_json::json!({
        "type": "error",
        "code": code,
        "message": message,
        "recoverable": false
    });

    sender
        .send(Message::Text(error_msg.to_string().into()))
        .await
        .ok();
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_error_message_format() {
        let error = serde_json::json!({
            "type": "error",
            "code": "TEST_ERROR",
            "message": "Test message",
            "recoverable": false
        });

        assert_eq!(error["type"], "error");
        assert_eq!(error["code"], "TEST_ERROR");
        assert_eq!(error["message"], "Test message");
        assert_eq!(error["recoverable"], false);
    }

    #[test]
    fn test_client_message_parse_speak() {
        let json = r#"{"type": "speak", "text": "Hello world"}"#;
        let msg: ClientMessage = serde_json::from_str(json).unwrap();
        match msg {
            ClientMessage::Speak { text, voice } => {
                assert_eq!(text, "Hello world");
                assert!(voice.is_none());
            }
            _ => panic!("Expected Speak message"),
        }
    }

    #[test]
    fn test_client_message_parse_cancel() {
        let json = r#"{"type": "cancel"}"#;
        let msg: ClientMessage = serde_json::from_str(json).unwrap();
        assert!(matches!(msg, ClientMessage::Cancel));
    }

    #[test]
    fn test_client_message_parse_stop() {
        let json = r#"{"type": "stop"}"#;
        let msg: ClientMessage = serde_json::from_str(json).unwrap();
        assert!(matches!(msg, ClientMessage::Stop));
    }

    #[test]
    fn test_client_message_parse_speak_with_voice() {
        let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#;
        let msg: ClientMessage = serde_json::from_str(json).unwrap();
        match msg {
            ClientMessage::Speak { text, voice } => {
                assert_eq!(text, "Hello");
                assert_eq!(voice.as_deref(), Some("makima"));
            }
            _ => panic!("Expected Speak message"),
        }
    }
}