summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/tts/qwen3/mod.rs')
-rw-r--r--makima/src/tts/qwen3/mod.rs317
1 files changed, 0 insertions, 317 deletions
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs
deleted file mode 100644
index fc6c472..0000000
--- a/makima/src/tts/qwen3/mod.rs
+++ /dev/null
@@ -1,317 +0,0 @@
-//! Qwen3-TTS — Pure Rust implementation using candle.
-//!
-//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
-//! with voice cloning support. No Python, no ONNX — pure Rust inference
-//! via the candle ML framework.
-//!
-//! # Architecture
-//!
-//! The model has three components:
-//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
-//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
-//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
-//!
-//! # Usage
-//!
-//! ```rust,no_run
-//! use makima::tts::qwen3::Qwen3Tts;
-//! use candle_core::Device;
-//!
-//! let device = Device::Cpu;
-//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
-//! // Use via TtsEngine trait or direct API
-//! ```
-
-pub mod code_predictor;
-pub mod config;
-pub mod generate;
-pub mod model;
-pub mod speech_tokenizer;
-
-use std::path::{Path, PathBuf};
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-
-use candle_core::{DType, Device};
-use candle_nn::VarBuilder;
-use hf_hub::api::sync::Api;
-use tokenizers::Tokenizer;
-
-use self::code_predictor::CodePredictor;
-use self::config::Qwen3TtsConfig;
-use self::generate::{GenerationConfig, GenerationContext};
-use self::model::Qwen3Model;
-use self::speech_tokenizer::SpeechTokenizer;
-use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};
-
-/// HuggingFace model IDs.
-const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
-const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
-const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";
-
-/// Qwen3-TTS engine — pure Rust candle-based inference.
-pub struct Qwen3Tts {
- /// The 28-layer language model.
- model: Qwen3Model,
- /// Multi-token prediction code predictor.
- code_predictor: CodePredictor,
- /// Speech tokenizer (encoder + decoder + RVQ).
- speech_tokenizer: SpeechTokenizer,
- /// Text tokenizer.
- tokenizer: Tokenizer,
- /// Model configuration.
- config: Qwen3TtsConfig,
- /// Compute device (CPU/CUDA/Metal).
- device: Device,
- /// Whether the model is fully loaded and ready.
- ready: AtomicBool,
-}
-
-// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
-// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
-unsafe impl Send for Qwen3Tts {}
-unsafe impl Sync for Qwen3Tts {}
-
-impl Qwen3Tts {
- /// Load from a local directory or download from HuggingFace.
- pub fn from_pretrained(
- model_dir: Option<&str>,
- device: &Device,
- ) -> Result<Self, TtsError> {
- let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
-
- if !model_path.exists() {
- Self::download_models(&model_path)?;
- }
-
- Self::load_from_path(&model_path, device)
- }
-
- /// Load all model components from a local directory.
- pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
- let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU
-
- // Load configuration
- let config_path = model_dir.join("config.json");
- let config = if config_path.exists() {
- Qwen3TtsConfig::from_json_path(&config_path)?
- } else {
- Qwen3TtsConfig::default()
- };
-
- // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats)
- let tokenizer_json_path = model_dir.join("tokenizer.json");
- let tokenizer = if tokenizer_json_path.exists() {
- Tokenizer::from_file(&tokenizer_json_path)
- .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))?
- } else {
- // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format)
- let vocab_path = model_dir.join("vocab.json");
- let merges_path = model_dir.join("merges.txt");
-
- if !vocab_path.exists() || !merges_path.exists() {
- return Err(TtsError::Tokenizer(format!(
- "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}",
- model_dir.display()
- )));
- }
-
- tokenizers::Tokenizer::from_file(&vocab_path)
- .or_else(|_| {
- // Build BPE tokenizer from vocab and merges
- use tokenizers::models::bpe::BPE;
- let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy())
- .build()
- .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?;
- Ok(Tokenizer::new(bpe))
- })
- .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?
- };
-
- // Load LM weights from safetensors
- let lm_weights_path = model_dir.join("model.safetensors");
- let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to read LM weights from {}: {e}",
- lm_weights_path.display()
- ))
- })?;
- let lm_vb = VarBuilder::from_buffered_safetensors(
- lm_data,
- dtype,
- device,
- ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;
-
- // Build language model
- let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build LM model: {e}"))
- })?;
-
- // Build code predictor (weights are in the same safetensors file)
- let code_predictor =
- CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
- })?;
-
- // Load speech tokenizer from separate safetensors
- let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
- let st_data = std::fs::read(&st_weights_path).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to read speech tokenizer weights from {}: {e}",
- st_weights_path.display()
- ))
- })?;
- let st_vb = VarBuilder::from_buffered_safetensors(
- st_data,
- dtype,
- device,
- ).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to create speech tokenizer VarBuilder: {e}"
- ))
- })?;
-
- let speech_tokenizer =
- SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
- })?;
-
- Ok(Self {
- model,
- code_predictor,
- speech_tokenizer,
- tokenizer,
- config,
- device: device.clone(),
- ready: AtomicBool::new(true),
- })
- }
-
- /// Generate audio from text with optional voice reference.
- pub fn generate_speech(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- gen_config: Option<GenerationConfig>,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Result<Vec<AudioChunk>, TtsError> {
- let config = gen_config.unwrap_or_default();
-
- let ctx = GenerationContext::new(
- &self.model,
- &self.code_predictor,
- &self.speech_tokenizer,
- &self.tokenizer,
- &self.device,
- config,
- cancel_flag,
- );
-
- ctx.generate(text, reference_audio)
- }
-
- /// Download model files from HuggingFace Hub.
- fn download_models(target_dir: &Path) -> Result<(), TtsError> {
- std::fs::create_dir_all(target_dir)?;
-
- let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;
-
- // Download LM model files
- println!("Downloading Qwen3-TTS language model...");
- let lm_repo = api.model(LM_MODEL_ID.to_string());
-
- // Note: HuggingFace repo has vocab.json + merges.txt instead of tokenizer.json
- let lm_files = [
- "model.safetensors",
- "config.json",
- "vocab.json",
- "merges.txt",
- "tokenizer_config.json",
- ];
-
- for file in &lm_files {
- println!(" Downloading {file}...");
- let downloaded = lm_repo
- .get(file)
- .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;
-
- let target = target_dir.join(file);
- if !target.exists() {
- std::fs::copy(&downloaded, &target)?;
- }
- }
-
- // Download speech tokenizer
- println!("Downloading Qwen3-TTS speech tokenizer...");
- let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());
-
- let st_file = "model.safetensors";
- let downloaded = st_repo
- .get(st_file)
- .map_err(|e| {
- TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
- })?;
-
- let target = target_dir.join("speech_tokenizer.safetensors");
- if !target.exists() {
- std::fs::copy(&downloaded, &target)?;
- }
-
- println!("All models downloaded to {}", target_dir.display());
- Ok(())
- }
-
- /// Get the model configuration.
- pub fn config(&self) -> &Qwen3TtsConfig {
- &self.config
- }
-
- /// Get the compute device.
- pub fn device(&self) -> &Device {
- &self.device
- }
-}
-
-#[async_trait::async_trait]
-impl TtsEngine for Qwen3Tts {
- async fn generate(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- _reference_sample_rate: Option<u32>,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Result<Vec<AudioChunk>, TtsError> {
- // Note: reference audio should already be resampled to 24kHz
- // by the caller. If a different sample rate is provided,
- // the caller should resample using `resample_to_24k()`.
- self.generate_speech(text, reference_audio, None, cancel_flag)
- }
-
- fn is_ready(&self) -> bool {
- self.ready.load(Ordering::Relaxed)
- }
-
- fn sample_rate(&self) -> u32 {
- SAMPLE_RATE
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_default_config() {
- let config = Qwen3TtsConfig::default();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.code_predictor.num_code_groups, 16);
- assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
- }
-
- #[test]
- fn test_model_ids() {
- assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
- assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
- }
-}