//! Qwen3-TTS — Pure Rust implementation using candle. //! //! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis //! with voice cloning support. No Python, no ONNX — pure Rust inference //! via the candle ML framework. //! //! # Architecture //! //! The model has three components: //! - **Language Model** (28-layer transformer): generates zeroth codebook tokens //! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers //! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes //! //! # Usage //! //! ```rust,no_run //! use makima::tts::qwen3::Qwen3Tts; //! use candle_core::Device; //! //! let device = Device::Cpu; //! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap(); //! // Use via TtsEngine trait or direct API //! ``` pub mod code_predictor; pub mod config; pub mod generate; pub mod model; pub mod speech_tokenizer; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use candle_core::{DType, Device}; use candle_nn::VarBuilder; use hf_hub::api::sync::Api; use tokenizers::Tokenizer; use self::code_predictor::CodePredictor; use self::config::Qwen3TtsConfig; use self::generate::{GenerationConfig, GenerationContext}; use self::model::Qwen3Model; use self::speech_tokenizer::SpeechTokenizer; use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE}; /// HuggingFace model IDs. const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"; const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz"; const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts"; /// Qwen3-TTS engine — pure Rust candle-based inference. pub struct Qwen3Tts { /// The 28-layer language model. model: Qwen3Model, /// Multi-token prediction code predictor. code_predictor: CodePredictor, /// Speech tokenizer (encoder + decoder + RVQ). speech_tokenizer: SpeechTokenizer, /// Text tokenizer. tokenizer: Tokenizer, /// Model configuration. config: Qwen3TtsConfig, /// Compute device (CPU/CUDA/Metal). device: Device, /// Whether the model is fully loaded and ready. ready: AtomicBool, } // SAFETY: All fields are either Send+Sync or behind appropriate synchronization. // candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync. unsafe impl Send for Qwen3Tts {} unsafe impl Sync for Qwen3Tts {} impl Qwen3Tts { /// Load from a local directory or download from HuggingFace. pub fn from_pretrained( model_dir: Option<&str>, device: &Device, ) -> Result { let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); if !model_path.exists() { Self::download_models(&model_path)?; } Self::load_from_path(&model_path, device) } /// Load all model components from a local directory. pub fn load_from_path(model_dir: &Path, device: &Device) -> Result { let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU // Load configuration let config_path = model_dir.join("config.json"); let config = if config_path.exists() { Qwen3TtsConfig::from_json_path(&config_path)? } else { Qwen3TtsConfig::default() }; // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats) let tokenizer_json_path = model_dir.join("tokenizer.json"); let tokenizer = if tokenizer_json_path.exists() { Tokenizer::from_file(&tokenizer_json_path) .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))? } else { // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format) let vocab_path = model_dir.join("vocab.json"); let merges_path = model_dir.join("merges.txt"); if !vocab_path.exists() || !merges_path.exists() { return Err(TtsError::Tokenizer(format!( "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}", model_dir.display() ))); } tokenizers::Tokenizer::from_file(&vocab_path) .or_else(|_| { // Build BPE tokenizer from vocab and merges use tokenizers::models::bpe::BPE; let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy()) .build() .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?; Ok(Tokenizer::new(bpe)) }) .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))? }; // Load LM weights from safetensors let lm_weights_path = model_dir.join("model.safetensors"); let lm_data = std::fs::read(&lm_weights_path).map_err(|e| { TtsError::ModelLoad(format!( "failed to read LM weights from {}: {e}", lm_weights_path.display() )) })?; let lm_vb = VarBuilder::from_buffered_safetensors( lm_data, dtype, device, ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?; // Build language model let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| { TtsError::ModelLoad(format!("failed to build LM model: {e}")) })?; // Build code predictor (weights are in the same safetensors file) let code_predictor = CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| { TtsError::ModelLoad(format!("failed to build code predictor: {e}")) })?; // Load speech tokenizer from separate safetensors let st_weights_path = model_dir.join("speech_tokenizer.safetensors"); let st_data = std::fs::read(&st_weights_path).map_err(|e| { TtsError::ModelLoad(format!( "failed to read speech tokenizer weights from {}: {e}", st_weights_path.display() )) })?; let st_vb = VarBuilder::from_buffered_safetensors( st_data, dtype, device, ).map_err(|e| { TtsError::ModelLoad(format!( "failed to create speech tokenizer VarBuilder: {e}" )) })?; let speech_tokenizer = SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| { TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}")) })?; Ok(Self { model, code_predictor, speech_tokenizer, tokenizer, config, device: device.clone(), ready: AtomicBool::new(true), }) } /// Generate audio from text with optional voice reference. pub fn generate_speech( &self, text: &str, reference_audio: Option<&[f32]>, gen_config: Option, cancel_flag: Option>, ) -> Result, TtsError> { let config = gen_config.unwrap_or_default(); let ctx = GenerationContext::new( &self.model, &self.code_predictor, &self.speech_tokenizer, &self.tokenizer, &self.device, config, cancel_flag, ); ctx.generate(text, reference_audio) } /// Download model files from HuggingFace Hub. fn download_models(target_dir: &Path) -> Result<(), TtsError> { std::fs::create_dir_all(target_dir)?; let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; // Download LM model files println!("Downloading Qwen3-TTS language model..."); let lm_repo = api.model(LM_MODEL_ID.to_string()); let lm_files = [ "model.safetensors", "config.json", "tokenizer.json", "tokenizer_config.json", ]; for file in &lm_files { println!(" Downloading {file}..."); let downloaded = lm_repo .get(file) .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?; let target = target_dir.join(file); if !target.exists() { std::fs::copy(&downloaded, &target)?; } } // Download speech tokenizer println!("Downloading Qwen3-TTS speech tokenizer..."); let st_repo = api.model(TOKENIZER_MODEL_ID.to_string()); let st_file = "model.safetensors"; let downloaded = st_repo .get(st_file) .map_err(|e| { TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}")) })?; let target = target_dir.join("speech_tokenizer.safetensors"); if !target.exists() { std::fs::copy(&downloaded, &target)?; } println!("All models downloaded to {}", target_dir.display()); Ok(()) } /// Get the model configuration. pub fn config(&self) -> &Qwen3TtsConfig { &self.config } /// Get the compute device. pub fn device(&self) -> &Device { &self.device } } #[async_trait::async_trait] impl TtsEngine for Qwen3Tts { async fn generate( &self, text: &str, reference_audio: Option<&[f32]>, _reference_sample_rate: Option, cancel_flag: Option>, ) -> Result, TtsError> { // Note: reference audio should already be resampled to 24kHz // by the caller. If a different sample rate is provided, // the caller should resample using `resample_to_24k()`. self.generate_speech(text, reference_audio, None, cancel_flag) } fn is_ready(&self) -> bool { self.ready.load(Ordering::Relaxed) } fn sample_rate(&self) -> u32 { SAMPLE_RATE } } #[cfg(test)] mod tests { use super::*; #[test] fn test_default_config() { let config = Qwen3TtsConfig::default(); assert_eq!(config.lm.hidden_size, 1024); assert_eq!(config.lm.num_hidden_layers, 28); assert_eq!(config.code_predictor.num_code_groups, 16); assert_eq!(config.speech_tokenizer.sample_rate, 24_000); } #[test] fn test_model_ids() { assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base"); assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz"); } }