summaryrefslogblamecommitdiff
path: root/makima/src/tts/qwen3/config.rs
blob: 6fb55d70e2737e270f3243d4b83e8b6fc8163a18 (plain) (tree)














































































































































































































































































                                                                                        
//! Qwen3-TTS model configuration.
//!
//! Parses config.json from the HuggingFace model repository to configure
//! the language model, code predictor, and speech tokenizer.

use serde::Deserialize;

use crate::tts::TtsError;

/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
#[derive(Debug, Clone, Deserialize)]
pub struct Qwen3TtsConfig {
    /// Language model (talker) configuration.
    #[serde(default = "Qwen3LmConfig::default")]
    pub lm: Qwen3LmConfig,

    /// Code predictor (multi-token prediction) configuration.
    #[serde(default = "CodePredictorConfig::default")]
    pub code_predictor: CodePredictorConfig,

    /// Speech tokenizer configuration.
    #[serde(default = "SpeechTokenizerConfig::default")]
    pub speech_tokenizer: SpeechTokenizerConfig,
}

impl Default for Qwen3TtsConfig {
    fn default() -> Self {
        Self {
            lm: Qwen3LmConfig::default(),
            code_predictor: CodePredictorConfig::default(),
            speech_tokenizer: SpeechTokenizerConfig::default(),
        }
    }
}

impl Qwen3TtsConfig {
    /// Load from a config.json file path.
    pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
        let content = std::fs::read_to_string(path)
            .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
        Self::from_json_str(&content)
    }

    /// Load from a JSON string.
    pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
        // Try to parse the full HuggingFace config.json format first
        if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
            return Ok(Self::from_hf_config(&hf_config));
        }
        // Fall back to direct deserialization
        serde_json::from_str(json)
            .map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
    }

    /// Convert from HuggingFace's config.json format.
    fn from_hf_config(hf: &HfConfig) -> Self {
        Self {
            lm: Qwen3LmConfig {
                hidden_size: hf.hidden_size.unwrap_or(1024),
                num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
                num_attention_heads: hf.num_attention_heads.unwrap_or(16),
                num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
                intermediate_size: hf.intermediate_size.unwrap_or(3072),
                head_dim: hf.head_dim.unwrap_or(128),
                vocab_size: hf.vocab_size.unwrap_or(151_936),
                max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
                rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
                rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
                use_sliding_window: hf.use_sliding_window.unwrap_or(false),
                sliding_window: hf.sliding_window,
                hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
            },
            code_predictor: CodePredictorConfig {
                hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
                num_layers: hf.code_predictor_num_layers.unwrap_or(5),
                num_attention_heads: hf
                    .code_predictor_num_attention_heads
                    .unwrap_or(16),
                num_code_groups: hf.num_code_groups.unwrap_or(16),
                codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
                rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
            },
            speech_tokenizer: SpeechTokenizerConfig::default(),
        }
    }
}

/// Language model configuration (28-layer Qwen3 transformer).
#[derive(Debug, Clone, Deserialize)]
pub struct Qwen3LmConfig {
    /// Hidden dimension of transformer layers.
    pub hidden_size: usize,
    /// Number of transformer layers.
    pub num_hidden_layers: usize,
    /// Number of attention heads.
    pub num_attention_heads: usize,
    /// Number of key-value heads (GQA).
    pub num_key_value_heads: usize,
    /// Feed-forward intermediate size.
    pub intermediate_size: usize,
    /// Dimension per attention head.
    pub head_dim: usize,
    /// Text vocabulary size.
    pub vocab_size: usize,
    /// Maximum sequence length for RoPE.
    pub max_position_embeddings: usize,
    /// RMS normalization epsilon.
    pub rms_norm_eps: f64,
    /// RoPE theta parameter.
    pub rope_theta: f64,
    /// Whether to use sliding window attention.
    pub use_sliding_window: bool,
    /// Sliding window size (if enabled).
    pub sliding_window: Option<usize>,
    /// Activation function name.
    pub hidden_act: String,
}

impl Default for Qwen3LmConfig {
    fn default() -> Self {
        Self {
            hidden_size: 1024,
            num_hidden_layers: 28,
            num_attention_heads: 16,
            num_key_value_heads: 8,
            intermediate_size: 3072,
            head_dim: 128,
            vocab_size: 151_936,
            max_position_embeddings: 32_768,
            rms_norm_eps: 1e-6,
            rope_theta: 1_000_000.0,
            use_sliding_window: false,
            sliding_window: None,
            hidden_act: "silu".to_string(),
        }
    }
}

impl Qwen3LmConfig {
    /// Number of key-value head groups for GQA.
    pub fn num_kv_groups(&self) -> usize {
        self.num_attention_heads / self.num_key_value_heads
    }
}

/// Code predictor (multi-token prediction) configuration.
#[derive(Debug, Clone, Deserialize)]
pub struct CodePredictorConfig {
    /// Hidden size (matches LM hidden size).
    pub hidden_size: usize,
    /// Number of predictor transformer layers.
    pub num_layers: usize,
    /// Number of attention heads.
    pub num_attention_heads: usize,
    /// Number of codebook groups (residual codebooks).
    pub num_code_groups: usize,
    /// Vocabulary size per codebook.
    pub codebook_vocab_size: usize,
    /// RMS norm epsilon.
    pub rms_norm_eps: f64,
}

impl Default for CodePredictorConfig {
    fn default() -> Self {
        Self {
            hidden_size: 1024,
            num_layers: 5,
            num_attention_heads: 16,
            num_code_groups: 16,
            codebook_vocab_size: 2048,
            rms_norm_eps: 1e-6,
        }
    }
}

/// Speech tokenizer (ConvNet codec) configuration.
#[derive(Debug, Clone, Deserialize)]
pub struct SpeechTokenizerConfig {
    /// Number of RVQ codebooks.
    pub num_codebooks: usize,
    /// Codebook embedding dimension.
    pub codebook_dim: usize,
    /// Codebook vocabulary size per layer.
    pub codebook_size: usize,
    /// Encoder/decoder hidden channels.
    pub hidden_channels: usize,
    /// Output sample rate.
    pub sample_rate: u32,
    /// Token frame rate (Hz).
    pub frame_rate: f32,
    /// HuggingFace model ID for the speech tokenizer.
    pub model_id: String,
}

impl Default for SpeechTokenizerConfig {
    fn default() -> Self {
        Self {
            num_codebooks: 16,
            codebook_dim: 256,
            codebook_size: 2048,
            hidden_channels: 512,
            sample_rate: 24_000,
            frame_rate: 12.5,
            model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
        }
    }
}

/// HuggingFace config.json format (partial, fields we need).
#[derive(Debug, Deserialize)]
struct HfConfig {
    hidden_size: Option<usize>,
    num_hidden_layers: Option<usize>,
    num_attention_heads: Option<usize>,
    num_key_value_heads: Option<usize>,
    intermediate_size: Option<usize>,
    head_dim: Option<usize>,
    vocab_size: Option<usize>,
    max_position_embeddings: Option<usize>,
    rms_norm_eps: Option<f64>,
    rope_theta: Option<f64>,
    use_sliding_window: Option<bool>,
    sliding_window: Option<usize>,
    hidden_act: Option<String>,
    // Code predictor specific fields
    code_predictor_hidden_size: Option<usize>,
    code_predictor_num_layers: Option<usize>,
    code_predictor_num_attention_heads: Option<usize>,
    num_code_groups: Option<usize>,
    codebook_vocab_size: Option<usize>,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = Qwen3TtsConfig::default();
        assert_eq!(config.lm.hidden_size, 1024);
        assert_eq!(config.lm.num_hidden_layers, 28);
        assert_eq!(config.lm.num_attention_heads, 16);
        assert_eq!(config.lm.num_key_value_heads, 8);
        assert_eq!(config.lm.head_dim, 128);
        assert_eq!(config.lm.num_kv_groups(), 2);
        assert_eq!(config.code_predictor.num_layers, 5);
        assert_eq!(config.code_predictor.num_code_groups, 16);
        assert_eq!(config.speech_tokenizer.num_codebooks, 16);
    }

    #[test]
    fn test_config_from_json() {
        let json = r#"{
            "hidden_size": 1024,
            "num_hidden_layers": 28,
            "num_attention_heads": 16,
            "num_key_value_heads": 8,
            "intermediate_size": 3072,
            "vocab_size": 151936,
            "max_position_embeddings": 32768,
            "rms_norm_eps": 1e-6,
            "rope_theta": 1000000.0,
            "hidden_act": "silu"
        }"#;

        let config = Qwen3TtsConfig::from_json_str(json).unwrap();
        assert_eq!(config.lm.hidden_size, 1024);
        assert_eq!(config.lm.num_hidden_layers, 28);
        assert_eq!(config.lm.vocab_size, 151_936);
    }
}