summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/config.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-02-02 22:52:05 +0000
committersoryu <soryu@soryu.co>2026-02-02 22:52:05 +0000
commit0f06a7f9968816e5e2553c4f1c2104f2fa504f96 (patch)
tree53d8db119c17d7d22f3127ae5a54e12a3f384e29 /makima/src/tts/qwen3/config.rs
parent151e9d87e117b7980e6aad522ac8f3633eeca87a (diff)
downloadsoryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.tar.gz
soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.zip
Release in makima repo
Also remove all other TTS models
Diffstat (limited to 'makima/src/tts/qwen3/config.rs')
-rw-r--r--makima/src/tts/qwen3/config.rs271
1 files changed, 0 insertions, 271 deletions
diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs
deleted file mode 100644
index 6fb55d7..0000000
--- a/makima/src/tts/qwen3/config.rs
+++ /dev/null
@@ -1,271 +0,0 @@
-//! Qwen3-TTS model configuration.
-//!
-//! Parses config.json from the HuggingFace model repository to configure
-//! the language model, code predictor, and speech tokenizer.
-
-use serde::Deserialize;
-
-use crate::tts::TtsError;
-
-/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
-#[derive(Debug, Clone, Deserialize)]
-pub struct Qwen3TtsConfig {
- /// Language model (talker) configuration.
- #[serde(default = "Qwen3LmConfig::default")]
- pub lm: Qwen3LmConfig,
-
- /// Code predictor (multi-token prediction) configuration.
- #[serde(default = "CodePredictorConfig::default")]
- pub code_predictor: CodePredictorConfig,
-
- /// Speech tokenizer configuration.
- #[serde(default = "SpeechTokenizerConfig::default")]
- pub speech_tokenizer: SpeechTokenizerConfig,
-}
-
-impl Default for Qwen3TtsConfig {
- fn default() -> Self {
- Self {
- lm: Qwen3LmConfig::default(),
- code_predictor: CodePredictorConfig::default(),
- speech_tokenizer: SpeechTokenizerConfig::default(),
- }
- }
-}
-
-impl Qwen3TtsConfig {
- /// Load from a config.json file path.
- pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
- let content = std::fs::read_to_string(path)
- .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
- Self::from_json_str(&content)
- }
-
- /// Load from a JSON string.
- pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
- // Try to parse the full HuggingFace config.json format first
- if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
- return Ok(Self::from_hf_config(&hf_config));
- }
- // Fall back to direct deserialization
- serde_json::from_str(json)
- .map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
- }
-
- /// Convert from HuggingFace's config.json format.
- fn from_hf_config(hf: &HfConfig) -> Self {
- Self {
- lm: Qwen3LmConfig {
- hidden_size: hf.hidden_size.unwrap_or(1024),
- num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
- num_attention_heads: hf.num_attention_heads.unwrap_or(16),
- num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
- intermediate_size: hf.intermediate_size.unwrap_or(3072),
- head_dim: hf.head_dim.unwrap_or(128),
- vocab_size: hf.vocab_size.unwrap_or(151_936),
- max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
- rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
- rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
- use_sliding_window: hf.use_sliding_window.unwrap_or(false),
- sliding_window: hf.sliding_window,
- hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
- },
- code_predictor: CodePredictorConfig {
- hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
- num_layers: hf.code_predictor_num_layers.unwrap_or(5),
- num_attention_heads: hf
- .code_predictor_num_attention_heads
- .unwrap_or(16),
- num_code_groups: hf.num_code_groups.unwrap_or(16),
- codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
- rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
- },
- speech_tokenizer: SpeechTokenizerConfig::default(),
- }
- }
-}
-
-/// Language model configuration (28-layer Qwen3 transformer).
-#[derive(Debug, Clone, Deserialize)]
-pub struct Qwen3LmConfig {
- /// Hidden dimension of transformer layers.
- pub hidden_size: usize,
- /// Number of transformer layers.
- pub num_hidden_layers: usize,
- /// Number of attention heads.
- pub num_attention_heads: usize,
- /// Number of key-value heads (GQA).
- pub num_key_value_heads: usize,
- /// Feed-forward intermediate size.
- pub intermediate_size: usize,
- /// Dimension per attention head.
- pub head_dim: usize,
- /// Text vocabulary size.
- pub vocab_size: usize,
- /// Maximum sequence length for RoPE.
- pub max_position_embeddings: usize,
- /// RMS normalization epsilon.
- pub rms_norm_eps: f64,
- /// RoPE theta parameter.
- pub rope_theta: f64,
- /// Whether to use sliding window attention.
- pub use_sliding_window: bool,
- /// Sliding window size (if enabled).
- pub sliding_window: Option<usize>,
- /// Activation function name.
- pub hidden_act: String,
-}
-
-impl Default for Qwen3LmConfig {
- fn default() -> Self {
- Self {
- hidden_size: 1024,
- num_hidden_layers: 28,
- num_attention_heads: 16,
- num_key_value_heads: 8,
- intermediate_size: 3072,
- head_dim: 128,
- vocab_size: 151_936,
- max_position_embeddings: 32_768,
- rms_norm_eps: 1e-6,
- rope_theta: 1_000_000.0,
- use_sliding_window: false,
- sliding_window: None,
- hidden_act: "silu".to_string(),
- }
- }
-}
-
-impl Qwen3LmConfig {
- /// Number of key-value head groups for GQA.
- pub fn num_kv_groups(&self) -> usize {
- self.num_attention_heads / self.num_key_value_heads
- }
-}
-
-/// Code predictor (multi-token prediction) configuration.
-#[derive(Debug, Clone, Deserialize)]
-pub struct CodePredictorConfig {
- /// Hidden size (matches LM hidden size).
- pub hidden_size: usize,
- /// Number of predictor transformer layers.
- pub num_layers: usize,
- /// Number of attention heads.
- pub num_attention_heads: usize,
- /// Number of codebook groups (residual codebooks).
- pub num_code_groups: usize,
- /// Vocabulary size per codebook.
- pub codebook_vocab_size: usize,
- /// RMS norm epsilon.
- pub rms_norm_eps: f64,
-}
-
-impl Default for CodePredictorConfig {
- fn default() -> Self {
- Self {
- hidden_size: 1024,
- num_layers: 5,
- num_attention_heads: 16,
- num_code_groups: 16,
- codebook_vocab_size: 2048,
- rms_norm_eps: 1e-6,
- }
- }
-}
-
-/// Speech tokenizer (ConvNet codec) configuration.
-#[derive(Debug, Clone, Deserialize)]
-pub struct SpeechTokenizerConfig {
- /// Number of RVQ codebooks.
- pub num_codebooks: usize,
- /// Codebook embedding dimension.
- pub codebook_dim: usize,
- /// Codebook vocabulary size per layer.
- pub codebook_size: usize,
- /// Encoder/decoder hidden channels.
- pub hidden_channels: usize,
- /// Output sample rate.
- pub sample_rate: u32,
- /// Token frame rate (Hz).
- pub frame_rate: f32,
- /// HuggingFace model ID for the speech tokenizer.
- pub model_id: String,
-}
-
-impl Default for SpeechTokenizerConfig {
- fn default() -> Self {
- Self {
- num_codebooks: 16,
- codebook_dim: 256,
- codebook_size: 2048,
- hidden_channels: 512,
- sample_rate: 24_000,
- frame_rate: 12.5,
- model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
- }
- }
-}
-
-/// HuggingFace config.json format (partial, fields we need).
-#[derive(Debug, Deserialize)]
-struct HfConfig {
- hidden_size: Option<usize>,
- num_hidden_layers: Option<usize>,
- num_attention_heads: Option<usize>,
- num_key_value_heads: Option<usize>,
- intermediate_size: Option<usize>,
- head_dim: Option<usize>,
- vocab_size: Option<usize>,
- max_position_embeddings: Option<usize>,
- rms_norm_eps: Option<f64>,
- rope_theta: Option<f64>,
- use_sliding_window: Option<bool>,
- sliding_window: Option<usize>,
- hidden_act: Option<String>,
- // Code predictor specific fields
- code_predictor_hidden_size: Option<usize>,
- code_predictor_num_layers: Option<usize>,
- code_predictor_num_attention_heads: Option<usize>,
- num_code_groups: Option<usize>,
- codebook_vocab_size: Option<usize>,
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_default_config() {
- let config = Qwen3TtsConfig::default();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.lm.num_attention_heads, 16);
- assert_eq!(config.lm.num_key_value_heads, 8);
- assert_eq!(config.lm.head_dim, 128);
- assert_eq!(config.lm.num_kv_groups(), 2);
- assert_eq!(config.code_predictor.num_layers, 5);
- assert_eq!(config.code_predictor.num_code_groups, 16);
- assert_eq!(config.speech_tokenizer.num_codebooks, 16);
- }
-
- #[test]
- fn test_config_from_json() {
- let json = r#"{
- "hidden_size": 1024,
- "num_hidden_layers": 28,
- "num_attention_heads": 16,
- "num_key_value_heads": 8,
- "intermediate_size": 3072,
- "vocab_size": 151936,
- "max_position_embeddings": 32768,
- "rms_norm_eps": 1e-6,
- "rope_theta": 1000000.0,
- "hidden_act": "silu"
- }"#;
-
- let config = Qwen3TtsConfig::from_json_str(json).unwrap();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.lm.vocab_size, 151_936);
- }
-}