1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
|
//! Multi-Token Prediction (MTP) code predictor.
//!
//! After the main LM predicts the zeroth codebook token, this module
//! predicts the remaining 15 codebook layers in parallel from the
//! LM's hidden states.
//!
//! Architecture:
//! - 5 transformer layers (same structure as main LM layers)
//! - 16 output heads, one per codebook (vocab 2048 each)
//! - Input: last hidden state from main LM + zeroth codebook embedding
//! - Output: 16 codebook token predictions
use candle_core::{Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use super::config::{CodePredictorConfig, Qwen3LmConfig};
use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};
/// A single code predictor transformer layer.
///
/// Uses the same pre-norm residual structure as the main LM layers.
pub struct CodePredictorLayer {
self_attn: Qwen3Attention,
mlp: Qwen3Mlp,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl CodePredictorLayer {
pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
// Construct a Qwen3LmConfig-like view for the attention/MLP constructors
let lm_config = Qwen3LmConfig {
hidden_size: config.hidden_size,
num_hidden_layers: config.num_layers,
num_attention_heads: config.num_attention_heads,
num_key_value_heads: config.num_attention_heads, // No GQA in predictor
intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
head_dim: config.hidden_size / config.num_attention_heads,
rms_norm_eps: config.rms_norm_eps,
..Qwen3LmConfig::default()
};
let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
let input_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("input_layernorm"),
)?;
let post_attention_layernorm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
rope: &RotaryEmbedding,
kv_cache: &mut KvCache,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = hidden_states;
let hidden_states = self.input_layernorm.forward(hidden_states)?;
let hidden_states =
self.self_attn
.forward(&hidden_states, rope, kv_cache, attention_mask)?;
let hidden_states = (residual + hidden_states)?;
let residual = &hidden_states;
let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
let hidden_states = self.mlp.forward(&hidden_states)?;
let output = (residual + hidden_states)?;
Ok(output)
}
}
/// Multi-token prediction code predictor.
///
/// Takes the hidden states from the main LM and predicts all 16 codebook
/// tokens. The zeroth codebook is predicted by the main LM head; this
/// module predicts the remaining 15 residual codebooks.
pub struct CodePredictor {
/// Embedding layer for codebook tokens (shared across groups).
code_embeddings: Vec<Embedding>,
/// Projection from LM hidden + code embedding to predictor hidden.
input_proj: Linear,
/// 5 transformer layers.
layers: Vec<CodePredictorLayer>,
/// Final normalization.
norm: RmsNorm,
/// Per-codebook output heads (16 heads, each projecting to codebook_vocab_size).
output_heads: Vec<Linear>,
/// RoPE for the predictor's attention layers.
rope: RotaryEmbedding,
config: CodePredictorConfig,
}
impl CodePredictor {
pub fn new(
config: &CodePredictorConfig,
lm_config: &Qwen3LmConfig,
vb: VarBuilder,
) -> Result<Self> {
let predictor_vb = vb.pp("code_predictor");
// Code embeddings for each codebook group
let mut code_embeddings = Vec::with_capacity(config.num_code_groups);
for i in 0..config.num_code_groups {
let emb = embedding(
config.codebook_vocab_size,
config.hidden_size,
predictor_vb.pp(format!("code_embeddings.{i}")),
)?;
code_embeddings.push(emb);
}
// Input projection: LM hidden (1024) + code embedding (1024) -> predictor hidden (1024)
let input_proj = linear_no_bias(
config.hidden_size * 2,
config.hidden_size,
predictor_vb.pp("input_proj"),
)?;
// Transformer layers
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let layer =
CodePredictorLayer::new(config, predictor_vb.pp(format!("layers.{i}")))?;
layers.push(layer);
}
let norm = rms_norm(
config.hidden_size,
config.rms_norm_eps,
predictor_vb.pp("norm"),
)?;
// Output heads for each codebook
let mut output_heads = Vec::with_capacity(config.num_code_groups);
for i in 0..config.num_code_groups {
let head = linear_no_bias(
config.hidden_size,
config.codebook_vocab_size,
predictor_vb.pp(format!("output_heads.{i}")),
)?;
output_heads.push(head);
}
// RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
let predictor_head_dim = config.hidden_size / config.num_attention_heads;
let rope_config = Qwen3LmConfig {
head_dim: predictor_head_dim,
rope_theta: lm_config.rope_theta,
max_position_embeddings: lm_config.max_position_embeddings,
..Qwen3LmConfig::default()
};
let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;
Ok(Self {
code_embeddings,
input_proj,
layers,
norm,
output_heads,
rope,
config: config.clone(),
})
}
/// Predict all 16 codebook tokens from the LM hidden state.
///
/// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
/// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
///
/// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
pub fn predict(
&self,
lm_hidden: &Tensor,
zeroth_code: u32,
device: &Device,
) -> Result<Vec<u32>> {
let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
all_codes.push(zeroth_code);
// The code predictor iterates through codebook groups.
// For each group i (1..16), it:
// 1. Embeds the previous codebook token
// 2. Concatenates with LM hidden state
// 3. Projects through the predictor layers
// 4. Predicts the next codebook token via output_head[i]
let mut prev_code = zeroth_code;
for group_idx in 1..self.config.num_code_groups {
// Embed the previous codebook token
let code_tensor = Tensor::from_vec(
vec![prev_code],
(1, 1),
device,
)?;
let code_emb = self.code_embeddings[group_idx - 1].forward(&code_tensor)?;
// Concatenate LM hidden state with code embedding
let combined = Tensor::cat(&[lm_hidden, &code_emb], D::Minus1)?;
// Project to predictor hidden size
let mut hidden = self.input_proj.forward(&combined)?;
// Run through predictor transformer layers (no KV cache needed — single step)
let mut kv_caches: Vec<KvCache> =
(0..self.config.num_layers).map(|_| KvCache::new()).collect();
for (i, layer) in self.layers.iter().enumerate() {
hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
}
hidden = self.norm.forward(&hidden)?;
// Predict codebook token
let logits = self.output_heads[group_idx].forward(&hidden)?;
// Greedy decode: argmax
let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
let next_code = logits_flat
.argmax(0)?
.to_scalar::<u32>()?;
all_codes.push(next_code);
prev_code = next_code;
}
Ok(all_codes)
}
/// Number of codebook groups.
pub fn num_code_groups(&self) -> usize {
self.config.num_code_groups
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_predictor_config() {
let config = CodePredictorConfig::default();
assert_eq!(config.num_layers, 5);
assert_eq!(config.num_code_groups, 16);
assert_eq!(config.codebook_vocab_size, 2048);
assert_eq!(config.hidden_size, 1024);
}
}
|