diff options
Diffstat (limited to 'makima/src/tts/qwen3/generate.rs')
| -rw-r--r-- | makima/src/tts/qwen3/generate.rs | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs index 02161e6..30d165b 100644 --- a/makima/src/tts/qwen3/generate.rs +++ b/makima/src/tts/qwen3/generate.rs @@ -7,6 +7,9 @@ //! 4. Code predictor → remaining 15 codebook tokens per frame //! 5. Speech tokenizer decoder → waveform audio +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + use candle_core::{DType, Device, IndexOp, Result, Tensor}; use tokenizers::Tokenizer; @@ -60,6 +63,9 @@ pub struct GenerationContext<'a> { tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, + /// Optional cancellation flag. When set to `true`, the generation loop + /// will break early and return whatever audio has been produced so far. + cancel_flag: Option<Arc<AtomicBool>>, } impl<'a> GenerationContext<'a> { @@ -70,6 +76,7 @@ impl<'a> GenerationContext<'a> { tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, + cancel_flag: Option<Arc<AtomicBool>>, ) -> Self { Self { model, @@ -78,9 +85,17 @@ impl<'a> GenerationContext<'a> { tokenizer, device, config, + cancel_flag, } } + /// Check whether cancellation has been requested. + fn is_cancelled(&self) -> bool { + self.cancel_flag + .as_ref() + .map_or(false, |f| f.load(Ordering::Relaxed)) + } + /// Generate audio from text, optionally with a voice reference. /// /// Returns a list of audio chunks. If `streaming` is false, returns @@ -194,6 +209,12 @@ impl<'a> GenerationContext<'a> { // === Subsequent iterations: one token at a time === for _step in 1..self.config.max_new_tokens { + // Check for cancellation each iteration + if self.is_cancelled() { + tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); + break; + } + let past_len = kv_caches[0].seq_len(); // Input: just the last generated zeroth codebook token @@ -340,13 +361,22 @@ impl<'a> GenerationContext<'a> { &self, frames: &[Vec<u32>], ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let mut chunks = Vec::new(); + let mut chunks: Vec<AudioChunk> = Vec::new(); // Decode in groups of frames for efficiency let chunk_size = 10; // ~800ms per chunk at 12.5Hz let num_codebooks = self.speech_tokenizer.num_codebooks(); for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { + // Check for cancellation between streaming chunks + if self.is_cancelled() { + tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); + if let Some(last) = chunks.last_mut() { + last.is_final = true; + } + return Ok(chunks); + } + let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); // Transpose chunk frames |
