summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/code_predictor.rs
blob: 363105f1b24322668462ca9478c16b12dbcab2dd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
//! Multi-Token Prediction (MTP) code predictor.
//!
//! After the main LM predicts the zeroth codebook token, this module
//! predicts the remaining 15 codebook layers in parallel from the
//! LM's hidden states.
//!
//! Architecture:
//! - 5 transformer layers (same structure as main LM layers)
//! - 16 output heads, one per codebook (vocab 2048 each)
//! - Input: last hidden state from main LM + zeroth codebook embedding
//! - Output: 16 codebook token predictions

use candle_core::{Device, Module, Result, Tensor};
use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};

use super::config::{CodePredictorConfig, Qwen3LmConfig};
use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};

/// A single code predictor transformer layer.
///
/// Uses the same pre-norm residual structure as the main LM layers.
pub struct CodePredictorLayer {
    self_attn: Qwen3Attention,
    mlp: Qwen3Mlp,
    input_layernorm: RmsNorm,
    post_attention_layernorm: RmsNorm,
}

impl CodePredictorLayer {
    pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
        // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
        let lm_config = Qwen3LmConfig {
            hidden_size: config.hidden_size,
            num_hidden_layers: config.num_layers,
            num_attention_heads: config.num_attention_heads,
            num_key_value_heads: config.num_attention_heads, // No GQA in predictor
            intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
            head_dim: config.hidden_size / config.num_attention_heads,
            rms_norm_eps: config.rms_norm_eps,
            ..Qwen3LmConfig::default()
        };

        let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
        let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
        let input_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("input_layernorm"),
        )?;
        let post_attention_layernorm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            vb.pp("post_attention_layernorm"),
        )?;

        Ok(Self {
            self_attn,
            mlp,
            input_layernorm,
            post_attention_layernorm,
        })
    }

    pub fn forward(
        &self,
        hidden_states: &Tensor,
        rope: &RotaryEmbedding,
        kv_cache: &mut KvCache,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let residual = hidden_states;
        let hidden_states = self.input_layernorm.forward(hidden_states)?;
        let hidden_states =
            self.self_attn
                .forward(&hidden_states, rope, kv_cache, attention_mask)?;
        let hidden_states = (residual + hidden_states)?;

        let residual = &hidden_states;
        let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
        let hidden_states = self.mlp.forward(&hidden_states)?;
        let output = (residual + hidden_states)?;

        Ok(output)
    }
}

/// Multi-token prediction code predictor.
///
/// Takes the hidden states from the main LM and predicts all 16 codebook
/// tokens. The zeroth codebook is predicted by the main LM head; this
/// module predicts the remaining 15 residual codebooks.
pub struct CodePredictor {
    /// Embedding layer for codebook tokens (one per residual codebook group, 0-14).
    code_embeddings: Vec<Embedding>,
    /// 5 transformer layers.
    layers: Vec<CodePredictorLayer>,
    /// Final normalization.
    norm: RmsNorm,
    /// Per-codebook output heads (15 heads for residual codebooks).
    output_heads: Vec<Linear>,
    /// RoPE for the predictor's attention layers.
    rope: RotaryEmbedding,
    config: CodePredictorConfig,
}

impl CodePredictor {
    pub fn new(
        config: &CodePredictorConfig,
        lm_config: &Qwen3LmConfig,
        vb: VarBuilder,
    ) -> Result<Self> {
        // HuggingFace Qwen3-TTS uses "talker.code_predictor.*" prefix
        let predictor_vb = vb.pp("talker").pp("code_predictor");
        let model_vb = predictor_vb.pp("model");

        // Code embeddings for residual codebook groups (15 groups, indices 0-14)
        // HF names them "codec_embedding" not "code_embeddings"
        let num_residual_groups = config.num_code_groups - 1; // 15, not 16
        let mut code_embeddings = Vec::with_capacity(num_residual_groups);
        for i in 0..num_residual_groups {
            let emb = embedding(
                config.codebook_vocab_size,
                config.hidden_size,
                model_vb.pp(format!("codec_embedding.{i}")),
            )?;
            code_embeddings.push(emb);
        }

        // Transformer layers
        let mut layers = Vec::with_capacity(config.num_layers);
        for i in 0..config.num_layers {
            let layer =
                CodePredictorLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
            layers.push(layer);
        }

        let norm = rms_norm(
            config.hidden_size,
            config.rms_norm_eps,
            model_vb.pp("norm"),
        )?;

        // Output heads for residual codebooks (15 heads, indices 0-14)
        // HF names them "lm_head" not "output_heads"
        let mut output_heads = Vec::with_capacity(num_residual_groups);
        for i in 0..num_residual_groups {
            let head = linear_no_bias(
                config.hidden_size,
                config.codebook_vocab_size,
                predictor_vb.pp(format!("lm_head.{i}")),
            )?;
            output_heads.push(head);
        }

        // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
        let predictor_head_dim = config.hidden_size / config.num_attention_heads;
        let rope_config = Qwen3LmConfig {
            head_dim: predictor_head_dim,
            rope_theta: lm_config.rope_theta,
            max_position_embeddings: lm_config.max_position_embeddings,
            ..Qwen3LmConfig::default()
        };
        let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;

        Ok(Self {
            code_embeddings,
            layers,
            norm,
            output_heads,
            rope,
            config: config.clone(),
        })
    }

    /// Predict all 16 codebook tokens from the LM hidden state.
    ///
    /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
    /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
    ///
    /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
    pub fn predict(
        &self,
        lm_hidden: &Tensor,
        zeroth_code: u32,
        device: &Device,
    ) -> Result<Vec<u32>> {
        let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
        all_codes.push(zeroth_code);

        // The code predictor iterates through the 15 residual codebook groups.
        // For each group i (0..15), it:
        //   1. Embeds the previous codebook token
        //   2. Adds to LM hidden state
        //   3. Runs through predictor layers
        //   4. Predicts the next codebook token via lm_head[i]
        let mut prev_code = zeroth_code;

        for group_idx in 0..self.code_embeddings.len() {
            // Embed the previous codebook token
            let code_tensor = Tensor::from_vec(
                vec![prev_code],
                (1, 1),
                device,
            )?;
            let code_emb = self.code_embeddings[group_idx].forward(&code_tensor)?;

            // Add code embedding to LM hidden state (no concatenation, no projection)
            let mut hidden = (lm_hidden + &code_emb)?;

            // Run through predictor transformer layers (no KV cache needed — single step)
            let mut kv_caches: Vec<KvCache> =
                (0..self.config.num_layers).map(|_| KvCache::new()).collect();
            for (i, layer) in self.layers.iter().enumerate() {
                hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
            }

            hidden = self.norm.forward(&hidden)?;

            // Predict codebook token
            let logits = self.output_heads[group_idx].forward(&hidden)?;

            // Greedy decode: argmax
            let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
            let next_code = logits_flat
                .argmax(0)?
                .to_scalar::<u32>()?;

            all_codes.push(next_code);
            prev_code = next_code;
        }

        Ok(all_codes)
    }

    /// Number of codebook groups.
    pub fn num_code_groups(&self) -> usize {
        self.config.num_code_groups
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_code_predictor_config() {
        let config = CodePredictorConfig::default();
        assert_eq!(config.num_layers, 5);
        assert_eq!(config.num_code_groups, 16);
        assert_eq!(config.codebook_vocab_size, 2048);
        assert_eq!(config.hidden_size, 1024);
    }
}