summaryrefslogblamecommitdiff
path: root/makima/src/tts/qwen3/code_predictor.rs
blob: 0ef8a1d2c66ea6c0845f42f7c219337fc5a711aa (plain) (tree)




































































































































































































































































                                                                                                    
//! Multi-Token Prediction (MTP) code predictor.
//!
//! After the main LM predicts the zeroth codebook token, this module
//! predicts the remaining 15 codebook layers in parallel from the
//! LM's hidden states.
//!
//! Architecture:
//! - 5 transformer layers (same structure as main LM layers)
//! - 16 output heads, one per codebook (vocab 2048 each)
//! - Input: last hidden state from main LM + zeroth codebook embedding
//! - Output: 16 codebook token predictions

use candle_core::{Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};

use super::config::{CodePredictorConfig, Qwen3LmConfig};
use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};

/// A single code predictor transformer layer.
///
/// Uses the same pre-norm residual structure as the main LM layers.
pub struct CodePredictorLayer {
    self_attn: Qwen3Attention,
    mlp: Qwen3Mlp,
    input_layernorm: RmsNorm,
    post_attention_layernorm: RmsNorm,
}

impl CodePredictorLayer {
    pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
        // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
        let lm_config = Qwen3LmConfig {
            hidden_size: config.hidden_size,
            num_hidden_layers: config.num_layers,
            num_attention_heads: config.num_attention_heads,
            num_key_value_heads: config.num_attention_heads, // No GQA in predictor
            intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
            head_dim: config.hidden_size / config.num_attention_heads,
            rms_norm_eps: config.rms_norm_eps,
            ..Qwen3LmConfig::default()
        };

        let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
        let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
        let input_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("input_layernorm"),
        )?;
        let post_attention_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("post_attention_layernorm"),
        )?;

        Ok(Self {
            self_attn,
            mlp,
            input_layernorm,
            post_attention_layernorm,
        })
    }

    pub fn forward(
        &self,
        hidden_states: &Tensor,
        rope: &RotaryEmbedding,
        kv_cache: &mut KvCache,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let residual = hidden_states;
        let hidden_states = self.input_layernorm.forward(hidden_states)?;
        let hidden_states =
            self.self_attn
                .forward(&hidden_states, rope, kv_cache, attention_mask)?;
        let hidden_states = (residual + hidden_states)?;

        let residual = &hidden_states;
        let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
        let hidden_states = self.mlp.forward(&hidden_states)?;
        let output = (residual + hidden_states)?;

        Ok(output)
    }
}

/// Multi-token prediction code predictor.
///
/// Takes the hidden states from the main LM and predicts all 16 codebook
/// tokens. The zeroth codebook is predicted by the main LM head; this
/// module predicts the remaining 15 residual codebooks.
pub struct CodePredictor {
    /// Embedding layer for codebook tokens (shared across groups).
    code_embeddings: Vec<Embedding>,
    /// Projection from LM hidden + code embedding to predictor hidden.
    input_proj: Linear,
    /// 5 transformer layers.
    layers: Vec<CodePredictorLayer>,
    /// Final normalization.
    norm: RmsNorm,
    /// Per-codebook output heads (16 heads, each projecting to codebook_vocab_size).
    output_heads: Vec<Linear>,
    /// RoPE for the predictor's attention layers.
    rope: RotaryEmbedding,
    config: CodePredictorConfig,
}

impl CodePredictor {
    pub fn new(
        config: &CodePredictorConfig,
        lm_config: &Qwen3LmConfig,
        vb: VarBuilder,
    ) -> Result<Self> {
        let predictor_vb = vb.pp("code_predictor");

        // Code embeddings for each codebook group
        let mut code_embeddings = Vec::with_capacity(config.num_code_groups);
        for i in 0..config.num_code_groups {
            let emb = embedding(
                config.codebook_vocab_size,
                config.hidden_size,
                predictor_vb.pp(format!("code_embeddings.{i}")),
            )?;
            code_embeddings.push(emb);
        }

        // Input projection: LM hidden (1024) + code embedding (1024) -> predictor hidden (1024)
        let input_proj = linear_no_bias(
            config.hidden_size * 2,
            config.hidden_size,
            predictor_vb.pp("input_proj"),
        )?;

        // Transformer layers
        let mut layers = Vec::with_capacity(config.num_layers);
        for i in 0..config.num_layers {
            let layer =
                CodePredictorLayer::new(config, predictor_vb.pp(format!("layers.{i}")))?;
            layers.push(layer);
        }

        let norm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            predictor_vb.pp("norm"),
        )?;

        // Output heads for each codebook
        let mut output_heads = Vec::with_capacity(config.num_code_groups);
        for i in 0..config.num_code_groups {
            let head = linear_no_bias(
                config.hidden_size,
                config.codebook_vocab_size,
                predictor_vb.pp(format!("output_heads.{i}")),
            )?;
            output_heads.push(head);
        }

        // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
        let predictor_head_dim = config.hidden_size / config.num_attention_heads;
        let rope_config = Qwen3LmConfig {
            head_dim: predictor_head_dim,
            rope_theta: lm_config.rope_theta,
            max_position_embeddings: lm_config.max_position_embeddings,
            ..Qwen3LmConfig::default()
        };
        let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;

        Ok(Self {
            code_embeddings,
            input_proj,
            layers,
            norm,
            output_heads,
            rope,
            config: config.clone(),
        })
    }

    /// Predict all 16 codebook tokens from the LM hidden state.
    ///
    /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
    /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
    ///
    /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
    pub fn predict(
        &self,
        lm_hidden: &Tensor,
        zeroth_code: u32,
        device: &Device,
    ) -> Result<Vec<u32>> {
        let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
        all_codes.push(zeroth_code);

        // The code predictor iterates through codebook groups.
        // For each group i (1..16), it:
        //   1. Embeds the previous codebook token
        //   2. Concatenates with LM hidden state
        //   3. Projects through the predictor layers
        //   4. Predicts the next codebook token via output_head[i]
        let mut prev_code = zeroth_code;

        for group_idx in 1..self.config.num_code_groups {
            // Embed the previous codebook token
            let code_tensor = Tensor::from_vec(
                vec![prev_code],
                (1, 1),
                device,
            )?;
            let code_emb = self.code_embeddings[group_idx - 1].forward(&code_tensor)?;

            // Concatenate LM hidden state with code embedding
            let combined = Tensor::cat(&[lm_hidden, &code_emb], D::Minus1)?;

            // Project to predictor hidden size
            let mut hidden = self.input_proj.forward(&combined)?;

            // Run through predictor transformer layers (no KV cache needed — single step)
            let mut kv_caches: Vec<KvCache> =
                (0..self.config.num_layers).map(|_| KvCache::new()).collect();
            for (i, layer) in self.layers.iter().enumerate() {
                hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
            }

            hidden = self.norm.forward(&hidden)?;

            // Predict codebook token
            let logits = self.output_heads[group_idx].forward(&hidden)?;

            // Greedy decode: argmax
            let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
            let next_code = logits_flat
                .argmax(0)?
                .to_scalar::<u32>()?;

            all_codes.push(next_code);
            prev_code = next_code;
        }

        Ok(all_codes)
    }

    /// Number of codebook groups.
    pub fn num_code_groups(&self) -> usize {
        self.config.num_code_groups
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_code_predictor_config() {
        let config = CodePredictorConfig::default();
        assert_eq!(config.num_layers, 5);
        assert_eq!(config.num_code_groups, 16);
        assert_eq!(config.codebook_vocab_size, 2048);
        assert_eq!(config.hidden_size, 1024);
    }
}