//! Qwen3-TTS model configuration.
//!
//! Parses config.json from the HuggingFace model repository to configure
//! the language model, code predictor, and speech tokenizer.
use serde::Deserialize;
use crate::tts::TtsError;
/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
#[derive(Debug, Clone, Deserialize)]
pub struct Qwen3TtsConfig {
/// Language model (talker) configuration.
#[serde(default = "Qwen3LmConfig::default")]
pub lm: Qwen3LmConfig,
/// Code predictor (multi-token prediction) configuration.
#[serde(default = "CodePredictorConfig::default")]
pub code_predictor: CodePredictorConfig,
/// Speech tokenizer configuration.
#[serde(default = "SpeechTokenizerConfig::default")]
pub speech_tokenizer: SpeechTokenizerConfig,
}
impl Default for Qwen3TtsConfig {
fn default() -> Self {
Self {
lm: Qwen3LmConfig::default(),
code_predictor: CodePredictorConfig::default(),
speech_tokenizer: SpeechTokenizerConfig::default(),
}
}
}
impl Qwen3TtsConfig {
/// Load from a config.json file path.
pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
let content = std::fs::read_to_string(path)
.map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
Self::from_json_str(&content)
}
/// Load from a JSON string.
pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
// Try to parse the full HuggingFace config.json format first
if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
return Ok(Self::from_hf_config(&hf_config));
}
// Fall back to direct deserialization
serde_json::from_str(json)
.map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
}
/// Convert from HuggingFace's config.json format.
fn from_hf_config(hf: &HfConfig) -> Self {
Self {
lm: Qwen3LmConfig {
hidden_size: hf.hidden_size.unwrap_or(1024),
num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
num_attention_heads: hf.num_attention_heads.unwrap_or(16),
num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
intermediate_size: hf.intermediate_size.unwrap_or(3072),
head_dim: hf.head_dim.unwrap_or(128),
vocab_size: hf.vocab_size.unwrap_or(151_936),
max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
use_sliding_window: hf.use_sliding_window.unwrap_or(false),
sliding_window: hf.sliding_window,
hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
},
code_predictor: CodePredictorConfig {
hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
num_layers: hf.code_predictor_num_layers.unwrap_or(5),
num_attention_heads: hf
.code_predictor_num_attention_heads
.unwrap_or(16),
num_code_groups: hf.num_code_groups.unwrap_or(16),
codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
},
speech_tokenizer: SpeechTokenizerConfig::default(),
}
}
}
/// Language model configuration (28-layer Qwen3 transformer).
#[derive(Debug, Clone, Deserialize)]
pub struct Qwen3LmConfig {
/// Hidden dimension of transformer layers.
pub hidden_size: usize,
/// Number of transformer layers.
pub num_hidden_layers: usize,
/// Number of attention heads.
pub num_attention_heads: usize,
/// Number of key-value heads (GQA).
pub num_key_value_heads: usize,
/// Feed-forward intermediate size.
pub intermediate_size: usize,
/// Dimension per attention head.
pub head_dim: usize,
/// Text vocabulary size.
pub vocab_size: usize,
/// Maximum sequence length for RoPE.
pub max_position_embeddings: usize,
/// RMS normalization epsilon.
pub rms_norm_eps: f64,
/// RoPE theta parameter.
pub rope_theta: f64,
/// Whether to use sliding window attention.
pub use_sliding_window: bool,
/// Sliding window size (if enabled).
pub sliding_window: Option<usize>,
/// Activation function name.
pub hidden_act: String,
}
impl Default for Qwen3LmConfig {
fn default() -> Self {
Self {
hidden_size: 1024,
num_hidden_layers: 28,
num_attention_heads: 16,
num_key_value_heads: 8,
intermediate_size: 3072,
head_dim: 128,
vocab_size: 151_936,
max_position_embeddings: 32_768,
rms_norm_eps: 1e-6,
rope_theta: 1_000_000.0,
use_sliding_window: false,
sliding_window: None,
hidden_act: "silu".to_string(),
}
}
}
impl Qwen3LmConfig {
/// Number of key-value head groups for GQA.
pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
}
/// Code predictor (multi-token prediction) configuration.
#[derive(Debug, Clone, Deserialize)]
pub struct CodePredictorConfig {
/// Hidden size (matches LM hidden size).
pub hidden_size: usize,
/// Number of predictor transformer layers.
pub num_layers: usize,
/// Number of attention heads.
pub num_attention_heads: usize,
/// Number of codebook groups (residual codebooks).
pub num_code_groups: usize,
/// Vocabulary size per codebook.
pub codebook_vocab_size: usize,
/// RMS norm epsilon.
pub rms_norm_eps: f64,
}
impl Default for CodePredictorConfig {
fn default() -> Self {
Self {
hidden_size: 1024,
num_layers: 5,
num_attention_heads: 16,
num_code_groups: 16,
codebook_vocab_size: 2048,
rms_norm_eps: 1e-6,
}
}
}
/// Speech tokenizer (ConvNet codec) configuration.
#[derive(Debug, Clone, Deserialize)]
pub struct SpeechTokenizerConfig {
/// Number of RVQ codebooks.
pub num_codebooks: usize,
/// Codebook embedding dimension.
pub codebook_dim: usize,
/// Codebook vocabulary size per layer.
pub codebook_size: usize,
/// Encoder/decoder hidden channels.
pub hidden_channels: usize,
/// Output sample rate.
pub sample_rate: u32,
/// Token frame rate (Hz).
pub frame_rate: f32,
/// HuggingFace model ID for the speech tokenizer.
pub model_id: String,
}
impl Default for SpeechTokenizerConfig {
fn default() -> Self {
Self {
num_codebooks: 16,
codebook_dim: 256,
codebook_size: 2048,
hidden_channels: 512,
sample_rate: 24_000,
frame_rate: 12.5,
model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
}
}
}
/// HuggingFace config.json format (partial, fields we need).
#[derive(Debug, Deserialize)]
struct HfConfig {
hidden_size: Option<usize>,
num_hidden_layers: Option<usize>,
num_attention_heads: Option<usize>,
num_key_value_heads: Option<usize>,
intermediate_size: Option<usize>,
head_dim: Option<usize>,
vocab_size: Option<usize>,
max_position_embeddings: Option<usize>,
rms_norm_eps: Option<f64>,
rope_theta: Option<f64>,
use_sliding_window: Option<bool>,
sliding_window: Option<usize>,
hidden_act: Option<String>,
// Code predictor specific fields
code_predictor_hidden_size: Option<usize>,
code_predictor_num_layers: Option<usize>,
code_predictor_num_attention_heads: Option<usize>,
num_code_groups: Option<usize>,
codebook_vocab_size: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Qwen3TtsConfig::default();
assert_eq!(config.lm.hidden_size, 1024);
assert_eq!(config.lm.num_hidden_layers, 28);
assert_eq!(config.lm.num_attention_heads, 16);
assert_eq!(config.lm.num_key_value_heads, 8);
assert_eq!(config.lm.head_dim, 128);
assert_eq!(config.lm.num_kv_groups(), 2);
assert_eq!(config.code_predictor.num_layers, 5);
assert_eq!(config.code_predictor.num_code_groups, 16);
assert_eq!(config.speech_tokenizer.num_codebooks, 16);
}
#[test]
fn test_config_from_json() {
let json = r#"{
"hidden_size": 1024,
"num_hidden_layers": 28,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"intermediate_size": 3072,
"vocab_size": 151936,
"max_position_embeddings": 32768,
"rms_norm_eps": 1e-6,
"rope_theta": 1000000.0,
"hidden_act": "silu"
}"#;
let config = Qwen3TtsConfig::from_json_str(json).unwrap();
assert_eq!(config.lm.hidden_size, 1024);
assert_eq!(config.lm.num_hidden_layers, 28);
assert_eq!(config.lm.vocab_size, 151_936);
}
}