//! Qwen3-TTS — Pure Rust implementation using candle.
//!
//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
//! with voice cloning support. No Python, no ONNX — pure Rust inference
//! via the candle ML framework.
//!
//! # Architecture
//!
//! The model has three components:
//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
//!
//! # Usage
//!
//! ```rust,no_run
//! use makima::tts::qwen3::Qwen3Tts;
//! use candle_core::Device;
//!
//! let device = Device::Cpu;
//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
//! // Use via TtsEngine trait or direct API
//! ```
pub mod code_predictor;
pub mod config;
pub mod generate;
pub mod model;
pub mod speech_tokenizer;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use tokenizers::Tokenizer;
use self::code_predictor::CodePredictor;
use self::config::Qwen3TtsConfig;
use self::generate::{GenerationConfig, GenerationContext};
use self::model::Qwen3Model;
use self::speech_tokenizer::SpeechTokenizer;
use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};
/// HuggingFace model IDs.
const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";
/// Qwen3-TTS engine — pure Rust candle-based inference.
pub struct Qwen3Tts {
/// The 28-layer language model.
model: Qwen3Model,
/// Multi-token prediction code predictor.
code_predictor: CodePredictor,
/// Speech tokenizer (encoder + decoder + RVQ).
speech_tokenizer: SpeechTokenizer,
/// Text tokenizer.
tokenizer: Tokenizer,
/// Model configuration.
config: Qwen3TtsConfig,
/// Compute device (CPU/CUDA/Metal).
device: Device,
/// Whether the model is fully loaded and ready.
ready: AtomicBool,
}
// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
unsafe impl Send for Qwen3Tts {}
unsafe impl Sync for Qwen3Tts {}
impl Qwen3Tts {
/// Load from a local directory or download from HuggingFace.
pub fn from_pretrained(
model_dir: Option<&str>,
device: &Device,
) -> Result<Self, TtsError> {
let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
if !model_path.exists() {
Self::download_models(&model_path)?;
}
Self::load_from_path(&model_path, device)
}
/// Load all model components from a local directory.
pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU
// Load configuration
let config_path = model_dir.join("config.json");
let config = if config_path.exists() {
Qwen3TtsConfig::from_json_path(&config_path)?
} else {
Qwen3TtsConfig::default()
};
// Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats)
let tokenizer_json_path = model_dir.join("tokenizer.json");
let tokenizer = if tokenizer_json_path.exists() {
Tokenizer::from_file(&tokenizer_json_path)
.map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))?
} else {
// Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format)
let vocab_path = model_dir.join("vocab.json");
let merges_path = model_dir.join("merges.txt");
if !vocab_path.exists() || !merges_path.exists() {
return Err(TtsError::Tokenizer(format!(
"tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}",
model_dir.display()
)));
}
tokenizers::Tokenizer::from_file(&vocab_path)
.or_else(|_| {
// Build BPE tokenizer from vocab and merges
use tokenizers::models::bpe::BPE;
let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy())
.build()
.map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?;
Ok(Tokenizer::new(bpe))
})
.map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?
};
// Load LM weights from safetensors
let lm_weights_path = model_dir.join("model.safetensors");
let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
TtsError::ModelLoad(format!(
"failed to read LM weights from {}: {e}",
lm_weights_path.display()
))
})?;
let lm_vb = VarBuilder::from_buffered_safetensors(
lm_data,
dtype,
device,
).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;
// Build language model
let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
TtsError::ModelLoad(format!("failed to build LM model: {e}"))
})?;
// Build code predictor (weights are in the same safetensors file)
let code_predictor =
CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
})?;
// Load speech tokenizer from separate safetensors
let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
let st_data = std::fs::read(&st_weights_path).map_err(|e| {
TtsError::ModelLoad(format!(
"failed to read speech tokenizer weights from {}: {e}",
st_weights_path.display()
))
})?;
let st_vb = VarBuilder::from_buffered_safetensors(
st_data,
dtype,
device,
).map_err(|e| {
TtsError::ModelLoad(format!(
"failed to create speech tokenizer VarBuilder: {e}"
))
})?;
let speech_tokenizer =
SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
})?;
Ok(Self {
model,
code_predictor,
speech_tokenizer,
tokenizer,
config,
device: device.clone(),
ready: AtomicBool::new(true),
})
}
/// Generate audio from text with optional voice reference.
pub fn generate_speech(
&self,
text: &str,
reference_audio: Option<&[f32]>,
gen_config: Option<GenerationConfig>,
cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError> {
let config = gen_config.unwrap_or_default();
let ctx = GenerationContext::new(
&self.model,
&self.code_predictor,
&self.speech_tokenizer,
&self.tokenizer,
&self.device,
config,
cancel_flag,
);
ctx.generate(text, reference_audio)
}
/// Download model files from HuggingFace Hub.
fn download_models(target_dir: &Path) -> Result<(), TtsError> {
std::fs::create_dir_all(target_dir)?;
let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;
// Download LM model files
println!("Downloading Qwen3-TTS language model...");
let lm_repo = api.model(LM_MODEL_ID.to_string());
let lm_files = [
"model.safetensors",
"config.json",
"tokenizer.json",
"tokenizer_config.json",
];
for file in &lm_files {
println!(" Downloading {file}...");
let downloaded = lm_repo
.get(file)
.map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;
let target = target_dir.join(file);
if !target.exists() {
std::fs::copy(&downloaded, &target)?;
}
}
// Download speech tokenizer
println!("Downloading Qwen3-TTS speech tokenizer...");
let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());
let st_file = "model.safetensors";
let downloaded = st_repo
.get(st_file)
.map_err(|e| {
TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
})?;
let target = target_dir.join("speech_tokenizer.safetensors");
if !target.exists() {
std::fs::copy(&downloaded, &target)?;
}
println!("All models downloaded to {}", target_dir.display());
Ok(())
}
/// Get the model configuration.
pub fn config(&self) -> &Qwen3TtsConfig {
&self.config
}
/// Get the compute device.
pub fn device(&self) -> &Device {
&self.device
}
}
#[async_trait::async_trait]
impl TtsEngine for Qwen3Tts {
async fn generate(
&self,
text: &str,
reference_audio: Option<&[f32]>,
_reference_sample_rate: Option<u32>,
cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError> {
// Note: reference audio should already be resampled to 24kHz
// by the caller. If a different sample rate is provided,
// the caller should resample using `resample_to_24k()`.
self.generate_speech(text, reference_audio, None, cancel_flag)
}
fn is_ready(&self) -> bool {
self.ready.load(Ordering::Relaxed)
}
fn sample_rate(&self) -> u32 {
SAMPLE_RATE
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Qwen3TtsConfig::default();
assert_eq!(config.lm.hidden_size, 1024);
assert_eq!(config.lm.num_hidden_layers, 28);
assert_eq!(config.code_predictor.num_code_groups, 16);
assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
}
#[test]
fn test_model_ids() {
assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
}
}