From 9b53f6c6b01da85ef73bd5960b32ec319df0b947 Mon Sep 17 00:00:00 2001 From: soryu Date: Wed, 28 Jan 2026 03:50:45 +0000 Subject: Replace TTS endpoint with Rust-native Qwen3-TTS (#41) * chore: fix unused import warnings in qwen3-tts module - Remove unused import 'IndexOp' in model.rs - Remove unused import 'DType' in speech_tokenizer.rs - Add #[allow(dead_code)] to codebook_dim field in RvqCodebook Co-Authored-By: Claude Opus 4.5 * feat: add voice loading and selection for TTS cloning Add voice reference audio loading so the TTS speak handler can perform voice cloning using reference WAV files from the voices/ directory. - Add voice.rs module: loads manifest.json and reference.wav for a given voice_id, decodes via symphonia, resamples to 24kHz for the TTS engine - Update speak.rs: resolve voice_id from the speak request (default "makima"), load reference audio, pass it to engine.generate() - Add voices/makima/README.md with instructions for obtaining reference audio (extraction from YouTube, recording, ffmpeg conversion) - Graceful fallback: if reference audio is missing, TTS proceeds without voice cloning using the model's default voice Co-Authored-By: Claude Opus 4.5 * [WIP] Heartbeat checkpoint - 2026-01-28 03:49:13 UTC --------- Co-authored-by: Claude Opus 4.5 --- makima/src/tts/chatterbox.rs | 4 +++- makima/src/tts/mod.rs | 8 ++++++++ makima/src/tts/qwen3/generate.rs | 32 +++++++++++++++++++++++++++++++- makima/src/tts/qwen3/mod.rs | 6 +++++- makima/src/tts/qwen3/model.rs | 2 +- makima/src/tts/qwen3/speech_tokenizer.rs | 3 ++- 6 files changed, 50 insertions(+), 5 deletions(-) (limited to 'makima/src/tts') diff --git a/makima/src/tts/chatterbox.rs b/makima/src/tts/chatterbox.rs index e26bc06..712910f 100644 --- a/makima/src/tts/chatterbox.rs +++ b/makima/src/tts/chatterbox.rs @@ -6,7 +6,8 @@ use std::borrow::Cow; use std::fs; use std::path::{Path, PathBuf}; -use std::sync::Mutex; +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, Mutex}; use hf_hub::api::sync::Api; use ndarray::{Array2, Array3, Array4, ArrayD, IxDyn}; @@ -427,6 +428,7 @@ impl TtsEngine for ChatterboxTTS { text: &str, reference_audio: Option<&[f32]>, reference_sample_rate: Option, + _cancel_flag: Option>, ) -> Result, TtsError> { let samples = match reference_audio { Some(audio) => { diff --git a/makima/src/tts/mod.rs b/makima/src/tts/mod.rs index 2cd0412..b66f4a5 100644 --- a/makima/src/tts/mod.rs +++ b/makima/src/tts/mod.rs @@ -5,6 +5,8 @@ //! - **Qwen3**: Pure Rust candle-based Qwen3-TTS-12Hz-0.6B use std::path::Path; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; pub mod chatterbox; pub mod qwen3; @@ -109,11 +111,17 @@ pub enum TtsBackend { #[async_trait::async_trait] pub trait TtsEngine: Send + Sync { /// Generate complete audio from text with a voice reference. + /// + /// The optional `cancel_flag` can be set to `true` by another thread/task + /// to request early termination of the generation loop. Engines that + /// support cancellation will check this flag periodically and return + /// whatever audio has been produced so far. async fn generate( &self, text: &str, reference_audio: Option<&[f32]>, reference_sample_rate: Option, + cancel_flag: Option>, ) -> Result, TtsError>; /// Check if the engine is loaded and ready. diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs index 02161e6..30d165b 100644 --- a/makima/src/tts/qwen3/generate.rs +++ b/makima/src/tts/qwen3/generate.rs @@ -7,6 +7,9 @@ //! 4. Code predictor → remaining 15 codebook tokens per frame //! 5. Speech tokenizer decoder → waveform audio +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + use candle_core::{DType, Device, IndexOp, Result, Tensor}; use tokenizers::Tokenizer; @@ -60,6 +63,9 @@ pub struct GenerationContext<'a> { tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, + /// Optional cancellation flag. When set to `true`, the generation loop + /// will break early and return whatever audio has been produced so far. + cancel_flag: Option>, } impl<'a> GenerationContext<'a> { @@ -70,6 +76,7 @@ impl<'a> GenerationContext<'a> { tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, + cancel_flag: Option>, ) -> Self { Self { model, @@ -78,9 +85,17 @@ impl<'a> GenerationContext<'a> { tokenizer, device, config, + cancel_flag, } } + /// Check whether cancellation has been requested. + fn is_cancelled(&self) -> bool { + self.cancel_flag + .as_ref() + .map_or(false, |f| f.load(Ordering::Relaxed)) + } + /// Generate audio from text, optionally with a voice reference. /// /// Returns a list of audio chunks. If `streaming` is false, returns @@ -194,6 +209,12 @@ impl<'a> GenerationContext<'a> { // === Subsequent iterations: one token at a time === for _step in 1..self.config.max_new_tokens { + // Check for cancellation each iteration + if self.is_cancelled() { + tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); + break; + } + let past_len = kv_caches[0].seq_len(); // Input: just the last generated zeroth codebook token @@ -340,13 +361,22 @@ impl<'a> GenerationContext<'a> { &self, frames: &[Vec], ) -> std::result::Result, TtsError> { - let mut chunks = Vec::new(); + let mut chunks: Vec = Vec::new(); // Decode in groups of frames for efficiency let chunk_size = 10; // ~800ms per chunk at 12.5Hz let num_codebooks = self.speech_tokenizer.num_codebooks(); for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { + // Check for cancellation between streaming chunks + if self.is_cancelled() { + tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); + if let Some(last) = chunks.last_mut() { + last.is_final = true; + } + return Ok(chunks); + } + let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); // Transpose chunk frames diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs index c55c118..9bac794 100644 --- a/makima/src/tts/qwen3/mod.rs +++ b/makima/src/tts/qwen3/mod.rs @@ -30,6 +30,7 @@ pub mod speech_tokenizer; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use candle_core::{DType, Device}; use candle_nn::VarBuilder; @@ -168,6 +169,7 @@ impl Qwen3Tts { text: &str, reference_audio: Option<&[f32]>, gen_config: Option, + cancel_flag: Option>, ) -> Result, TtsError> { let config = gen_config.unwrap_or_default(); @@ -178,6 +180,7 @@ impl Qwen3Tts { &self.tokenizer, &self.device, config, + cancel_flag, ); ctx.generate(text, reference_audio) @@ -250,11 +253,12 @@ impl TtsEngine for Qwen3Tts { text: &str, reference_audio: Option<&[f32]>, _reference_sample_rate: Option, + cancel_flag: Option>, ) -> Result, TtsError> { // Note: reference audio should already be resampled to 24kHz // by the caller. If a different sample rate is provided, // the caller should resample using `resample_to_24k()`. - self.generate_speech(text, reference_audio, None) + self.generate_speech(text, reference_audio, None, cancel_flag) } fn is_ready(&self) -> bool { diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs index 551893b..8a1e986 100644 --- a/makima/src/tts/qwen3/model.rs +++ b/makima/src/tts/qwen3/model.rs @@ -10,7 +10,7 @@ //! Based on the candle-transformers Qwen2 model architecture, //! extended for Qwen3-TTS. -use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_core::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use super::config::Qwen3LmConfig; diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs index 752050a..86e00f2 100644 --- a/makima/src/tts/qwen3/speech_tokenizer.rs +++ b/makima/src/tts/qwen3/speech_tokenizer.rs @@ -11,7 +11,7 @@ //! The speech tokenizer is a separate model (~682MB) loaded from //! `Qwen/Qwen3-TTS-Tokenizer-12Hz`. -use candle_core::{DType, Device, Module, Result, Tensor, D}; +use candle_core::{Device, Module, Result, Tensor, D}; use candle_nn::{ conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder, }; @@ -259,6 +259,7 @@ impl DecoderBlock { pub struct RvqCodebook { codebooks: Vec, num_codebooks: usize, + #[allow(dead_code)] codebook_dim: usize, } -- cgit v1.2.3