//! Qwen3 Language Model transformer backbone. //! //! Implements the 28-layer transformer with: //! - Rotary Position Embeddings (RoPE) //! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads //! - SiLU-gated MLP //! - RMS normalization //! - KV cache for autoregressive generation //! //! Based on the candle-transformers Qwen2 model architecture, //! extended for Qwen3-TTS. use candle_core::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; use super::config::Qwen3LmConfig; // --------------------------------------------------------------------------- // Rotary Position Embeddings // --------------------------------------------------------------------------- /// Precomputed RoPE sin/cos tables. #[derive(Debug, Clone)] pub struct RotaryEmbedding { cos: Tensor, sin: Tensor, } impl RotaryEmbedding { pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result { let head_dim = config.head_dim; let max_seq = config.max_position_embeddings; let theta = config.rope_theta; let inv_freq: Vec = (0..head_dim) .step_by(2) .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32)) .collect(); let inv_freq_tensor = Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?; let positions: Vec = (0..max_seq).map(|p| p as f32).collect(); let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?; // [max_seq, head_dim/2] let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?; // [max_seq, head_dim] by repeating let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; let cos = emb.cos()?.to_dtype(dtype)?; let sin = emb.sin()?.to_dtype(dtype)?; Ok(Self { cos, sin }) } /// Apply RoPE to query and key tensors. /// Input shape: [batch, heads, seq_len, head_dim] pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let seq_len = q.dim(2)?; let cos = self.cos.narrow(0, offset, seq_len)?; let sin = self.sin.narrow(0, offset, seq_len)?; let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim] let sin = sin.unsqueeze(0)?.unsqueeze(0)?; let q_rotated = Self::rotate_half(q, &cos, &sin)?; let k_rotated = Self::rotate_half(k, &cos, &sin)?; Ok((q_rotated, k_rotated)) } fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { let half_dim = x.dim(D::Minus1)? / 2; let x1 = x.narrow(D::Minus1, 0, half_dim)?; let x2 = x.narrow(D::Minus1, half_dim, half_dim)?; // [-x2, x1] concatenated let neg_x2 = x2.neg()?; let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?; // x * cos + rotated * sin let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?; Ok(result) } } // --------------------------------------------------------------------------- // KV Cache // --------------------------------------------------------------------------- /// Per-layer key-value cache for autoregressive generation. #[derive(Debug, Clone)] pub struct KvCache { key: Option, value: Option, } impl KvCache { pub fn new() -> Self { Self { key: None, value: None, } } /// Append new key/value tensors and return the full cached sequence. /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim] pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> { let (full_key, full_value) = match (&self.key, &self.value) { (Some(prev_k), Some(prev_v)) => { let k = Tensor::cat(&[prev_k, key], 2)?; let v = Tensor::cat(&[prev_v, value], 2)?; (k, v) } _ => (key.clone(), value.clone()), }; self.key = Some(full_key.clone()); self.value = Some(full_value.clone()); Ok((full_key, full_value)) } /// Current cached sequence length. pub fn seq_len(&self) -> usize { self.key .as_ref() .map(|k| k.dim(2).unwrap_or(0)) .unwrap_or(0) } /// Reset the cache. pub fn reset(&mut self) { self.key = None; self.value = None; } } // --------------------------------------------------------------------------- // Attention // --------------------------------------------------------------------------- /// Multi-head attention with GQA and RoPE. pub struct Qwen3Attention { q_proj: Linear, k_proj: Linear, v_proj: Linear, o_proj: Linear, q_norm: RmsNorm, k_norm: RmsNorm, num_heads: usize, num_kv_heads: usize, head_dim: usize, num_kv_groups: usize, } impl Qwen3Attention { pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result { let hidden = config.hidden_size; let num_heads = config.num_attention_heads; let num_kv_heads = config.num_key_value_heads; let head_dim = config.head_dim; let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?; let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?; let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?; let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?; let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?; Ok(Self { q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, num_heads, num_kv_heads, head_dim, num_kv_groups: config.num_kv_groups(), }) } /// Forward pass with KV cache and RoPE. /// Input: [batch, seq_len, hidden_size] /// Returns: [batch, seq_len, hidden_size] pub fn forward( &self, hidden_states: &Tensor, rope: &RotaryEmbedding, kv_cache: &mut KvCache, attention_mask: Option<&Tensor>, ) -> Result { let (batch, seq_len, _) = hidden_states.dims3()?; let offset = kv_cache.seq_len(); // Project Q, K, V let q = self.q_proj.forward(hidden_states)?; let k = self.k_proj.forward(hidden_states)?; let v = self.v_proj.forward(hidden_states)?; // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim] let q = q .reshape((batch, seq_len, self.num_heads, self.head_dim))? .transpose(1, 2)?; let k = k .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; let v = v .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? .transpose(1, 2)?; // Apply QK normalization (Qwen3 specific) let q = self.apply_head_norm(&q, &self.q_norm)?; let k = self.apply_head_norm(&k, &self.k_norm)?; // Apply RoPE let (q, k) = rope.apply(&q, &k, offset)?; // Update KV cache let (k, v) = kv_cache.append(&k, &v)?; // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim] let k = self.repeat_kv(&k)?; let v = self.repeat_kv(&v)?; // Scaled dot-product attention let scale = (self.head_dim as f64).sqrt(); let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?; let attn_weights = match attention_mask { Some(mask) => attn_weights.broadcast_add(mask)?, None => attn_weights, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; // Attention output let attn_output = attn_weights.matmul(&v)?; // [batch, heads, seq, dim] -> [batch, seq, heads*dim] let attn_output = attn_output .transpose(1, 2)? .reshape((batch, seq_len, self.num_heads * self.head_dim))?; self.o_proj.forward(&attn_output) } /// Apply RMS norm per-head. fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result { let (b, h, s, d) = x.dims4()?; // Reshape to [b*h*s, d] for norm, then back let flat = x.reshape((b * h * s, d))?; let normed = norm.forward(&flat)?; normed.reshape((b, h, s, d)) } /// Repeat KV heads for GQA. fn repeat_kv(&self, x: &Tensor) -> Result { if self.num_kv_groups == 1 { return Ok(x.clone()); } let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?; let x = x .unsqueeze(2)? .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))? .reshape((batch, self.num_heads, seq_len, head_dim))?; Ok(x) } } // --------------------------------------------------------------------------- // MLP // --------------------------------------------------------------------------- /// SiLU-gated feed-forward network. pub struct Qwen3Mlp { gate_proj: Linear, up_proj: Linear, down_proj: Linear, } impl Qwen3Mlp { pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result { let hidden = config.hidden_size; let intermediate = config.intermediate_size; let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?; let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?; let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?; Ok(Self { gate_proj, up_proj, down_proj, }) } pub fn forward(&self, x: &Tensor) -> Result { let gate = self.gate_proj.forward(x)?; let gate = candle_nn::Activation::Silu.forward(&gate)?; let up = self.up_proj.forward(x)?; let hidden = (gate * up)?; self.down_proj.forward(&hidden) } } // --------------------------------------------------------------------------- // Transformer Layer // --------------------------------------------------------------------------- /// A single Qwen3 transformer decoder layer. pub struct Qwen3DecoderLayer { self_attn: Qwen3Attention, mlp: Qwen3Mlp, input_layernorm: RmsNorm, post_attention_layernorm: RmsNorm, } impl Qwen3DecoderLayer { pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result { let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?; let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?; let input_layernorm = rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = rms_norm( config.hidden_size, config.rms_norm_eps, vb.pp("post_attention_layernorm"), )?; Ok(Self { self_attn, mlp, input_layernorm, post_attention_layernorm, }) } pub fn forward( &self, hidden_states: &Tensor, rope: &RotaryEmbedding, kv_cache: &mut KvCache, attention_mask: Option<&Tensor>, ) -> Result { // Pre-norm attention let residual = hidden_states; let hidden_states = self.input_layernorm.forward(hidden_states)?; let hidden_states = self.self_attn .forward(&hidden_states, rope, kv_cache, attention_mask)?; let hidden_states = (residual + hidden_states)?; // Pre-norm MLP let residual = &hidden_states; let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; let hidden_states = self.mlp.forward(&hidden_states)?; let output = (residual + hidden_states)?; Ok(output) } } // --------------------------------------------------------------------------- // Full Model // --------------------------------------------------------------------------- /// The complete Qwen3 language model for TTS. /// /// Architecture: /// - Token embedding layer /// - 28 transformer decoder layers /// - Final RMS normalization /// - LM head (projects to vocab) pub struct Qwen3Model { embed_tokens: Embedding, layers: Vec, norm: RmsNorm, lm_head: Linear, rope: RotaryEmbedding, config: Qwen3LmConfig, /// Last hidden states (before lm_head), used by code predictor. last_hidden: std::cell::RefCell>, } impl Qwen3Model { pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result { // HuggingFace Qwen3-TTS uses "talker.model.*" prefix let talker_vb = vb.pp("talker"); let model_vb = talker_vb.pp("model"); // Text embedding (called "text_embedding" in HF, not "embed_tokens") let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?; let mut layers = Vec::with_capacity(config.num_hidden_layers); for i in 0..config.num_hidden_layers { let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?; layers.push(layer); } let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?; // Codec head (called "codec_head" in HF, not "lm_head") let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?; let dtype = vb.dtype(); let device = vb.device().clone(); let rope = RotaryEmbedding::new(config, dtype, &device)?; Ok(Self { embed_tokens, layers, norm, lm_head, rope, config: config.clone(), last_hidden: std::cell::RefCell::new(None), }) } /// Forward pass through the full model. /// /// `input_ids`: [batch, seq_len] — token IDs /// `kv_caches`: per-layer KV caches /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len] /// /// Returns logits: [batch, seq_len, vocab_size] pub fn forward( &self, input_ids: &Tensor, kv_caches: &mut [KvCache], attention_mask: Option<&Tensor>, ) -> Result { let mut hidden_states = self.embed_tokens.forward(input_ids)?; for (i, layer) in self.layers.iter().enumerate() { hidden_states = layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; } hidden_states = self.norm.forward(&hidden_states)?; // Store last hidden state for code predictor *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); let logits = self.lm_head.forward(&hidden_states)?; Ok(logits) } /// Forward pass with pre-computed embeddings (for first iteration where /// text embeddings are concatenated with audio features). /// /// `inputs_embeds`: [batch, seq_len, hidden_size] pub fn forward_embeds( &self, inputs_embeds: &Tensor, kv_caches: &mut [KvCache], attention_mask: Option<&Tensor>, ) -> Result { let mut hidden_states = inputs_embeds.clone(); for (i, layer) in self.layers.iter().enumerate() { hidden_states = layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; } hidden_states = self.norm.forward(&hidden_states)?; *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); let logits = self.lm_head.forward(&hidden_states)?; Ok(logits) } /// Get the last hidden states (for the code predictor). pub fn last_hidden_state(&self) -> Option { self.last_hidden.borrow().clone() } /// Number of transformer layers. pub fn num_layers(&self) -> usize { self.config.num_hidden_layers } /// Hidden size. pub fn hidden_size(&self) -> usize { self.config.hidden_size } /// Get token embedding layer (for input preparation). pub fn embed_tokens(&self) -> &Embedding { &self.embed_tokens } /// Create a causal attention mask. pub fn make_causal_mask( seq_len: usize, past_len: usize, dtype: DType, device: &Device, ) -> Result { let total_len = past_len + seq_len; if seq_len == 1 { // Single token: no masking needed (can attend to everything) return Tensor::zeros((1, 1, 1, total_len), dtype, device); } // Full causal mask: lower triangular let mask: Vec = (0..seq_len) .flat_map(|i| { (0..total_len).map(move |j| { if j <= past_len + i { 0.0 } else { f32::NEG_INFINITY } }) }) .collect(); Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_kv_cache() { let device = Device::Cpu; let mut cache = KvCache::new(); assert_eq!(cache.seq_len(), 0); let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); let (fk, _fv) = cache.append(&k, &v).unwrap(); assert_eq!(cache.seq_len(), 5); assert_eq!(fk.dim(2).unwrap(), 5); let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); let (fk2, _fv2) = cache.append(&k2, &v2).unwrap(); assert_eq!(cache.seq_len(), 6); assert_eq!(fk2.dim(2).unwrap(), 6); cache.reset(); assert_eq!(cache.seq_len(), 0); } #[test] fn test_causal_mask_single_token() { let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap(); assert_eq!(mask.dims(), &[1, 1, 1, 11]); // All zeros — single token can attend to everything let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap(); assert_eq!(sum, 0.0); } #[test] fn test_causal_mask_multi_token() { let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap(); assert_eq!(mask.dims(), &[1, 1, 3, 3]); // Upper triangle should be -inf let data: Vec = mask.flatten_all().unwrap().to_vec1().unwrap(); // Row 0: [0, -inf, -inf] assert_eq!(data[0], 0.0); assert!(data[1].is_infinite() && data[1] < 0.0); assert!(data[2].is_infinite() && data[2] < 0.0); // Row 1: [0, 0, -inf] assert_eq!(data[3], 0.0); assert_eq!(data[4], 0.0); assert!(data[5].is_infinite() && data[5] < 0.0); // Row 2: [0, 0, 0] assert_eq!(data[6], 0.0); assert_eq!(data[7], 0.0); assert_eq!(data[8], 0.0); } }