//! Multi-Token Prediction (MTP) code predictor.
//!
//! After the main LM predicts the zeroth codebook token, this module
//! predicts the remaining 15 codebook layers in parallel from the
//! LM's hidden states.
//!
//! Architecture:
//! - 5 transformer layers (same structure as main LM layers)
//! - 16 output heads, one per codebook (vocab 2048 each)
//! - Input: last hidden state from main LM + zeroth codebook embedding
//! - Output: 16 codebook token predictions
use candle_core::{Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use super::config::{CodePredictorConfig, Qwen3LmConfig};
use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};
/// A single code predictor transformer layer.
///
/// Uses the same pre-norm residual structure as the main LM layers.
pub struct CodePredictorLayer {
self_attn: Qwen3Attention,
mlp: Qwen3Mlp,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl CodePredictorLayer {
pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
// Construct a Qwen3LmConfig-like view for the attention/MLP constructors
let lm_config = Qwen3LmConfig {
hidden_size: config.hidden_size,
num_hidden_layers: config.num_layers,
num_attention_heads: config.num_attention_heads,
num_key_value_heads: config.num_attention_heads, // No GQA in predictor
intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
head_dim: config.hidden_size / config.num_attention_heads,
rms_norm_eps: config.rms_norm_eps,
..Qwen3LmConfig::default()
};
let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
let input_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("input_layernorm"),
)?;
let post_attention_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
rope: &RotaryEmbedding,
kv_cache: &mut KvCache,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = hidden_states;
let hidden_states = self.input_layernorm.forward(hidden_states)?;
let hidden_states =
self.self_attn
.forward(&hidden_states, rope, kv_cache, attention_mask)?;
let hidden_states = (residual + hidden_states)?;
let residual = &hidden_states;
let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
let hidden_states = self.mlp.forward(&hidden_states)?;
let output = (residual + hidden_states)?;
Ok(output)
}
}
/// Multi-token prediction code predictor.
///
/// Takes the hidden states from the main LM and predicts all 16 codebook
/// tokens. The zeroth codebook is predicted by the main LM head; this
/// module predicts the remaining 15 residual codebooks.
pub struct CodePredictor {
/// Embedding layer for codebook tokens (shared across groups).
code_embeddings: Vec<Embedding>,
/// Projection from LM hidden + code embedding to predictor hidden.
input_proj: Linear,
/// 5 transformer layers.
layers: Vec<CodePredictorLayer>,
/// Final normalization.
norm: RmsNorm,
/// Per-codebook output heads (16 heads, each projecting to codebook_vocab_size).
output_heads: Vec<Linear>,
/// RoPE for the predictor's attention layers.
rope: RotaryEmbedding,
config: CodePredictorConfig,
}
impl CodePredictor {
pub fn new(
config: &CodePredictorConfig,
lm_config: &Qwen3LmConfig,
vb: VarBuilder,
) -> Result<Self> {
let predictor_vb = vb.pp("code_predictor");
// Code embeddings for each codebook group
let mut code_embeddings = Vec::with_capacity(config.num_code_groups);
for i in 0..config.num_code_groups {
let emb = embedding(
config.codebook_vocab_size,
config.hidden_size,
predictor_vb.pp(format!("code_embeddings.{i}")),
)?;
code_embeddings.push(emb);
}
// Input projection: LM hidden (1024) + code embedding (1024) -> predictor hidden (1024)
let input_proj = linear_no_bias(
config.hidden_size * 2,
config.hidden_size,
predictor_vb.pp("input_proj"),
)?;
// Transformer layers
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let layer =
CodePredictorLayer::new(config, predictor_vb.pp(format!("layers.{i}")))?;
layers.push(layer);
}
let norm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
predictor_vb.pp("norm"),
)?;
// Output heads for each codebook
let mut output_heads = Vec::with_capacity(config.num_code_groups);
for i in 0..config.num_code_groups {
let head = linear_no_bias(
config.hidden_size,
config.codebook_vocab_size,
predictor_vb.pp(format!("output_heads.{i}")),
)?;
output_heads.push(head);
}
// RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
let predictor_head_dim = config.hidden_size / config.num_attention_heads;
let rope_config = Qwen3LmConfig {
head_dim: predictor_head_dim,
rope_theta: lm_config.rope_theta,
max_position_embeddings: lm_config.max_position_embeddings,
..Qwen3LmConfig::default()
};
let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;
Ok(Self {
code_embeddings,
input_proj,
layers,
norm,
output_heads,
rope,
config: config.clone(),
})
}
/// Predict all 16 codebook tokens from the LM hidden state.
///
/// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
/// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
///
/// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
pub fn predict(
&self,
lm_hidden: &Tensor,
zeroth_code: u32,
device: &Device,
) -> Result<Vec<u32>> {
let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
all_codes.push(zeroth_code);
// The code predictor iterates through codebook groups.
// For each group i (1..16), it:
// 1. Embeds the previous codebook token
// 2. Concatenates with LM hidden state
// 3. Projects through the predictor layers
// 4. Predicts the next codebook token via output_head[i]
let mut prev_code = zeroth_code;
for group_idx in 1..self.config.num_code_groups {
// Embed the previous codebook token
let code_tensor = Tensor::from_vec(
vec![prev_code],
(1, 1),
device,
)?;
let code_emb = self.code_embeddings[group_idx - 1].forward(&code_tensor)?;
// Concatenate LM hidden state with code embedding
let combined = Tensor::cat(&[lm_hidden, &code_emb], D::Minus1)?;
// Project to predictor hidden size
let mut hidden = self.input_proj.forward(&combined)?;
// Run through predictor transformer layers (no KV cache needed — single step)
let mut kv_caches: Vec<KvCache> =
(0..self.config.num_layers).map(|_| KvCache::new()).collect();
for (i, layer) in self.layers.iter().enumerate() {
hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
}
hidden = self.norm.forward(&hidden)?;
// Predict codebook token
let logits = self.output_heads[group_idx].forward(&hidden)?;
// Greedy decode: argmax
let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
let next_code = logits_flat
.argmax(0)?
.to_scalar::<u32>()?;
all_codes.push(next_code);
prev_code = next_code;
}
Ok(all_codes)
}
/// Number of codebook groups.
pub fn num_code_groups(&self) -> usize {
self.config.num_code_groups
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_predictor_config() {
let config = CodePredictorConfig::default();
assert_eq!(config.num_layers, 5);
assert_eq!(config.num_code_groups, 16);
assert_eq!(config.codebook_vocab_size, 2048);
assert_eq!(config.hidden_size, 1024);
}
}