summaryrefslogblamecommitdiff
path: root/makima/src/tts/qwen3/mod.rs
blob: 1520be619be9f44482ace1c618530ce34ebf982c (plain) (tree)































                                                                                 
                   



































































                                                                                        



























                                                                                                           

































































                                                                                            
                                             









                                                    
                        







































































                                                                                              
                                             



                                                                     
                                                                      





























                                                                        
//! Qwen3-TTS — Pure Rust implementation using candle.
//!
//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
//! with voice cloning support. No Python, no ONNX — pure Rust inference
//! via the candle ML framework.
//!
//! # Architecture
//!
//! The model has three components:
//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
//!
//! # Usage
//!
//! ```rust,no_run
//! use makima::tts::qwen3::Qwen3Tts;
//! use candle_core::Device;
//!
//! let device = Device::Cpu;
//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
//! // Use via TtsEngine trait or direct API
//! ```

pub mod code_predictor;
pub mod config;
pub mod generate;
pub mod model;
pub mod speech_tokenizer;

use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use tokenizers::Tokenizer;

use self::code_predictor::CodePredictor;
use self::config::Qwen3TtsConfig;
use self::generate::{GenerationConfig, GenerationContext};
use self::model::Qwen3Model;
use self::speech_tokenizer::SpeechTokenizer;
use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};

/// HuggingFace model IDs.
const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";

/// Qwen3-TTS engine — pure Rust candle-based inference.
pub struct Qwen3Tts {
    /// The 28-layer language model.
    model: Qwen3Model,
    /// Multi-token prediction code predictor.
    code_predictor: CodePredictor,
    /// Speech tokenizer (encoder + decoder + RVQ).
    speech_tokenizer: SpeechTokenizer,
    /// Text tokenizer.
    tokenizer: Tokenizer,
    /// Model configuration.
    config: Qwen3TtsConfig,
    /// Compute device (CPU/CUDA/Metal).
    device: Device,
    /// Whether the model is fully loaded and ready.
    ready: AtomicBool,
}

// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
unsafe impl Send for Qwen3Tts {}
unsafe impl Sync for Qwen3Tts {}

impl Qwen3Tts {
    /// Load from a local directory or download from HuggingFace.
    pub fn from_pretrained(
        model_dir: Option<&str>,
        device: &Device,
    ) -> Result<Self, TtsError> {
        let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));

        if !model_path.exists() {
            Self::download_models(&model_path)?;
        }

        Self::load_from_path(&model_path, device)
    }

    /// Load all model components from a local directory.
    pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
        let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU

        // Load configuration
        let config_path = model_dir.join("config.json");
        let config = if config_path.exists() {
            Qwen3TtsConfig::from_json_path(&config_path)?
        } else {
            Qwen3TtsConfig::default()
        };

        // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats)
        let tokenizer_json_path = model_dir.join("tokenizer.json");
        let tokenizer = if tokenizer_json_path.exists() {
            Tokenizer::from_file(&tokenizer_json_path)
                .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))?
        } else {
            // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format)
            let vocab_path = model_dir.join("vocab.json");
            let merges_path = model_dir.join("merges.txt");

            if !vocab_path.exists() || !merges_path.exists() {
                return Err(TtsError::Tokenizer(format!(
                    "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}",
                    model_dir.display()
                )));
            }

            tokenizers::Tokenizer::from_file(&vocab_path)
                .or_else(|_| {
                    // Build BPE tokenizer from vocab and merges
                    use tokenizers::models::bpe::BPE;
                    let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy())
                        .build()
                        .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?;
                    Ok(Tokenizer::new(bpe))
                })
                .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?
        };

        // Load LM weights from safetensors
        let lm_weights_path = model_dir.join("model.safetensors");
        let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
            TtsError::ModelLoad(format!(
                "failed to read LM weights from {}: {e}",
                lm_weights_path.display()
            ))
        })?;
        let lm_vb = VarBuilder::from_buffered_safetensors(
            lm_data,
            dtype,
            device,
        ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;

        // Build language model
        let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
            TtsError::ModelLoad(format!("failed to build LM model: {e}"))
        })?;

        // Build code predictor (weights are in the same safetensors file)
        let code_predictor =
            CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
                TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
            })?;

        // Load speech tokenizer from separate safetensors
        let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
        let st_data = std::fs::read(&st_weights_path).map_err(|e| {
            TtsError::ModelLoad(format!(
                "failed to read speech tokenizer weights from {}: {e}",
                st_weights_path.display()
            ))
        })?;
        let st_vb = VarBuilder::from_buffered_safetensors(
            st_data,
            dtype,
            device,
        ).map_err(|e| {
            TtsError::ModelLoad(format!(
                "failed to create speech tokenizer VarBuilder: {e}"
            ))
        })?;

        let speech_tokenizer =
            SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
                TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
            })?;

        Ok(Self {
            model,
            code_predictor,
            speech_tokenizer,
            tokenizer,
            config,
            device: device.clone(),
            ready: AtomicBool::new(true),
        })
    }

    /// Generate audio from text with optional voice reference.
    pub fn generate_speech(
        &self,
        text: &str,
        reference_audio: Option<&[f32]>,
        gen_config: Option<GenerationConfig>,
        cancel_flag: Option<Arc<AtomicBool>>,
    ) -> Result<Vec<AudioChunk>, TtsError> {
        let config = gen_config.unwrap_or_default();

        let ctx = GenerationContext::new(
            &self.model,
            &self.code_predictor,
            &self.speech_tokenizer,
            &self.tokenizer,
            &self.device,
            config,
            cancel_flag,
        );

        ctx.generate(text, reference_audio)
    }

    /// Download model files from HuggingFace Hub.
    fn download_models(target_dir: &Path) -> Result<(), TtsError> {
        std::fs::create_dir_all(target_dir)?;

        let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;

        // Download LM model files
        println!("Downloading Qwen3-TTS language model...");
        let lm_repo = api.model(LM_MODEL_ID.to_string());

        let lm_files = [
            "model.safetensors",
            "config.json",
            "tokenizer.json",
            "tokenizer_config.json",
        ];

        for file in &lm_files {
            println!("  Downloading {file}...");
            let downloaded = lm_repo
                .get(file)
                .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;

            let target = target_dir.join(file);
            if !target.exists() {
                std::fs::copy(&downloaded, &target)?;
            }
        }

        // Download speech tokenizer
        println!("Downloading Qwen3-TTS speech tokenizer...");
        let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());

        let st_file = "model.safetensors";
        let downloaded = st_repo
            .get(st_file)
            .map_err(|e| {
                TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
            })?;

        let target = target_dir.join("speech_tokenizer.safetensors");
        if !target.exists() {
            std::fs::copy(&downloaded, &target)?;
        }

        println!("All models downloaded to {}", target_dir.display());
        Ok(())
    }

    /// Get the model configuration.
    pub fn config(&self) -> &Qwen3TtsConfig {
        &self.config
    }

    /// Get the compute device.
    pub fn device(&self) -> &Device {
        &self.device
    }
}

#[async_trait::async_trait]
impl TtsEngine for Qwen3Tts {
    async fn generate(
        &self,
        text: &str,
        reference_audio: Option<&[f32]>,
        _reference_sample_rate: Option<u32>,
        cancel_flag: Option<Arc<AtomicBool>>,
    ) -> Result<Vec<AudioChunk>, TtsError> {
        // Note: reference audio should already be resampled to 24kHz
        // by the caller. If a different sample rate is provided,
        // the caller should resample using `resample_to_24k()`.
        self.generate_speech(text, reference_audio, None, cancel_flag)
    }

    fn is_ready(&self) -> bool {
        self.ready.load(Ordering::Relaxed)
    }

    fn sample_rate(&self) -> u32 {
        SAMPLE_RATE
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = Qwen3TtsConfig::default();
        assert_eq!(config.lm.hidden_size, 1024);
        assert_eq!(config.lm.num_hidden_layers, 28);
        assert_eq!(config.code_predictor.num_code_groups, 16);
        assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
    }

    #[test]
    fn test_model_ids() {
        assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
        assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
    }
}