summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/generate.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-28 03:50:45 +0000
committerGitHub <noreply@github.com>2026-01-28 03:50:45 +0000
commit9b53f6c6b01da85ef73bd5960b32ec319df0b947 (patch)
tree8c5e9983e1a5e75afab4a7d7a18ba22b75211628 /makima/src/tts/qwen3/generate.rs
parentc14192cc8b0e82369c93c1aee615fcc9cfad5911 (diff)
downloadsoryu-9b53f6c6b01da85ef73bd5960b32ec319df0b947.tar.gz
soryu-9b53f6c6b01da85ef73bd5960b32ec319df0b947.zip
Replace TTS endpoint with Rust-native Qwen3-TTS (#41)
* chore: fix unused import warnings in qwen3-tts module - Remove unused import 'IndexOp' in model.rs - Remove unused import 'DType' in speech_tokenizer.rs - Add #[allow(dead_code)] to codebook_dim field in RvqCodebook Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: add voice loading and selection for TTS cloning Add voice reference audio loading so the TTS speak handler can perform voice cloning using reference WAV files from the voices/ directory. - Add voice.rs module: loads manifest.json and reference.wav for a given voice_id, decodes via symphonia, resamples to 24kHz for the TTS engine - Update speak.rs: resolve voice_id from the speak request (default "makima"), load reference audio, pass it to engine.generate() - Add voices/makima/README.md with instructions for obtaining reference audio (extraction from YouTube, recording, ffmpeg conversion) - Graceful fallback: if reference audio is missing, TTS proceeds without voice cloning using the model's default voice Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-28 03:49:13 UTC --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'makima/src/tts/qwen3/generate.rs')
-rw-r--r--makima/src/tts/qwen3/generate.rs32
1 files changed, 31 insertions, 1 deletions
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
index 02161e6..30d165b 100644
--- a/makima/src/tts/qwen3/generate.rs
+++ b/makima/src/tts/qwen3/generate.rs
@@ -7,6 +7,9 @@
//! 4. Code predictor → remaining 15 codebook tokens per frame
//! 5. Speech tokenizer decoder → waveform audio
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use tokenizers::Tokenizer;
@@ -60,6 +63,9 @@ pub struct GenerationContext<'a> {
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
+ /// Optional cancellation flag. When set to `true`, the generation loop
+ /// will break early and return whatever audio has been produced so far.
+ cancel_flag: Option<Arc<AtomicBool>>,
}
impl<'a> GenerationContext<'a> {
@@ -70,6 +76,7 @@ impl<'a> GenerationContext<'a> {
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
+ cancel_flag: Option<Arc<AtomicBool>>,
) -> Self {
Self {
model,
@@ -78,9 +85,17 @@ impl<'a> GenerationContext<'a> {
tokenizer,
device,
config,
+ cancel_flag,
}
}
+ /// Check whether cancellation has been requested.
+ fn is_cancelled(&self) -> bool {
+ self.cancel_flag
+ .as_ref()
+ .map_or(false, |f| f.load(Ordering::Relaxed))
+ }
+
/// Generate audio from text, optionally with a voice reference.
///
/// Returns a list of audio chunks. If `streaming` is false, returns
@@ -194,6 +209,12 @@ impl<'a> GenerationContext<'a> {
// === Subsequent iterations: one token at a time ===
for _step in 1..self.config.max_new_tokens {
+ // Check for cancellation each iteration
+ if self.is_cancelled() {
+ tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
+ break;
+ }
+
let past_len = kv_caches[0].seq_len();
// Input: just the last generated zeroth codebook token
@@ -340,13 +361,22 @@ impl<'a> GenerationContext<'a> {
&self,
frames: &[Vec<u32>],
) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let mut chunks = Vec::new();
+ let mut chunks: Vec<AudioChunk> = Vec::new();
// Decode in groups of frames for efficiency
let chunk_size = 10; // ~800ms per chunk at 12.5Hz
let num_codebooks = self.speech_tokenizer.num_codebooks();
for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
+ // Check for cancellation between streaming chunks
+ if self.is_cancelled() {
+ tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
+ if let Some(last) = chunks.last_mut() {
+ last.is_final = true;
+ }
+ return Ok(chunks);
+ }
+
let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
// Transpose chunk frames