diff options
Diffstat (limited to 'makima/src/tts/qwen3/config.rs')
| -rw-r--r-- | makima/src/tts/qwen3/config.rs | 271 |
1 files changed, 0 insertions, 271 deletions
diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs deleted file mode 100644 index 6fb55d7..0000000 --- a/makima/src/tts/qwen3/config.rs +++ /dev/null @@ -1,271 +0,0 @@ -//! Qwen3-TTS model configuration. -//! -//! Parses config.json from the HuggingFace model repository to configure -//! the language model, code predictor, and speech tokenizer. - -use serde::Deserialize; - -use crate::tts::TtsError; - -/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base. -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3TtsConfig { - /// Language model (talker) configuration. - #[serde(default = "Qwen3LmConfig::default")] - pub lm: Qwen3LmConfig, - - /// Code predictor (multi-token prediction) configuration. - #[serde(default = "CodePredictorConfig::default")] - pub code_predictor: CodePredictorConfig, - - /// Speech tokenizer configuration. - #[serde(default = "SpeechTokenizerConfig::default")] - pub speech_tokenizer: SpeechTokenizerConfig, -} - -impl Default for Qwen3TtsConfig { - fn default() -> Self { - Self { - lm: Qwen3LmConfig::default(), - code_predictor: CodePredictorConfig::default(), - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -impl Qwen3TtsConfig { - /// Load from a config.json file path. - pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> { - let content = std::fs::read_to_string(path) - .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?; - Self::from_json_str(&content) - } - - /// Load from a JSON string. - pub fn from_json_str(json: &str) -> Result<Self, TtsError> { - // Try to parse the full HuggingFace config.json format first - if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) { - return Ok(Self::from_hf_config(&hf_config)); - } - // Fall back to direct deserialization - serde_json::from_str(json) - .map_err(|e| TtsError::Config(format!("failed to parse config: {e}"))) - } - - /// Convert from HuggingFace's config.json format. - fn from_hf_config(hf: &HfConfig) -> Self { - Self { - lm: Qwen3LmConfig { - hidden_size: hf.hidden_size.unwrap_or(1024), - num_hidden_layers: hf.num_hidden_layers.unwrap_or(28), - num_attention_heads: hf.num_attention_heads.unwrap_or(16), - num_key_value_heads: hf.num_key_value_heads.unwrap_or(8), - intermediate_size: hf.intermediate_size.unwrap_or(3072), - head_dim: hf.head_dim.unwrap_or(128), - vocab_size: hf.vocab_size.unwrap_or(151_936), - max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - rope_theta: hf.rope_theta.unwrap_or(1_000_000.0), - use_sliding_window: hf.use_sliding_window.unwrap_or(false), - sliding_window: hf.sliding_window, - hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()), - }, - code_predictor: CodePredictorConfig { - hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024), - num_layers: hf.code_predictor_num_layers.unwrap_or(5), - num_attention_heads: hf - .code_predictor_num_attention_heads - .unwrap_or(16), - num_code_groups: hf.num_code_groups.unwrap_or(16), - codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - }, - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -/// Language model configuration (28-layer Qwen3 transformer). -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3LmConfig { - /// Hidden dimension of transformer layers. - pub hidden_size: usize, - /// Number of transformer layers. - pub num_hidden_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of key-value heads (GQA). - pub num_key_value_heads: usize, - /// Feed-forward intermediate size. - pub intermediate_size: usize, - /// Dimension per attention head. - pub head_dim: usize, - /// Text vocabulary size. - pub vocab_size: usize, - /// Maximum sequence length for RoPE. - pub max_position_embeddings: usize, - /// RMS normalization epsilon. - pub rms_norm_eps: f64, - /// RoPE theta parameter. - pub rope_theta: f64, - /// Whether to use sliding window attention. - pub use_sliding_window: bool, - /// Sliding window size (if enabled). - pub sliding_window: Option<usize>, - /// Activation function name. - pub hidden_act: String, -} - -impl Default for Qwen3LmConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_hidden_layers: 28, - num_attention_heads: 16, - num_key_value_heads: 8, - intermediate_size: 3072, - head_dim: 128, - vocab_size: 151_936, - max_position_embeddings: 32_768, - rms_norm_eps: 1e-6, - rope_theta: 1_000_000.0, - use_sliding_window: false, - sliding_window: None, - hidden_act: "silu".to_string(), - } - } -} - -impl Qwen3LmConfig { - /// Number of key-value head groups for GQA. - pub fn num_kv_groups(&self) -> usize { - self.num_attention_heads / self.num_key_value_heads - } -} - -/// Code predictor (multi-token prediction) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct CodePredictorConfig { - /// Hidden size (matches LM hidden size). - pub hidden_size: usize, - /// Number of predictor transformer layers. - pub num_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of codebook groups (residual codebooks). - pub num_code_groups: usize, - /// Vocabulary size per codebook. - pub codebook_vocab_size: usize, - /// RMS norm epsilon. - pub rms_norm_eps: f64, -} - -impl Default for CodePredictorConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_layers: 5, - num_attention_heads: 16, - num_code_groups: 16, - codebook_vocab_size: 2048, - rms_norm_eps: 1e-6, - } - } -} - -/// Speech tokenizer (ConvNet codec) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct SpeechTokenizerConfig { - /// Number of RVQ codebooks. - pub num_codebooks: usize, - /// Codebook embedding dimension. - pub codebook_dim: usize, - /// Codebook vocabulary size per layer. - pub codebook_size: usize, - /// Encoder/decoder hidden channels. - pub hidden_channels: usize, - /// Output sample rate. - pub sample_rate: u32, - /// Token frame rate (Hz). - pub frame_rate: f32, - /// HuggingFace model ID for the speech tokenizer. - pub model_id: String, -} - -impl Default for SpeechTokenizerConfig { - fn default() -> Self { - Self { - num_codebooks: 16, - codebook_dim: 256, - codebook_size: 2048, - hidden_channels: 512, - sample_rate: 24_000, - frame_rate: 12.5, - model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(), - } - } -} - -/// HuggingFace config.json format (partial, fields we need). -#[derive(Debug, Deserialize)] -struct HfConfig { - hidden_size: Option<usize>, - num_hidden_layers: Option<usize>, - num_attention_heads: Option<usize>, - num_key_value_heads: Option<usize>, - intermediate_size: Option<usize>, - head_dim: Option<usize>, - vocab_size: Option<usize>, - max_position_embeddings: Option<usize>, - rms_norm_eps: Option<f64>, - rope_theta: Option<f64>, - use_sliding_window: Option<bool>, - sliding_window: Option<usize>, - hidden_act: Option<String>, - // Code predictor specific fields - code_predictor_hidden_size: Option<usize>, - code_predictor_num_layers: Option<usize>, - code_predictor_num_attention_heads: Option<usize>, - num_code_groups: Option<usize>, - codebook_vocab_size: Option<usize>, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.num_attention_heads, 16); - assert_eq!(config.lm.num_key_value_heads, 8); - assert_eq!(config.lm.head_dim, 128); - assert_eq!(config.lm.num_kv_groups(), 2); - assert_eq!(config.code_predictor.num_layers, 5); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.num_codebooks, 16); - } - - #[test] - fn test_config_from_json() { - let json = r#"{ - "hidden_size": 1024, - "num_hidden_layers": 28, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 3072, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rms_norm_eps": 1e-6, - "rope_theta": 1000000.0, - "hidden_act": "silu" - }"#; - - let config = Qwen3TtsConfig::from_json_str(json).unwrap(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.vocab_size, 151_936); - } -} |
