//! Qwen3-TTS model configuration. //! //! Parses config.json from the HuggingFace model repository to configure //! the language model, code predictor, and speech tokenizer. use serde::Deserialize; use crate::tts::TtsError; /// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base. #[derive(Debug, Clone, Deserialize)] pub struct Qwen3TtsConfig { /// Language model (talker) configuration. #[serde(default = "Qwen3LmConfig::default")] pub lm: Qwen3LmConfig, /// Code predictor (multi-token prediction) configuration. #[serde(default = "CodePredictorConfig::default")] pub code_predictor: CodePredictorConfig, /// Speech tokenizer configuration. #[serde(default = "SpeechTokenizerConfig::default")] pub speech_tokenizer: SpeechTokenizerConfig, } impl Default for Qwen3TtsConfig { fn default() -> Self { Self { lm: Qwen3LmConfig::default(), code_predictor: CodePredictorConfig::default(), speech_tokenizer: SpeechTokenizerConfig::default(), } } } impl Qwen3TtsConfig { /// Load from a config.json file path. pub fn from_json_path(path: &std::path::Path) -> Result { let content = std::fs::read_to_string(path) .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?; Self::from_json_str(&content) } /// Load from a JSON string. pub fn from_json_str(json: &str) -> Result { // Try to parse the full HuggingFace config.json format first if let Ok(hf_config) = serde_json::from_str::(json) { return Ok(Self::from_hf_config(&hf_config)); } // Fall back to direct deserialization serde_json::from_str(json) .map_err(|e| TtsError::Config(format!("failed to parse config: {e}"))) } /// Convert from HuggingFace's config.json format. fn from_hf_config(hf: &HfConfig) -> Self { Self { lm: Qwen3LmConfig { hidden_size: hf.hidden_size.unwrap_or(1024), num_hidden_layers: hf.num_hidden_layers.unwrap_or(28), num_attention_heads: hf.num_attention_heads.unwrap_or(16), num_key_value_heads: hf.num_key_value_heads.unwrap_or(8), intermediate_size: hf.intermediate_size.unwrap_or(3072), head_dim: hf.head_dim.unwrap_or(128), vocab_size: hf.vocab_size.unwrap_or(151_936), max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768), rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), rope_theta: hf.rope_theta.unwrap_or(1_000_000.0), use_sliding_window: hf.use_sliding_window.unwrap_or(false), sliding_window: hf.sliding_window, hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()), }, code_predictor: CodePredictorConfig { hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024), num_layers: hf.code_predictor_num_layers.unwrap_or(5), num_attention_heads: hf .code_predictor_num_attention_heads .unwrap_or(16), num_code_groups: hf.num_code_groups.unwrap_or(16), codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048), rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), }, speech_tokenizer: SpeechTokenizerConfig::default(), } } } /// Language model configuration (28-layer Qwen3 transformer). #[derive(Debug, Clone, Deserialize)] pub struct Qwen3LmConfig { /// Hidden dimension of transformer layers. pub hidden_size: usize, /// Number of transformer layers. pub num_hidden_layers: usize, /// Number of attention heads. pub num_attention_heads: usize, /// Number of key-value heads (GQA). pub num_key_value_heads: usize, /// Feed-forward intermediate size. pub intermediate_size: usize, /// Dimension per attention head. pub head_dim: usize, /// Text vocabulary size. pub vocab_size: usize, /// Maximum sequence length for RoPE. pub max_position_embeddings: usize, /// RMS normalization epsilon. pub rms_norm_eps: f64, /// RoPE theta parameter. pub rope_theta: f64, /// Whether to use sliding window attention. pub use_sliding_window: bool, /// Sliding window size (if enabled). pub sliding_window: Option, /// Activation function name. pub hidden_act: String, } impl Default for Qwen3LmConfig { fn default() -> Self { Self { hidden_size: 1024, num_hidden_layers: 28, num_attention_heads: 16, num_key_value_heads: 8, intermediate_size: 3072, head_dim: 128, vocab_size: 151_936, max_position_embeddings: 32_768, rms_norm_eps: 1e-6, rope_theta: 1_000_000.0, use_sliding_window: false, sliding_window: None, hidden_act: "silu".to_string(), } } } impl Qwen3LmConfig { /// Number of key-value head groups for GQA. pub fn num_kv_groups(&self) -> usize { self.num_attention_heads / self.num_key_value_heads } } /// Code predictor (multi-token prediction) configuration. #[derive(Debug, Clone, Deserialize)] pub struct CodePredictorConfig { /// Hidden size (matches LM hidden size). pub hidden_size: usize, /// Number of predictor transformer layers. pub num_layers: usize, /// Number of attention heads. pub num_attention_heads: usize, /// Number of codebook groups (residual codebooks). pub num_code_groups: usize, /// Vocabulary size per codebook. pub codebook_vocab_size: usize, /// RMS norm epsilon. pub rms_norm_eps: f64, } impl Default for CodePredictorConfig { fn default() -> Self { Self { hidden_size: 1024, num_layers: 5, num_attention_heads: 16, num_code_groups: 16, codebook_vocab_size: 2048, rms_norm_eps: 1e-6, } } } /// Speech tokenizer (ConvNet codec) configuration. #[derive(Debug, Clone, Deserialize)] pub struct SpeechTokenizerConfig { /// Number of RVQ codebooks. pub num_codebooks: usize, /// Codebook embedding dimension. pub codebook_dim: usize, /// Codebook vocabulary size per layer. pub codebook_size: usize, /// Encoder/decoder hidden channels. pub hidden_channels: usize, /// Output sample rate. pub sample_rate: u32, /// Token frame rate (Hz). pub frame_rate: f32, /// HuggingFace model ID for the speech tokenizer. pub model_id: String, } impl Default for SpeechTokenizerConfig { fn default() -> Self { Self { num_codebooks: 16, codebook_dim: 256, codebook_size: 2048, hidden_channels: 512, sample_rate: 24_000, frame_rate: 12.5, model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(), } } } /// HuggingFace config.json format (partial, fields we need). #[derive(Debug, Deserialize)] struct HfConfig { hidden_size: Option, num_hidden_layers: Option, num_attention_heads: Option, num_key_value_heads: Option, intermediate_size: Option, head_dim: Option, vocab_size: Option, max_position_embeddings: Option, rms_norm_eps: Option, rope_theta: Option, use_sliding_window: Option, sliding_window: Option, hidden_act: Option, // Code predictor specific fields code_predictor_hidden_size: Option, code_predictor_num_layers: Option, code_predictor_num_attention_heads: Option, num_code_groups: Option, codebook_vocab_size: Option, } #[cfg(test)] mod tests { use super::*; #[test] fn test_default_config() { let config = Qwen3TtsConfig::default(); assert_eq!(config.lm.hidden_size, 1024); assert_eq!(config.lm.num_hidden_layers, 28); assert_eq!(config.lm.num_attention_heads, 16); assert_eq!(config.lm.num_key_value_heads, 8); assert_eq!(config.lm.head_dim, 128); assert_eq!(config.lm.num_kv_groups(), 2); assert_eq!(config.code_predictor.num_layers, 5); assert_eq!(config.code_predictor.num_code_groups, 16); assert_eq!(config.speech_tokenizer.num_codebooks, 16); } #[test] fn test_config_from_json() { let json = r#"{ "hidden_size": 1024, "num_hidden_layers": 28, "num_attention_heads": 16, "num_key_value_heads": 8, "intermediate_size": 3072, "vocab_size": 151936, "max_position_embeddings": 32768, "rms_norm_eps": 1e-6, "rope_theta": 1000000.0, "hidden_act": "silu" }"#; let config = Qwen3TtsConfig::from_json_str(json).unwrap(); assert_eq!(config.lm.hidden_size, 1024); assert_eq!(config.lm.num_hidden_layers, 28); assert_eq!(config.lm.vocab_size, 151_936); } }