summaryrefslogblamecommitdiff
path: root/makima/src/tts/qwen3/code_predictor.rs
blob: 363105f1b24322668462ca9478c16b12dbcab2dd (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12











                                                                       
                                                  














































































                                                                                             
                                                                                    
                                    



                                    
                                                                    











                                                  








                                                                                 


                                           
                                                            



                                      



                                                               
                                                                                     





                                
                                

           



                                                                       


                                           
                                                        















                                                                                                    






















                                                                                       

                                                                               
                                                  


                                                               

                                        
                                                        





                                                
                                                                                  
 

                                                                                      












































                                                                                            
//! Multi-Token Prediction (MTP) code predictor.
//!
//! After the main LM predicts the zeroth codebook token, this module
//! predicts the remaining 15 codebook layers in parallel from the
//! LM's hidden states.
//!
//! Architecture:
//! - 5 transformer layers (same structure as main LM layers)
//! - 16 output heads, one per codebook (vocab 2048 each)
//! - Input: last hidden state from main LM + zeroth codebook embedding
//! - Output: 16 codebook token predictions

use candle_core::{Device, Module, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};

use super::config::{CodePredictorConfig, Qwen3LmConfig};
use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};

/// A single code predictor transformer layer.
///
/// Uses the same pre-norm residual structure as the main LM layers.
pub struct CodePredictorLayer {
    self_attn: Qwen3Attention,
    mlp: Qwen3Mlp,
    input_layernorm: RmsNorm,
    post_attention_layernorm: RmsNorm,
}

impl CodePredictorLayer {
    pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
        // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
        let lm_config = Qwen3LmConfig {
            hidden_size: config.hidden_size,
            num_hidden_layers: config.num_layers,
            num_attention_heads: config.num_attention_heads,
            num_key_value_heads: config.num_attention_heads, // No GQA in predictor
            intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
            head_dim: config.hidden_size / config.num_attention_heads,
            rms_norm_eps: config.rms_norm_eps,
            ..Qwen3LmConfig::default()
        };

        let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
        let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
        let input_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("input_layernorm"),
        )?;
        let post_attention_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("post_attention_layernorm"),
        )?;

        Ok(Self {
            self_attn,
            mlp,
            input_layernorm,
            post_attention_layernorm,
        })
    }

    pub fn forward(
        &self,
        hidden_states: &Tensor,
        rope: &RotaryEmbedding,
        kv_cache: &mut KvCache,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let residual = hidden_states;
        let hidden_states = self.input_layernorm.forward(hidden_states)?;
        let hidden_states =
            self.self_attn
                .forward(&hidden_states, rope, kv_cache, attention_mask)?;
        let hidden_states = (residual + hidden_states)?;

        let residual = &hidden_states;
        let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
        let hidden_states = self.mlp.forward(&hidden_states)?;
        let output = (residual + hidden_states)?;

        Ok(output)
    }
}

/// Multi-token prediction code predictor.
///
/// Takes the hidden states from the main LM and predicts all 16 codebook
/// tokens. The zeroth codebook is predicted by the main LM head; this
/// module predicts the remaining 15 residual codebooks.
pub struct CodePredictor {
    /// Embedding layer for codebook tokens (one per residual codebook group, 0-14).
    code_embeddings: Vec<Embedding>,
    /// 5 transformer layers.
    layers: Vec<CodePredictorLayer>,
    /// Final normalization.
    norm: RmsNorm,
    /// Per-codebook output heads (15 heads for residual codebooks).
    output_heads: Vec<Linear>,
    /// RoPE for the predictor's attention layers.
    rope: RotaryEmbedding,
    config: CodePredictorConfig,
}

impl CodePredictor {
    pub fn new(
        config: &CodePredictorConfig,
        lm_config: &Qwen3LmConfig,
        vb: VarBuilder,
    ) -> Result<Self> {
        // HuggingFace Qwen3-TTS uses "talker.code_predictor.*" prefix
        let predictor_vb = vb.pp("talker").pp("code_predictor");
        let model_vb = predictor_vb.pp("model");

        // Code embeddings for residual codebook groups (15 groups, indices 0-14)
        // HF names them "codec_embedding" not "code_embeddings"
        let num_residual_groups = config.num_code_groups - 1; // 15, not 16
        let mut code_embeddings = Vec::with_capacity(num_residual_groups);
        for i in 0..num_residual_groups {
            let emb = embedding(
                config.codebook_vocab_size,
                config.hidden_size,
                model_vb.pp(format!("codec_embedding.{i}")),
            )?;
            code_embeddings.push(emb);
        }

        // Transformer layers
        let mut layers = Vec::with_capacity(config.num_layers);
        for i in 0..config.num_layers {
            let layer =
                CodePredictorLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
            layers.push(layer);
        }

        let norm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            model_vb.pp("norm"),
        )?;

        // Output heads for residual codebooks (15 heads, indices 0-14)
        // HF names them "lm_head" not "output_heads"
        let mut output_heads = Vec::with_capacity(num_residual_groups);
        for i in 0..num_residual_groups {
            let head = linear_no_bias(
                config.hidden_size,
                config.codebook_vocab_size,
                predictor_vb.pp(format!("lm_head.{i}")),
            )?;
            output_heads.push(head);
        }

        // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
        let predictor_head_dim = config.hidden_size / config.num_attention_heads;
        let rope_config = Qwen3LmConfig {
            head_dim: predictor_head_dim,
            rope_theta: lm_config.rope_theta,
            max_position_embeddings: lm_config.max_position_embeddings,
            ..Qwen3LmConfig::default()
        };
        let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;

        Ok(Self {
            code_embeddings,
            layers,
            norm,
            output_heads,
            rope,
            config: config.clone(),
        })
    }

    /// Predict all 16 codebook tokens from the LM hidden state.
    ///
    /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
    /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
    ///
    /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
    pub fn predict(
        &self,
        lm_hidden: &Tensor,
        zeroth_code: u32,
        device: &Device,
    ) -> Result<Vec<u32>> {
        let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
        all_codes.push(zeroth_code);

        // The code predictor iterates through the 15 residual codebook groups.
        // For each group i (0..15), it:
        //   1. Embeds the previous codebook token
        //   2. Adds to LM hidden state
        //   3. Runs through predictor layers
        //   4. Predicts the next codebook token via lm_head[i]
        let mut prev_code = zeroth_code;

        for group_idx in 0..self.code_embeddings.len() {
            // Embed the previous codebook token
            let code_tensor = Tensor::from_vec(
                vec![prev_code],
                (1, 1),
                device,
            )?;
            let code_emb = self.code_embeddings[group_idx].forward(&code_tensor)?;

            // Add code embedding to LM hidden state (no concatenation, no projection)
            let mut hidden = (lm_hidden + &code_emb)?;

            // Run through predictor transformer layers (no KV cache needed — single step)
            let mut kv_caches: Vec<KvCache> =
                (0..self.config.num_layers).map(|_| KvCache::new()).collect();
            for (i, layer) in self.layers.iter().enumerate() {
                hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
            }

            hidden = self.norm.forward(&hidden)?;

            // Predict codebook token
            let logits = self.output_heads[group_idx].forward(&hidden)?;

            // Greedy decode: argmax
            let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
            let next_code = logits_flat
                .argmax(0)?
                .to_scalar::<u32>()?;

            all_codes.push(next_code);
            prev_code = next_code;
        }

        Ok(all_codes)
    }

    /// Number of codebook groups.
    pub fn num_code_groups(&self) -> usize {
        self.config.num_code_groups
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_code_predictor_config() {
        let config = CodePredictorConfig::default();
        assert_eq!(config.num_layers, 5);
        assert_eq!(config.num_code_groups, 16);
        assert_eq!(config.codebook_vocab_size, 2048);
        assert_eq!(config.hidden_size, 1024);
    }
}