1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
|
//! WebSocket handler for TTS streaming (direct in-process inference).
//!
//! This module implements the `/api/v1/speak` endpoint which performs
//! text-to-speech synthesis directly using the candle-based TTS engine.
//! No external Python service or proxy — the model runs in-process.
//!
//! ## Architecture
//!
//! The speak handler will:
//! 1. Accept a WebSocket connection from the client
//! 2. Lazily load the TTS model (candle) on first request
//! 3. Parse JSON control messages (start, speak, stop, cancel)
//! 4. Run inference directly and stream audio chunks back
//!
//! See `makima/src/tts/` for the TTS engine implementation.
//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use uuid::Uuid;
use crate::server::state::SharedState;
/// Client-to-server control messages.
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage {
/// Request speech synthesis for the given text.
Speak {
text: String,
/// Optional voice ID (e.g., "makima"). Used to load reference audio for voice cloning.
/// Defaults to "makima" if not specified.
#[serde(default)]
voice: Option<String>,
},
/// Cancel any in-progress synthesis.
Cancel,
/// Graceful close.
Stop,
}
/// WebSocket upgrade handler for TTS streaming.
///
/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
/// The TTS model runs directly in-process using candle — no external service.
#[utoipa::path(
get,
path = "/api/v1/speak",
responses(
(status = 101, description = "WebSocket connection established"),
(status = 503, description = "TTS engine not available"),
),
tag = "Speak"
)]
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_speak_socket(socket, state))
}
/// Handle TTS WebSocket session with direct in-process inference.
///
/// Protocol:
/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
/// - Server sends JSON `{ "type": "error", ... }` on failures
async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
let session_id = Uuid::new_v4().to_string();
tracing::info!(session_id = %session_id, "New TTS WebSocket connection");
let (mut sender, mut receiver) = socket.split();
// Cancellation flag shared between the message loop and inference.
// Each new Speak request resets it to false; Cancel sets it to true.
let cancel_flag: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
// Process incoming messages
while let Some(msg) = receiver.next().await {
let msg = match msg {
Ok(m) => m,
Err(e) => {
tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
break;
}
};
match msg {
Message::Text(text) => {
let client_msg: ClientMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let _ = send_error(
&mut sender,
"INVALID_MESSAGE",
&format!("Failed to parse message: {e}"),
)
.await;
continue;
}
};
match client_msg {
ClientMessage::Speak { text, voice } => {
let voice_id = voice
.as_deref()
.unwrap_or(super::voice::DEFAULT_VOICE_ID);
tracing::info!(
session_id = %session_id,
text_len = text.len(),
voice_id = %voice_id,
"TTS speak request"
);
// Load voice reference audio for cloning
let voice_ref = match super::voice::load_reference_audio(voice_id) {
Ok(v) => {
tracing::debug!(
session_id = %session_id,
voice_id = %voice_id,
voice_name = %v.manifest.name,
samples = v.samples.len(),
"Voice reference loaded"
);
Some(v)
}
Err(e) => {
tracing::warn!(
session_id = %session_id,
voice_id = %voice_id,
error = %e,
"Failed to load voice reference, proceeding without cloning"
);
None
}
};
// Get or lazily load the TTS engine
let engine = match state.get_tts_engine().await {
Ok(e) => e,
Err(e) => {
tracing::error!(
session_id = %session_id,
error = %e,
"Failed to load TTS engine"
);
let _ = send_error(
&mut sender,
"TTS_LOAD_FAILED",
&format!("Failed to load TTS engine: {e}"),
)
.await;
continue;
}
};
if !engine.is_ready() {
let _ = send_error(
&mut sender,
"TTS_NOT_READY",
"TTS engine is not ready yet",
)
.await;
continue;
}
// Reset the cancel flag for this new generation request
cancel_flag.store(false, Ordering::Relaxed);
// Run TTS inference with optional voice reference for cloning
// and the cancel flag so it can be stopped early.
let (ref_audio, ref_rate) = match &voice_ref {
Some(v) => (Some(v.samples.as_slice()), Some(v.sample_rate)),
None => (None, None),
};
let flag = cancel_flag.clone();
match engine.generate(&text, ref_audio, ref_rate, Some(flag)).await {
Ok(chunks) => {
// Check if generation was cancelled
let was_cancelled = cancel_flag.load(Ordering::Relaxed);
for chunk in &chunks {
// Send binary PCM audio data
let pcm_bytes = chunk.to_pcm16_bytes();
if sender
.send(Message::Binary(pcm_bytes.into()))
.await
.is_err()
{
tracing::warn!(
session_id = %session_id,
"Failed to send audio chunk — client disconnected"
);
return;
}
}
// Signal end of audio (include cancelled status)
let end_msg = serde_json::json!({
"type": "audio_end",
"sample_rate": engine.sample_rate(),
"format": "pcm_s16le",
"channels": 1,
"cancelled": was_cancelled,
});
let _ = sender
.send(Message::Text(end_msg.to_string().into()))
.await;
}
Err(e) => {
tracing::error!(
session_id = %session_id,
error = %e,
"TTS inference failed"
);
let _ = send_error(
&mut sender,
"TTS_INFERENCE_FAILED",
&format!("TTS inference failed: {e}"),
)
.await;
}
}
}
ClientMessage::Cancel => {
tracing::info!(session_id = %session_id, "TTS cancel requested");
cancel_flag.store(true, Ordering::Relaxed);
}
ClientMessage::Stop => {
tracing::info!(session_id = %session_id, "TTS stop requested, closing");
cancel_flag.store(true, Ordering::Relaxed);
break;
}
}
}
Message::Close(_) => {
tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
cancel_flag.store(true, Ordering::Relaxed);
break;
}
_ => {
// Ignore ping/pong/binary from client
}
}
}
tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
}
/// Send an error message to the client.
async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
where
S: SinkExt<Message> + Unpin,
<S as futures::Sink<Message>>::Error: std::error::Error,
{
let error_msg = serde_json::json!({
"type": "error",
"code": code,
"message": message,
"recoverable": false
});
sender
.send(Message::Text(error_msg.to_string().into()))
.await
.ok();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_message_format() {
let error = serde_json::json!({
"type": "error",
"code": "TEST_ERROR",
"message": "Test message",
"recoverable": false
});
assert_eq!(error["type"], "error");
assert_eq!(error["code"], "TEST_ERROR");
assert_eq!(error["message"], "Test message");
assert_eq!(error["recoverable"], false);
}
#[test]
fn test_client_message_parse_speak() {
let json = r#"{"type": "speak", "text": "Hello world"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Speak { text, voice } => {
assert_eq!(text, "Hello world");
assert!(voice.is_none());
}
_ => panic!("Expected Speak message"),
}
}
#[test]
fn test_client_message_parse_cancel() {
let json = r#"{"type": "cancel"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Cancel));
}
#[test]
fn test_client_message_parse_stop() {
let json = r#"{"type": "stop"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert!(matches!(msg, ClientMessage::Stop));
}
#[test]
fn test_client_message_parse_speak_with_voice() {
let json = r#"{"type": "speak", "text": "Hello", "voice": "makima"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
match msg {
ClientMessage::Speak { text, voice } => {
assert_eq!(text, "Hello");
assert_eq!(voice.as_deref(), Some("makima"));
}
_ => panic!("Expected Speak message"),
}
}
}
|