summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/speak.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/handlers/speak.rs')
-rw-r--r--makima/src/server/handlers/speak.rs77
1 files changed, 70 insertions, 7 deletions
diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs
index 75e7780..3ed2620 100644
--- a/makima/src/server/handlers/speak.rs
+++ b/makima/src/server/handlers/speak.rs
@@ -15,6 +15,9 @@
//! See `makima/src/tts/` for the TTS engine implementation.
//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
@@ -32,9 +35,9 @@ enum ClientMessage {
/// Request speech synthesis for the given text.
Speak {
text: String,
- /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection.
+ /// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning.
+ /// Defaults to "makima" if not specified.
#[serde(default)]
- #[allow(dead_code)]
voice: Option<String>,
},
/// Cancel any in-progress synthesis.
@@ -76,6 +79,10 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
let (mut sender, mut receiver) = socket.split();
+ // Cancellation flag shared between the message loop and inference.
+ // Each new Speak request resets it to false; Cancel sets it to true.
+ let cancel_flag: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
+
// Process incoming messages
while let Some(msg) = receiver.next().await {
let msg = match msg {
@@ -102,13 +109,41 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
};
match client_msg {
- ClientMessage::Speak { text, .. } => {
+ ClientMessage::Speak { text, voice } => {
+ let voice_id = voice
+ .as_deref()
+ .unwrap_or(super::voice::DEFAULT_VOICE_ID);
+
tracing::info!(
session_id = %session_id,
text_len = text.len(),
+ voice_id = %voice_id,
"TTS speak request"
);
+ // Load voice reference audio for cloning
+ let voice_ref = match super::voice::load_reference_audio(voice_id) {
+ Ok(v) => {
+ tracing::debug!(
+ session_id = %session_id,
+ voice_id = %voice_id,
+ voice_name = %v.manifest.name,
+ samples = v.samples.len(),
+ "Voice reference loaded"
+ );
+ Some(v)
+ }
+ Err(e) => {
+ tracing::warn!(
+ session_id = %session_id,
+ voice_id = %voice_id,
+ error = %e,
+ "Failed to load voice reference, proceeding without cloning"
+ );
+ None
+ }
+ };
+
// Get or lazily load the TTS engine
let engine = match state.get_tts_engine().await {
Ok(e) => e,
@@ -138,9 +173,21 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
continue;
}
- // Run TTS inference (no voice reference for now — uses default)
- match engine.generate(&text, None, None).await {
+ // Reset the cancel flag for this new generation request
+ cancel_flag.store(false, Ordering::Relaxed);
+
+ // Run TTS inference with optional voice reference for cloning
+ // and the cancel flag so it can be stopped early.
+ let (ref_audio, ref_rate) = match &voice_ref {
+ Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)),
+ None => (None, None),
+ };
+ let flag = cancel_flag.clone();
+ match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await {
Ok(chunks) => {
+ // Check if generation was cancelled
+ let was_cancelled = cancel_flag.load(Ordering::Relaxed);
+
for chunk in &chunks {
// Send binary PCM audio data
let pcm_bytes = chunk.to_pcm16_bytes();
@@ -157,12 +204,13 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
}
}
- // Signal end of audio
+ // Signal end of audio (include cancelled status)
let end_msg = serde_json::json!({
"type": "audio_end",
"sample_rate": engine.sample_rate(),
"format": "pcm_s16le",
"channels": 1,
+ "cancelled": was_cancelled,
});
let _ = sender
.send(Message::Text(end_msg.to_string().into()))
@@ -185,16 +233,18 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
}
ClientMessage::Cancel => {
tracing::info!(session_id = %session_id, "TTS cancel requested");
- // TODO: support cancellation of in-progress inference
+ cancel_flag.store(true, Ordering::Relaxed);
}
ClientMessage::Stop => {
tracing::info!(session_id = %session_id, "TTS stop requested, closing");
+ cancel_flag.store(true, Ordering::Relaxed);
break;
}
}
}
Message::Close(_) => {
tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
+ cancel_flag.store(true, Ordering::Relaxed);
break;
}
_ => {
@@ -271,4 +321,17 @@ mod tests {
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Stop));
}
+
+ #[test]
+ fn test_client_message_parse_speak_with_voice() {
+ let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ match msg {
+ ClientMessage::Speak { text, voice } => {
+ assert_eq!(text, "Hello");
+ assert_eq!(voice.as_deref(), Some("makima"));
+ }
+ _ => panic!("Expected Speak message"),
+ }
+ }
}