//! WebSocket handler for TTS streaming (direct in-process inference).
//!
//! This module implements the `/api/v1/speak` endpoint which performs
//! text-to-speech synthesis directly using the candle-based TTS engine.
//! No external Python service or proxy — the model runs in-process.
//!
//! ## Architecture
//!
//! The speak handler will:
//! 1. Accept a WebSocket connection from the client
//! 2. Lazily load the TTS model (candle) on first request
//! 3. Parse JSON control messages (start, speak, stop, cancel)
//! 4. Run inference directly and stream audio chunks back
//!
//! See `makima/src/tts/` for the TTS engine implementation.
//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use uuid::Uuid;
use crate::server::state::SharedState;
/// Client-to-server control messages.
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage {
/// Request speech synthesis for the given text.
Speak {
text: String,
/// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning.
/// Defaults to "makima" if not specified.
#[serde(default)]
voice: Option<String>,
},
/// Cancel any in-progress synthesis.
Cancel,
/// Graceful close.
Stop,
}
/// WebSocket upgrade handler for TTS streaming.
///
/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
/// The TTS model runs directly in-process using candle — no external service.
#[utoipa::path(
get,
path = "/api/v1/speak",
responses(
(status = 101, description = "WebSocket connection established"),
(status = 503, description = "TTS engine not available"),
),
tag = "Speak"
)]
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_speak_socket(socket, state))
}
/// Handle TTS WebSocket session with direct in-process inference.
///
/// Protocol:
/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
/// - Server sends JSON `{ "type": "error", ... }` on failures
async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
let session_id = Uuid::new_v4().to_string();
tracing::info!(session_id = %session_id, "New TTS WebSocket connection");
let (mut sender, mut receiver) = socket.split();
// Cancellation flag shared between the message loop and inference.
// Each new Speak request resets it to false; Cancel sets it to true.
let cancel_flag: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
// Process incoming messages
while let Some(msg) = receiver.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
break;
}
};
match msg {
Message::Text(text) => {
let client_msg: ClientMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let _ = send_error(
&mut sender,
"INVALID_MESSAGE",
&format!("Failed to parse message: {e}"),
)
.await;
continue;
}
};
match client_msg {
ClientMessage::Speak { text, voice } => {
let voice_id = voice
.as_deref()
.unwrap_or(super::voice::DEFAULT_VOICE_ID);
tracing::info!(
session_id = %session_id,
text_len = text.len(),
voice_id = %voice_id,
"TTS speak request"
);
// Load voice reference audio for cloning
let voice_ref = match super::voice::load_reference_audio(voice_id) {
Ok(v) => {
tracing::debug!(
session_id = %session_id,
voice_id = %voice_id,
voice_name = %v.manifest.name,
samples = v.samples.len(),
"Voice reference loaded"
);
Some(v)
}
Err(e) => {
tracing::warn!(
session_id = %session_id,
voice_id = %voice_id,
error = %e,
"Failed to load voice reference, proceeding without cloning"
);
None
}
};
// Get or lazily load the TTS engine
let engine = match state.get_tts_engine().await {
Ok(e) => e,
Err(e) => {
tracing::error!(
session_id = %session_id,
error = %e,
"Failed to load TTS engine"
);
let _ = send_error(
&mut sender,
"TTS_LOAD_FAILED",
&format!("Failed to load TTS engine: {e}"),
)
.await;
continue;
}
};
if !engine.is_ready() {
let _ = send_error(
&mut sender,
"TTS_NOT_READY",
"TTS engine is not ready yet",
)
.await;
continue;
}
// Reset the cancel flag for this new generation request
cancel_flag.store(false, Ordering::Relaxed);
// Run TTS inference with optional voice reference for cloning
// and the cancel flag so it can be stopped early.
let (ref_audio, ref_rate) = match &voice_ref {
Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)),
None => (None, None),
};
let flag = cancel_flag.clone();
match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await {
Ok(chunks) => {
// Check if generation was cancelled
let was_cancelled = cancel_flag.load(Ordering::Relaxed);
for chunk in &chunks {
// Send binary PCM audio data
let pcm_bytes = chunk.to_pcm16_bytes();
if sender
.send(Message::Binary(pcm_bytes.into()))
.await
.is_err()
{
tracing::warn!(
session_id = %session_id,
"Failed to send audio chunk — client disconnected"
);
return;
}
}
// Signal end of audio (include cancelled status)
let end_msg = serde_json::json!({
"type": "audio_end",
"sample_rate": engine.sample_rate(),
"format": "pcm_s16le",
"channels": 1,
"cancelled": was_cancelled,
});
let _ = sender
.send(Message::Text(end_msg.to_string().into()))
.await;
}
Err(e) => {
tracing::error!(
session_id = %session_id,
error = %e,
"TTS inference failed"
);
let _ = send_error(
&mut sender,
"TTS_INFERENCE_FAILED",
&format!("TTS inference failed: {e}"),
)
.await;
}
}
}
ClientMessage::Cancel => {
tracing::info!(session_id = %session_id, "TTS cancel requested");
cancel_flag.store(true, Ordering::Relaxed);
}
ClientMessage::Stop => {
tracing::info!(session_id = %session_id, "TTS stop requested, closing");
cancel_flag.store(true, Ordering::Relaxed);
break;
}
}
}
Message::Close(_) => {
tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
cancel_flag.store(true, Ordering::Relaxed);
break;
}
_ => {
// Ignore ping/pong/binary from client
}
}
}
tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
}
/// Send an error message to the client.
async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
where
S: SinkExt<Message> + Unpin,
<S as futures::Sink<Message>>::Error: std::error::Error,
{
let error_msg = serde_json::json!({
"type": "error",
"code": code,
"message": message,
"recoverable": false
});
sender
.send(Message::Text(error_msg.to_string().into()))
.await
.ok();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_message_format() {
let error = serde_json::json!({
"type": "error",
"code": "TEST_ERROR",
"message": "Test message",
"recoverable": false
});
assert_eq!(error["type"], "error");
assert_eq!(error["code"], "TEST_ERROR");
assert_eq!(error["message"], "Test message");
assert_eq!(error["recoverable"], false);
}
#[test]
fn test_client_message_parse_speak() {
let json = r#"{"type": "speak", "text": "Hello world"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Speak { text, voice } => {
assert_eq!(text, "Hello world");
assert!(voice.is_none());
}
_ => panic!("Expected Speak message"),
}
}
#[test]
fn test_client_message_parse_cancel() {
let json = r#"{"type": "cancel"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Cancel));
}
#[test]
fn test_client_message_parse_stop() {
let json = r#"{"type": "stop"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Stop));
}
#[test]
fn test_client_message_parse_speak_with_voice() {
let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Speak { text, voice } => {
assert_eq!(text, "Hello");
assert_eq!(voice.as_deref(), Some("makima"));
}
_ => panic!("Expected Speak message"),
}
}
}