diff options
| author | soryu <soryu@soryu.co> | 2026-01-28 03:50:45 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-28 03:50:45 +0000 |
| commit | 9b53f6c6b01da85ef73bd5960b32ec319df0b947 (patch) | |
| tree | 8c5e9983e1a5e75afab4a7d7a18ba22b75211628 /makima/src/server/handlers/speak.rs | |
| parent | c14192cc8b0e82369c93c1aee615fcc9cfad5911 (diff) | |
| download | soryu-9b53f6c6b01da85ef73bd5960b32ec319df0b947.tar.gz soryu-9b53f6c6b01da85ef73bd5960b32ec319df0b947.zip | |
Replace TTS endpoint with Rust-native Qwen3-TTS (#41)
* chore: fix unused import warnings in qwen3-tts module
- Remove unused import 'IndexOp' in model.rs
- Remove unused import 'DType' in speech_tokenizer.rs
- Add #[allow(dead_code)] to codebook_dim field in RvqCodebook
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* feat: add voice loading and selection for TTS cloning
Add voice reference audio loading so the TTS speak handler can perform
voice cloning using reference WAV files from the voices/ directory.
- Add voice.rs module: loads manifest.json and reference.wav for a given
voice_id, decodes via symphonia, resamples to 24kHz for the TTS engine
- Update speak.rs: resolve voice_id from the speak request (default
"makima"), load reference audio, pass it to engine.generate()
- Add voices/makima/README.md with instructions for obtaining reference
audio (extraction from YouTube, recording, ffmpeg conversion)
- Graceful fallback: if reference audio is missing, TTS proceeds without
voice cloning using the model's default voice
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* [WIP] Heartbeat checkpoint - 2026-01-28 03:49:13 UTC
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'makima/src/server/handlers/speak.rs')
| -rw-r--r-- | makima/src/server/handlers/speak.rs | 77 |
1 files changed, 70 insertions, 7 deletions
diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs index 75e7780..3ed2620 100644 --- a/makima/src/server/handlers/speak.rs +++ b/makima/src/server/handlers/speak.rs @@ -15,6 +15,9 @@ //! See `makima/src/tts/` for the TTS engine implementation. //! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification. +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, @@ -32,9 +35,9 @@ enum ClientMessage { /// Request speech synthesis for the given text. Speak { text: String, - /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection. + /// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning. + /// Defaults to "makima" if not specified. #[serde(default)] - #[allow(dead_code)] voice: Option<String>, }, /// Cancel any in-progress synthesis. @@ -76,6 +79,10 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { let (mut sender, mut receiver) = socket.split(); + // Cancellation flag shared between the message loop and inference. + // Each new Speak request resets it to false; Cancel sets it to true. + let cancel_flag: Arc<AtomicBool> = Arc::new(AtomicBool::new(false)); + // Process incoming messages while let Some(msg) = receiver.next().await { let msg = match msg { @@ -102,13 +109,41 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { }; match client_msg { - ClientMessage::Speak { text, .. } => { + ClientMessage::Speak { text, voice } => { + let voice_id = voice + .as_deref() + .unwrap_or(super::voice::DEFAULT_VOICE_ID); + tracing::info!( session_id = %session_id, text_len = text.len(), + voice_id = %voice_id, "TTS speak request" ); + // Load voice reference audio for cloning + let voice_ref = match super::voice::load_reference_audio(voice_id) { + Ok(v) => { + tracing::debug!( + session_id = %session_id, + voice_id = %voice_id, + voice_name = %v.manifest.name, + samples = v.samples.len(), + "Voice reference loaded" + ); + Some(v) + } + Err(e) => { + tracing::warn!( + session_id = %session_id, + voice_id = %voice_id, + error = %e, + "Failed to load voice reference, proceeding without cloning" + ); + None + } + }; + // Get or lazily load the TTS engine let engine = match state.get_tts_engine().await { Ok(e) => e, @@ -138,9 +173,21 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { continue; } - // Run TTS inference (no voice reference for now — uses default) - match engine.generate(&text, None, None).await { + // Reset the cancel flag for this new generation request + cancel_flag.store(false, Ordering::Relaxed); + + // Run TTS inference with optional voice reference for cloning + // and the cancel flag so it can be stopped early. + let (ref_audio, ref_rate) = match &voice_ref { + Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)), + None => (None, None), + }; + let flag = cancel_flag.clone(); + match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await { Ok(chunks) => { + // Check if generation was cancelled + let was_cancelled = cancel_flag.load(Ordering::Relaxed); + for chunk in &chunks { // Send binary PCM audio data let pcm_bytes = chunk.to_pcm16_bytes(); @@ -157,12 +204,13 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { } } - // Signal end of audio + // Signal end of audio (include cancelled status) let end_msg = serde_json::json!({ "type": "audio_end", "sample_rate": engine.sample_rate(), "format": "pcm_s16le", "channels": 1, + "cancelled": was_cancelled, }); let _ = sender .send(Message::Text(end_msg.to_string().into())) @@ -185,16 +233,18 @@ async fn handle_speak_socket(socket: WebSocket, state: SharedState) { } ClientMessage::Cancel => { tracing::info!(session_id = %session_id, "TTS cancel requested"); - // TODO: support cancellation of in-progress inference + cancel_flag.store(true, Ordering::Relaxed); } ClientMessage::Stop => { tracing::info!(session_id = %session_id, "TTS stop requested, closing"); + cancel_flag.store(true, Ordering::Relaxed); break; } } } Message::Close(_) => { tracing::info!(session_id = %session_id, "TTS WebSocket closed by client"); + cancel_flag.store(true, Ordering::Relaxed); break; } _ => { @@ -271,4 +321,17 @@ mod tests { let msg: ClientMessage = serde_json::from_str(json).unwrap(); assert!(matches!(msg, ClientMessage::Stop)); } + + #[test] + fn test_client_message_parse_speak_with_voice() { + let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#; + let msg: ClientMessage = serde_json::from_str(json).unwrap(); + match msg { + ClientMessage::Speak { text, voice } => { + assert_eq!(text, "Hello"); + assert_eq!(voice.as_deref(), Some("makima")); + } + _ => panic!("Expected Speak message"), + } + } } |
