summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-28 03:49:13 +0000
committersoryu <soryu@soryu.co>2026-01-28 03:49:54 +0000
commitd0436686f047f1d82c30da26cf83f9eca6727292 (patch)
treef4f7b6bc7c6cc410d90908d3adf7c519bdf6a2ad
parentc3de071511de5e8a8d63ea4ca47c815cb6450215 (diff)
downloadsoryu-d0436686f047f1d82c30da26cf83f9eca6727292.tar.gz
soryu-d0436686f047f1d82c30da26cf83f9eca6727292.zip
feat: add inference cancellation support for TTS generation
Add cooperative cancellation via Arc<AtomicBool> cancel flag that threads through TtsEngine::generate -> Qwen3Tts -> GenerationContext. The autoregressive loop and streaming decoder check the flag each iteration and break early when set. The speak WebSocket handler creates a per-session flag, passes it to generate, and sets it on Cancel/Stop/Close messages. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
-rw-r--r--makima/src/tts/chatterbox.rs4
-rw-r--r--makima/src/tts/mod.rs8
-rw-r--r--makima/src/tts/qwen3/generate.rs32
-rw-r--r--makima/src/tts/qwen3/mod.rs6
4 files changed, 47 insertions, 3 deletions
diff --git a/makima/src/tts/chatterbox.rs b/makima/src/tts/chatterbox.rs
index e26bc06..712910f 100644
--- a/makima/src/tts/chatterbox.rs
+++ b/makima/src/tts/chatterbox.rs
@@ -6,7 +6,8 @@
use std::borrow::Cow;
use std::fs;
use std::path::{Path, PathBuf};
-use std::sync::Mutex;
+use std::sync::atomic::AtomicBool;
+use std::sync::{Arc, Mutex};
use hf_hub::api::sync::Api;
use ndarray::{Array2, Array3, Array4, ArrayD, IxDyn};
@@ -427,6 +428,7 @@ impl TtsEngine for ChatterboxTTS {
text: &str,
reference_audio: Option<&[f32]>,
reference_sample_rate: Option<u32>,
+ _cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError> {
let samples = match reference_audio {
Some(audio) => {
diff --git a/makima/src/tts/mod.rs b/makima/src/tts/mod.rs
index 2cd0412..b66f4a5 100644
--- a/makima/src/tts/mod.rs
+++ b/makima/src/tts/mod.rs
@@ -5,6 +5,8 @@
//! - **Qwen3**: Pure Rust candle-based Qwen3-TTS-12Hz-0.6B
use std::path::Path;
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
pub mod chatterbox;
pub mod qwen3;
@@ -109,11 +111,17 @@ pub enum TtsBackend {
#[async_trait::async_trait]
pub trait TtsEngine: Send + Sync {
/// Generate complete audio from text with a voice reference.
+ ///
+ /// The optional `cancel_flag` can be set to `true` by another thread/task
+ /// to request early termination of the generation loop. Engines that
+ /// support cancellation will check this flag periodically and return
+ /// whatever audio has been produced so far.
async fn generate(
&self,
text: &str,
reference_audio: Option<&[f32]>,
reference_sample_rate: Option<u32>,
+ cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError>;
/// Check if the engine is loaded and ready.
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
index 02161e6..30d165b 100644
--- a/makima/src/tts/qwen3/generate.rs
+++ b/makima/src/tts/qwen3/generate.rs
@@ -7,6 +7,9 @@
//! 4. Code predictor → remaining 15 codebook tokens per frame
//! 5. Speech tokenizer decoder → waveform audio
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use tokenizers::Tokenizer;
@@ -60,6 +63,9 @@ pub struct GenerationContext<'a> {
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
+ /// Optional cancellation flag. When set to `true`, the generation loop
+ /// will break early and return whatever audio has been produced so far.
+ cancel_flag: Option<Arc<AtomicBool>>,
}
impl<'a> GenerationContext<'a> {
@@ -70,6 +76,7 @@ impl<'a> GenerationContext<'a> {
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
+ cancel_flag: Option<Arc<AtomicBool>>,
) -> Self {
Self {
model,
@@ -78,9 +85,17 @@ impl<'a> GenerationContext<'a> {
tokenizer,
device,
config,
+ cancel_flag,
}
}
+ /// Check whether cancellation has been requested.
+ fn is_cancelled(&self) -> bool {
+ self.cancel_flag
+ .as_ref()
+ .map_or(false, |f| f.load(Ordering::Relaxed))
+ }
+
/// Generate audio from text, optionally with a voice reference.
///
/// Returns a list of audio chunks. If `streaming` is false, returns
@@ -194,6 +209,12 @@ impl<'a> GenerationContext<'a> {
// === Subsequent iterations: one token at a time ===
for _step in 1..self.config.max_new_tokens {
+ // Check for cancellation each iteration
+ if self.is_cancelled() {
+ tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
+ break;
+ }
+
let past_len = kv_caches[0].seq_len();
// Input: just the last generated zeroth codebook token
@@ -340,13 +361,22 @@ impl<'a> GenerationContext<'a> {
&self,
frames: &[Vec<u32>],
) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let mut chunks = Vec::new();
+ let mut chunks: Vec<AudioChunk> = Vec::new();
// Decode in groups of frames for efficiency
let chunk_size = 10; // ~800ms per chunk at 12.5Hz
let num_codebooks = self.speech_tokenizer.num_codebooks();
for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
+ // Check for cancellation between streaming chunks
+ if self.is_cancelled() {
+ tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
+ if let Some(last) = chunks.last_mut() {
+ last.is_final = true;
+ }
+ return Ok(chunks);
+ }
+
let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
// Transpose chunk frames
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs
index c55c118..9bac794 100644
--- a/makima/src/tts/qwen3/mod.rs
+++ b/makima/src/tts/qwen3/mod.rs
@@ -30,6 +30,7 @@ pub mod speech_tokenizer;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
@@ -168,6 +169,7 @@ impl Qwen3Tts {
text: &str,
reference_audio: Option<&[f32]>,
gen_config: Option<GenerationConfig>,
+ cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError> {
let config = gen_config.unwrap_or_default();
@@ -178,6 +180,7 @@ impl Qwen3Tts {
&self.tokenizer,
&self.device,
config,
+ cancel_flag,
);
ctx.generate(text, reference_audio)
@@ -250,11 +253,12 @@ impl TtsEngine for Qwen3Tts {
text: &str,
reference_audio: Option<&[f32]>,
_reference_sample_rate: Option<u32>,
+ cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError> {
// Note: reference audio should already be resampled to 24kHz
// by the caller. If a different sample rate is provided,
// the caller should resample using `resample_to_24k()`.
- self.generate_speech(text, reference_audio, None)
+ self.generate_speech(text, reference_audio, None, cancel_flag)
}
fn is_ready(&self) -> bool {