diff options
| author | soryu <soryu@soryu.co> | 2026-02-02 22:52:05 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2026-02-02 22:52:05 +0000 |
| commit | 0f06a7f9968816e5e2553c4f1c2104f2fa504f96 (patch) | |
| tree | 53d8db119c17d7d22f3127ae5a54e12a3f384e29 /makima/src/tts/qwen3/mod.rs | |
| parent | 151e9d87e117b7980e6aad522ac8f3633eeca87a (diff) | |
| download | soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.tar.gz soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.zip | |
Release in makima repo
Also remove all other TTS models
Diffstat (limited to 'makima/src/tts/qwen3/mod.rs')
| -rw-r--r-- | makima/src/tts/qwen3/mod.rs | 317 |
1 files changed, 0 insertions, 317 deletions
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs deleted file mode 100644 index fc6c472..0000000 --- a/makima/src/tts/qwen3/mod.rs +++ /dev/null @@ -1,317 +0,0 @@ -//! Qwen3-TTS — Pure Rust implementation using candle. -//! -//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis -//! with voice cloning support. No Python, no ONNX — pure Rust inference -//! via the candle ML framework. -//! -//! # Architecture -//! -//! The model has three components: -//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens -//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers -//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes -//! -//! # Usage -//! -//! ```rust,no_run -//! use makima::tts::qwen3::Qwen3Tts; -//! use candle_core::Device; -//! -//! let device = Device::Cpu; -//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap(); -//! // Use via TtsEngine trait or direct API -//! ``` - -pub mod code_predictor; -pub mod config; -pub mod generate; -pub mod model; -pub mod speech_tokenizer; - -use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device}; -use candle_nn::VarBuilder; -use hf_hub::api::sync::Api; -use tokenizers::Tokenizer; - -use self::code_predictor::CodePredictor; -use self::config::Qwen3TtsConfig; -use self::generate::{GenerationConfig, GenerationContext}; -use self::model::Qwen3Model; -use self::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE}; - -/// HuggingFace model IDs. -const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"; -const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz"; -const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts"; - -/// Qwen3-TTS engine — pure Rust candle-based inference. -pub struct Qwen3Tts { - /// The 28-layer language model. - model: Qwen3Model, - /// Multi-token prediction code predictor. - code_predictor: CodePredictor, - /// Speech tokenizer (encoder + decoder + RVQ). - speech_tokenizer: SpeechTokenizer, - /// Text tokenizer. - tokenizer: Tokenizer, - /// Model configuration. - config: Qwen3TtsConfig, - /// Compute device (CPU/CUDA/Metal). - device: Device, - /// Whether the model is fully loaded and ready. - ready: AtomicBool, -} - -// SAFETY: All fields are either Send+Sync or behind appropriate synchronization. -// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync. -unsafe impl Send for Qwen3Tts {} -unsafe impl Sync for Qwen3Tts {} - -impl Qwen3Tts { - /// Load from a local directory or download from HuggingFace. - pub fn from_pretrained( - model_dir: Option<&str>, - device: &Device, - ) -> Result<Self, TtsError> { - let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); - - if !model_path.exists() { - Self::download_models(&model_path)?; - } - - Self::load_from_path(&model_path, device) - } - - /// Load all model components from a local directory. - pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> { - let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU - - // Load configuration - let config_path = model_dir.join("config.json"); - let config = if config_path.exists() { - Qwen3TtsConfig::from_json_path(&config_path)? - } else { - Qwen3TtsConfig::default() - }; - - // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats) - let tokenizer_json_path = model_dir.join("tokenizer.json"); - let tokenizer = if tokenizer_json_path.exists() { - Tokenizer::from_file(&tokenizer_json_path) - .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))? - } else { - // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format) - let vocab_path = model_dir.join("vocab.json"); - let merges_path = model_dir.join("merges.txt"); - - if !vocab_path.exists() || !merges_path.exists() { - return Err(TtsError::Tokenizer(format!( - "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}", - model_dir.display() - ))); - } - - tokenizers::Tokenizer::from_file(&vocab_path) - .or_else(|_| { - // Build BPE tokenizer from vocab and merges - use tokenizers::models::bpe::BPE; - let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy()) - .build() - .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?; - Ok(Tokenizer::new(bpe)) - }) - .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))? - }; - - // Load LM weights from safetensors - let lm_weights_path = model_dir.join("model.safetensors"); - let lm_data = std::fs::read(&lm_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read LM weights from {}: {e}", - lm_weights_path.display() - )) - })?; - let lm_vb = VarBuilder::from_buffered_safetensors( - lm_data, - dtype, - device, - ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?; - - // Build language model - let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| { - TtsError::ModelLoad(format!("failed to build LM model: {e}")) - })?; - - // Build code predictor (weights are in the same safetensors file) - let code_predictor = - CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| { - TtsError::ModelLoad(format!("failed to build code predictor: {e}")) - })?; - - // Load speech tokenizer from separate safetensors - let st_weights_path = model_dir.join("speech_tokenizer.safetensors"); - let st_data = std::fs::read(&st_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read speech tokenizer weights from {}: {e}", - st_weights_path.display() - )) - })?; - let st_vb = VarBuilder::from_buffered_safetensors( - st_data, - dtype, - device, - ).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to create speech tokenizer VarBuilder: {e}" - )) - })?; - - let speech_tokenizer = - SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| { - TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}")) - })?; - - Ok(Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - config, - device: device.clone(), - ready: AtomicBool::new(true), - }) - } - - /// Generate audio from text with optional voice reference. - pub fn generate_speech( - &self, - text: &str, - reference_audio: Option<&[f32]>, - gen_config: Option<GenerationConfig>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - let config = gen_config.unwrap_or_default(); - - let ctx = GenerationContext::new( - &self.model, - &self.code_predictor, - &self.speech_tokenizer, - &self.tokenizer, - &self.device, - config, - cancel_flag, - ); - - ctx.generate(text, reference_audio) - } - - /// Download model files from HuggingFace Hub. - fn download_models(target_dir: &Path) -> Result<(), TtsError> { - std::fs::create_dir_all(target_dir)?; - - let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; - - // Download LM model files - println!("Downloading Qwen3-TTS language model..."); - let lm_repo = api.model(LM_MODEL_ID.to_string()); - - // Note: HuggingFace repo has vocab.json + merges.txt instead of tokenizer.json - let lm_files = [ - "model.safetensors", - "config.json", - "vocab.json", - "merges.txt", - "tokenizer_config.json", - ]; - - for file in &lm_files { - println!(" Downloading {file}..."); - let downloaded = lm_repo - .get(file) - .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?; - - let target = target_dir.join(file); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - } - - // Download speech tokenizer - println!("Downloading Qwen3-TTS speech tokenizer..."); - let st_repo = api.model(TOKENIZER_MODEL_ID.to_string()); - - let st_file = "model.safetensors"; - let downloaded = st_repo - .get(st_file) - .map_err(|e| { - TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}")) - })?; - - let target = target_dir.join("speech_tokenizer.safetensors"); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - - println!("All models downloaded to {}", target_dir.display()); - Ok(()) - } - - /// Get the model configuration. - pub fn config(&self) -> &Qwen3TtsConfig { - &self.config - } - - /// Get the compute device. - pub fn device(&self) -> &Device { - &self.device - } -} - -#[async_trait::async_trait] -impl TtsEngine for Qwen3Tts { - async fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - _reference_sample_rate: Option<u32>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - // Note: reference audio should already be resampled to 24kHz - // by the caller. If a different sample rate is provided, - // the caller should resample using `resample_to_24k()`. - self.generate_speech(text, reference_audio, None, cancel_flag) - } - - fn is_ready(&self) -> bool { - self.ready.load(Ordering::Relaxed) - } - - fn sample_rate(&self) -> u32 { - SAMPLE_RATE - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.sample_rate, 24_000); - } - - #[test] - fn test_model_ids() { - assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base"); - assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz"); - } -} |
