diff options
Diffstat (limited to 'makima/src/tts/qwen3')
| -rw-r--r-- | makima/src/tts/qwen3/code_predictor.rs | 253 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/config.rs | 271 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/generate.rs | 456 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/mod.rs | 317 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/model.rs | 584 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/speech_tokenizer.rs | 613 |
6 files changed, 0 insertions, 2494 deletions
diff --git a/makima/src/tts/qwen3/code_predictor.rs b/makima/src/tts/qwen3/code_predictor.rs deleted file mode 100644 index 363105f..0000000 --- a/makima/src/tts/qwen3/code_predictor.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Multi-Token Prediction (MTP) code predictor. -//! -//! After the main LM predicts the zeroth codebook token, this module -//! predicts the remaining 15 codebook layers in parallel from the -//! LM's hidden states. -//! -//! Architecture: -//! - 5 transformer layers (same structure as main LM layers) -//! - 16 output heads, one per codebook (vocab 2048 each) -//! - Input: last hidden state from main LM + zeroth codebook embedding -//! - Output: 16 codebook token predictions - -use candle_core::{Device, Module, Result, Tensor}; -use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; - -use super::config::{CodePredictorConfig, Qwen3LmConfig}; -use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding}; - -/// A single code predictor transformer layer. -/// -/// Uses the same pre-norm residual structure as the main LM layers. -pub struct CodePredictorLayer { - self_attn: Qwen3Attention, - mlp: Qwen3Mlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, -} - -impl CodePredictorLayer { - pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> { - // Construct a Qwen3LmConfig-like view for the attention/MLP constructors - let lm_config = Qwen3LmConfig { - hidden_size: config.hidden_size, - num_hidden_layers: config.num_layers, - num_attention_heads: config.num_attention_heads, - num_key_value_heads: config.num_attention_heads, // No GQA in predictor - intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024 - head_dim: config.hidden_size / config.num_attention_heads, - rms_norm_eps: config.rms_norm_eps, - ..Qwen3LmConfig::default() - }; - - let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?; - let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?; - let input_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("input_layernorm"), - )?; - let post_attention_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - - Ok(Self { - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - }) - } - - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let residual = hidden_states; - let hidden_states = self.input_layernorm.forward(hidden_states)?; - let hidden_states = - self.self_attn - .forward(&hidden_states, rope, kv_cache, attention_mask)?; - let hidden_states = (residual + hidden_states)?; - - let residual = &hidden_states; - let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let output = (residual + hidden_states)?; - - Ok(output) - } -} - -/// Multi-token prediction code predictor. -/// -/// Takes the hidden states from the main LM and predicts all 16 codebook -/// tokens. The zeroth codebook is predicted by the main LM head; this -/// module predicts the remaining 15 residual codebooks. -pub struct CodePredictor { - /// Embedding layer for codebook tokens (one per residual codebook group, 0-14). - code_embeddings: Vec<Embedding>, - /// 5 transformer layers. - layers: Vec<CodePredictorLayer>, - /// Final normalization. - norm: RmsNorm, - /// Per-codebook output heads (15 heads for residual codebooks). - output_heads: Vec<Linear>, - /// RoPE for the predictor's attention layers. - rope: RotaryEmbedding, - config: CodePredictorConfig, -} - -impl CodePredictor { - pub fn new( - config: &CodePredictorConfig, - lm_config: &Qwen3LmConfig, - vb: VarBuilder, - ) -> Result<Self> { - // HuggingFace Qwen3-TTS uses "talker.code_predictor.*" prefix - let predictor_vb = vb.pp("talker").pp("code_predictor"); - let model_vb = predictor_vb.pp("model"); - - // Code embeddings for residual codebook groups (15 groups, indices 0-14) - // HF names them "codec_embedding" not "code_embeddings" - let num_residual_groups = config.num_code_groups - 1; // 15, not 16 - let mut code_embeddings = Vec::with_capacity(num_residual_groups); - for i in 0..num_residual_groups { - let emb = embedding( - config.codebook_vocab_size, - config.hidden_size, - model_vb.pp(format!("codec_embedding.{i}")), - )?; - code_embeddings.push(emb); - } - - // Transformer layers - let mut layers = Vec::with_capacity(config.num_layers); - for i in 0..config.num_layers { - let layer = - CodePredictorLayer::new(config, model_vb.pp(format!("layers.{i}")))?; - layers.push(layer); - } - - let norm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - model_vb.pp("norm"), - )?; - - // Output heads for residual codebooks (15 heads, indices 0-14) - // HF names them "lm_head" not "output_heads" - let mut output_heads = Vec::with_capacity(num_residual_groups); - for i in 0..num_residual_groups { - let head = linear_no_bias( - config.hidden_size, - config.codebook_vocab_size, - predictor_vb.pp(format!("lm_head.{i}")), - )?; - output_heads.push(head); - } - - // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim) - let predictor_head_dim = config.hidden_size / config.num_attention_heads; - let rope_config = Qwen3LmConfig { - head_dim: predictor_head_dim, - rope_theta: lm_config.rope_theta, - max_position_embeddings: lm_config.max_position_embeddings, - ..Qwen3LmConfig::default() - }; - let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?; - - Ok(Self { - code_embeddings, - layers, - norm, - output_heads, - rope, - config: config.clone(), - }) - } - - /// Predict all 16 codebook tokens from the LM hidden state. - /// - /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM - /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook) - /// - /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code. - pub fn predict( - &self, - lm_hidden: &Tensor, - zeroth_code: u32, - device: &Device, - ) -> Result<Vec<u32>> { - let mut all_codes = Vec::with_capacity(self.config.num_code_groups); - all_codes.push(zeroth_code); - - // The code predictor iterates through the 15 residual codebook groups. - // For each group i (0..15), it: - // 1. Embeds the previous codebook token - // 2. Adds to LM hidden state - // 3. Runs through predictor layers - // 4. Predicts the next codebook token via lm_head[i] - let mut prev_code = zeroth_code; - - for group_idx in 0..self.code_embeddings.len() { - // Embed the previous codebook token - let code_tensor = Tensor::from_vec( - vec![prev_code], - (1, 1), - device, - )?; - let code_emb = self.code_embeddings[group_idx].forward(&code_tensor)?; - - // Add code embedding to LM hidden state (no concatenation, no projection) - let mut hidden = (lm_hidden + &code_emb)?; - - // Run through predictor transformer layers (no KV cache needed — single step) - let mut kv_caches: Vec<KvCache> = - (0..self.config.num_layers).map(|_| KvCache::new()).collect(); - for (i, layer) in self.layers.iter().enumerate() { - hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?; - } - - hidden = self.norm.forward(&hidden)?; - - // Predict codebook token - let logits = self.output_heads[group_idx].forward(&hidden)?; - - // Greedy decode: argmax - let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size] - let next_code = logits_flat - .argmax(0)? - .to_scalar::<u32>()?; - - all_codes.push(next_code); - prev_code = next_code; - } - - Ok(all_codes) - } - - /// Number of codebook groups. - pub fn num_code_groups(&self) -> usize { - self.config.num_code_groups - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_code_predictor_config() { - let config = CodePredictorConfig::default(); - assert_eq!(config.num_layers, 5); - assert_eq!(config.num_code_groups, 16); - assert_eq!(config.codebook_vocab_size, 2048); - assert_eq!(config.hidden_size, 1024); - } -} diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs deleted file mode 100644 index 6fb55d7..0000000 --- a/makima/src/tts/qwen3/config.rs +++ /dev/null @@ -1,271 +0,0 @@ -//! Qwen3-TTS model configuration. -//! -//! Parses config.json from the HuggingFace model repository to configure -//! the language model, code predictor, and speech tokenizer. - -use serde::Deserialize; - -use crate::tts::TtsError; - -/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base. -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3TtsConfig { - /// Language model (talker) configuration. - #[serde(default = "Qwen3LmConfig::default")] - pub lm: Qwen3LmConfig, - - /// Code predictor (multi-token prediction) configuration. - #[serde(default = "CodePredictorConfig::default")] - pub code_predictor: CodePredictorConfig, - - /// Speech tokenizer configuration. - #[serde(default = "SpeechTokenizerConfig::default")] - pub speech_tokenizer: SpeechTokenizerConfig, -} - -impl Default for Qwen3TtsConfig { - fn default() -> Self { - Self { - lm: Qwen3LmConfig::default(), - code_predictor: CodePredictorConfig::default(), - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -impl Qwen3TtsConfig { - /// Load from a config.json file path. - pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> { - let content = std::fs::read_to_string(path) - .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?; - Self::from_json_str(&content) - } - - /// Load from a JSON string. - pub fn from_json_str(json: &str) -> Result<Self, TtsError> { - // Try to parse the full HuggingFace config.json format first - if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) { - return Ok(Self::from_hf_config(&hf_config)); - } - // Fall back to direct deserialization - serde_json::from_str(json) - .map_err(|e| TtsError::Config(format!("failed to parse config: {e}"))) - } - - /// Convert from HuggingFace's config.json format. - fn from_hf_config(hf: &HfConfig) -> Self { - Self { - lm: Qwen3LmConfig { - hidden_size: hf.hidden_size.unwrap_or(1024), - num_hidden_layers: hf.num_hidden_layers.unwrap_or(28), - num_attention_heads: hf.num_attention_heads.unwrap_or(16), - num_key_value_heads: hf.num_key_value_heads.unwrap_or(8), - intermediate_size: hf.intermediate_size.unwrap_or(3072), - head_dim: hf.head_dim.unwrap_or(128), - vocab_size: hf.vocab_size.unwrap_or(151_936), - max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - rope_theta: hf.rope_theta.unwrap_or(1_000_000.0), - use_sliding_window: hf.use_sliding_window.unwrap_or(false), - sliding_window: hf.sliding_window, - hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()), - }, - code_predictor: CodePredictorConfig { - hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024), - num_layers: hf.code_predictor_num_layers.unwrap_or(5), - num_attention_heads: hf - .code_predictor_num_attention_heads - .unwrap_or(16), - num_code_groups: hf.num_code_groups.unwrap_or(16), - codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - }, - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -/// Language model configuration (28-layer Qwen3 transformer). -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3LmConfig { - /// Hidden dimension of transformer layers. - pub hidden_size: usize, - /// Number of transformer layers. - pub num_hidden_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of key-value heads (GQA). - pub num_key_value_heads: usize, - /// Feed-forward intermediate size. - pub intermediate_size: usize, - /// Dimension per attention head. - pub head_dim: usize, - /// Text vocabulary size. - pub vocab_size: usize, - /// Maximum sequence length for RoPE. - pub max_position_embeddings: usize, - /// RMS normalization epsilon. - pub rms_norm_eps: f64, - /// RoPE theta parameter. - pub rope_theta: f64, - /// Whether to use sliding window attention. - pub use_sliding_window: bool, - /// Sliding window size (if enabled). - pub sliding_window: Option<usize>, - /// Activation function name. - pub hidden_act: String, -} - -impl Default for Qwen3LmConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_hidden_layers: 28, - num_attention_heads: 16, - num_key_value_heads: 8, - intermediate_size: 3072, - head_dim: 128, - vocab_size: 151_936, - max_position_embeddings: 32_768, - rms_norm_eps: 1e-6, - rope_theta: 1_000_000.0, - use_sliding_window: false, - sliding_window: None, - hidden_act: "silu".to_string(), - } - } -} - -impl Qwen3LmConfig { - /// Number of key-value head groups for GQA. - pub fn num_kv_groups(&self) -> usize { - self.num_attention_heads / self.num_key_value_heads - } -} - -/// Code predictor (multi-token prediction) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct CodePredictorConfig { - /// Hidden size (matches LM hidden size). - pub hidden_size: usize, - /// Number of predictor transformer layers. - pub num_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of codebook groups (residual codebooks). - pub num_code_groups: usize, - /// Vocabulary size per codebook. - pub codebook_vocab_size: usize, - /// RMS norm epsilon. - pub rms_norm_eps: f64, -} - -impl Default for CodePredictorConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_layers: 5, - num_attention_heads: 16, - num_code_groups: 16, - codebook_vocab_size: 2048, - rms_norm_eps: 1e-6, - } - } -} - -/// Speech tokenizer (ConvNet codec) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct SpeechTokenizerConfig { - /// Number of RVQ codebooks. - pub num_codebooks: usize, - /// Codebook embedding dimension. - pub codebook_dim: usize, - /// Codebook vocabulary size per layer. - pub codebook_size: usize, - /// Encoder/decoder hidden channels. - pub hidden_channels: usize, - /// Output sample rate. - pub sample_rate: u32, - /// Token frame rate (Hz). - pub frame_rate: f32, - /// HuggingFace model ID for the speech tokenizer. - pub model_id: String, -} - -impl Default for SpeechTokenizerConfig { - fn default() -> Self { - Self { - num_codebooks: 16, - codebook_dim: 256, - codebook_size: 2048, - hidden_channels: 512, - sample_rate: 24_000, - frame_rate: 12.5, - model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(), - } - } -} - -/// HuggingFace config.json format (partial, fields we need). -#[derive(Debug, Deserialize)] -struct HfConfig { - hidden_size: Option<usize>, - num_hidden_layers: Option<usize>, - num_attention_heads: Option<usize>, - num_key_value_heads: Option<usize>, - intermediate_size: Option<usize>, - head_dim: Option<usize>, - vocab_size: Option<usize>, - max_position_embeddings: Option<usize>, - rms_norm_eps: Option<f64>, - rope_theta: Option<f64>, - use_sliding_window: Option<bool>, - sliding_window: Option<usize>, - hidden_act: Option<String>, - // Code predictor specific fields - code_predictor_hidden_size: Option<usize>, - code_predictor_num_layers: Option<usize>, - code_predictor_num_attention_heads: Option<usize>, - num_code_groups: Option<usize>, - codebook_vocab_size: Option<usize>, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.num_attention_heads, 16); - assert_eq!(config.lm.num_key_value_heads, 8); - assert_eq!(config.lm.head_dim, 128); - assert_eq!(config.lm.num_kv_groups(), 2); - assert_eq!(config.code_predictor.num_layers, 5); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.num_codebooks, 16); - } - - #[test] - fn test_config_from_json() { - let json = r#"{ - "hidden_size": 1024, - "num_hidden_layers": 28, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 3072, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rms_norm_eps": 1e-6, - "rope_theta": 1000000.0, - "hidden_act": "silu" - }"#; - - let config = Qwen3TtsConfig::from_json_str(json).unwrap(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.vocab_size, 151_936); - } -} diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs deleted file mode 100644 index 30d165b..0000000 --- a/makima/src/tts/qwen3/generate.rs +++ /dev/null @@ -1,456 +0,0 @@ -//! Autoregressive generation loop for Qwen3-TTS. -//! -//! Orchestrates the full inference pipeline: -//! 1. Encode reference audio → speaker embedding via speech tokenizer -//! 2. Tokenize text → token IDs -//! 3. Autoregressive LM generation → zeroth codebook tokens -//! 4. Code predictor → remaining 15 codebook tokens per frame -//! 5. Speech tokenizer decoder → waveform audio - -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device, IndexOp, Result, Tensor}; -use tokenizers::Tokenizer; - -use super::code_predictor::CodePredictor; -use super::model::{KvCache, Qwen3Model}; -use super::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE}; - -/// Special tokens for the Qwen3-TTS vocabulary. -pub const BOS_TOKEN_ID: u32 = 151_643; -pub const EOS_TOKEN_ID: u32 = 151_645; -pub const PAD_TOKEN_ID: u32 = 151_643; - -/// Speech-specific control tokens. -/// These are placeholders — actual values come from the tokenizer config. -pub const START_OF_SPEECH: u32 = 151_668; -pub const END_OF_SPEECH: u32 = 151_669; - -/// Generation configuration. -#[derive(Debug, Clone)] -pub struct GenerationConfig { - /// Maximum number of speech tokens to generate. - pub max_new_tokens: usize, - /// Temperature for sampling (1.0 = greedy if top_k=1). - pub temperature: f32, - /// Top-k sampling (0 = disabled, use greedy argmax). - pub top_k: usize, - /// Repetition penalty. - pub repetition_penalty: f32, - /// Whether to generate audio chunks incrementally (streaming). - pub streaming: bool, -} - -impl Default for GenerationConfig { - fn default() -> Self { - Self { - max_new_tokens: 2048, - temperature: 1.0, - top_k: 0, // Greedy by default - repetition_penalty: 1.2, - streaming: false, - } - } -} - -/// Manages the full generation pipeline. -pub struct GenerationContext<'a> { - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - /// Optional cancellation flag. When set to `true`, the generation loop - /// will break early and return whatever audio has been produced so far. - cancel_flag: Option<Arc<AtomicBool>>, -} - -impl<'a> GenerationContext<'a> { - pub fn new( - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Self { - Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - device, - config, - cancel_flag, - } - } - - /// Check whether cancellation has been requested. - fn is_cancelled(&self) -> bool { - self.cancel_flag - .as_ref() - .map_or(false, |f| f.load(Ordering::Relaxed)) - } - - /// Generate audio from text, optionally with a voice reference. - /// - /// Returns a list of audio chunks. If `streaming` is false, returns - /// a single chunk with the complete audio. - pub fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - // 1. Encode reference audio if provided - let reference_codes = match reference_audio { - Some(audio) => Some( - self.speech_tokenizer - .encode(audio) - .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?, - ), - None => None, - }; - - // 2. Tokenize text - let encoding = self - .tokenizer - .encode(text, true) - .map_err(|e| TtsError::Tokenizer(e.to_string()))?; - - let text_token_ids: Vec<u32> = encoding.get_ids().to_vec(); - - // 3. Prepare input sequence - // Format: [BOS] [text_tokens] [START_OF_SPEECH] - let mut input_ids = Vec::new(); - input_ids.push(BOS_TOKEN_ID); - input_ids.extend_from_slice(&text_token_ids); - input_ids.push(START_OF_SPEECH); - - // 4. Run autoregressive generation - let generated_frames = self - .autoregressive_generate(&input_ids, reference_codes.as_deref()) - .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?; - - if generated_frames.is_empty() { - return Ok(vec![AudioChunk { - samples: vec![], - sample_rate: SAMPLE_RATE, - is_final: true, - }]); - } - - // 5. Decode all frames to audio - if self.config.streaming { - self.decode_streaming(&generated_frames) - } else { - self.decode_batch(&generated_frames) - } - } - - /// Autoregressive generation loop. - /// - /// Generates zeroth codebook tokens one at a time, then uses the code - /// predictor to fill in the remaining 15 codebooks per frame. - /// - /// Returns: Vec of frames, each frame is [num_codebooks] tokens. - fn autoregressive_generate( - &self, - input_ids: &[u32], - _reference_codes: Option<&[Vec<u32>]>, - ) -> Result<Vec<Vec<u32>>> { - let _num_codebooks = self.code_predictor.num_code_groups(); - let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers()) - .map(|_| KvCache::new()) - .collect(); - - let mut generated_frames: Vec<Vec<u32>> = Vec::new(); - let mut past_zeroth_tokens: Vec<u32> = Vec::new(); - - // === First iteration: process the full input sequence === - let input_tensor = Tensor::from_vec( - input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(), - (1, input_ids.len()), - self.device, - )? - .to_dtype(DType::I64)?; - - let seq_len = input_ids.len(); - let attention_mask = - Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?; - - let logits = - self.model - .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?; - - // Get the logits for the last position - let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size] - let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?; - - if first_token == END_OF_SPEECH as u32 { - return Ok(generated_frames); - } - - // Use code predictor for all codebooks - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?; - - let frame_codes = self - .code_predictor - .predict(&last_hidden, first_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(first_token); - - // === Subsequent iterations: one token at a time === - for _step in 1..self.config.max_new_tokens { - // Check for cancellation each iteration - if self.is_cancelled() { - tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); - break; - } - - let past_len = kv_caches[0].seq_len(); - - // Input: just the last generated zeroth codebook token - let last_token = *past_zeroth_tokens.last().unwrap(); - let token_tensor = Tensor::from_vec( - vec![last_token as i64], - (1, 1), - self.device, - )? - .to_dtype(DType::I64)?; - - // Single-token attention mask - let attention_mask = - Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?; - - let logits = - self.model - .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?; - - let next_logits = logits.i((0, 0, ..))?; // [vocab_size] - let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?; - - if next_token == END_OF_SPEECH as u32 { - break; - } - - // Predict all codebooks for this frame - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - - let frame_codes = self - .code_predictor - .predict(&lm_hidden, next_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(next_token); - } - - Ok(generated_frames) - } - - /// Sample a token from logits. - fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> { - let mut logits_vec: Vec<f32> = logits.to_vec1()?; - - // Apply repetition penalty - if self.config.repetition_penalty != 1.0 { - for &token in past_tokens { - let idx = token as usize; - if idx < logits_vec.len() { - let score = logits_vec[idx]; - logits_vec[idx] = if score < 0.0 { - score * self.config.repetition_penalty - } else { - score / self.config.repetition_penalty - }; - } - } - } - - if self.config.top_k == 0 || self.config.temperature == 0.0 { - // Greedy: argmax - let (max_idx, _) = logits_vec - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap_or((0, &0.0)); - Ok(max_idx as u32) - } else { - // Top-k sampling with temperature - let temperature = self.config.temperature; - - // Apply temperature - for v in logits_vec.iter_mut() { - *v /= temperature; - } - - // Sort indices by logit value (descending) - let mut indexed: Vec<(usize, f32)> = - logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); - indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Keep only top-k - let k = self.config.top_k.min(indexed.len()); - let top_k = &indexed[..k]; - - // Softmax over top-k - let max_val = top_k[0].1; - let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum(); - let probs: Vec<(usize, f32)> = top_k - .iter() - .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum)) - .collect(); - - // Sample from distribution (simple linear scan) - let r: f32 = random_float(); - let mut cumulative = 0.0; - for (idx, prob) in &probs { - cumulative += prob; - if cumulative >= r { - return Ok(*idx as u32); - } - } - - // Fallback to highest probability - Ok(probs[0].0 as u32) - } - } - - /// Decode all frames in batch (non-streaming). - fn decode_batch( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames] - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frames { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?; - - Ok(vec![AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: true, - }]) - } - - /// Decode frames incrementally (streaming). - fn decode_streaming( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let mut chunks: Vec<AudioChunk> = Vec::new(); - - // Decode in groups of frames for efficiency - let chunk_size = 10; // ~800ms per chunk at 12.5Hz - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { - // Check for cancellation between streaming chunks - if self.is_cancelled() { - tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); - if let Some(last) = chunks.last_mut() { - last.is_final = true; - } - return Ok(chunks); - } - - let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); - - // Transpose chunk frames - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frame_chunk { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?; - - chunks.push(AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: is_last, - }); - } - - Ok(chunks) - } -} - -/// Simple pseudo-random float in [0, 1) using thread-local state. -/// Uses a basic xorshift for reproducibility without external deps. -fn random_float() -> f32 { - use std::cell::Cell; - thread_local! { - static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0); - } - - STATE.with(|s| { - let mut x = s.get(); - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - s.set(x); - (x as f32) / (u64::MAX as f32) - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generation_config_default() { - let config = GenerationConfig::default(); - assert_eq!(config.max_new_tokens, 2048); - assert_eq!(config.top_k, 0); - assert_eq!(config.temperature, 1.0); - assert_eq!(config.repetition_penalty, 1.2); - assert!(!config.streaming); - } - - #[test] - fn test_random_float_range() { - for _ in 0..100 { - let r = random_float(); - assert!(r >= 0.0); - assert!(r < 1.0); - } - } - - #[test] - fn test_special_tokens() { - assert_eq!(BOS_TOKEN_ID, 151_643); - assert_eq!(EOS_TOKEN_ID, 151_645); - assert_eq!(START_OF_SPEECH, 151_668); - assert_eq!(END_OF_SPEECH, 151_669); - } -} diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs deleted file mode 100644 index fc6c472..0000000 --- a/makima/src/tts/qwen3/mod.rs +++ /dev/null @@ -1,317 +0,0 @@ -//! Qwen3-TTS — Pure Rust implementation using candle. -//! -//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis -//! with voice cloning support. No Python, no ONNX — pure Rust inference -//! via the candle ML framework. -//! -//! # Architecture -//! -//! The model has three components: -//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens -//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers -//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes -//! -//! # Usage -//! -//! ```rust,no_run -//! use makima::tts::qwen3::Qwen3Tts; -//! use candle_core::Device; -//! -//! let device = Device::Cpu; -//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap(); -//! // Use via TtsEngine trait or direct API -//! ``` - -pub mod code_predictor; -pub mod config; -pub mod generate; -pub mod model; -pub mod speech_tokenizer; - -use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device}; -use candle_nn::VarBuilder; -use hf_hub::api::sync::Api; -use tokenizers::Tokenizer; - -use self::code_predictor::CodePredictor; -use self::config::Qwen3TtsConfig; -use self::generate::{GenerationConfig, GenerationContext}; -use self::model::Qwen3Model; -use self::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE}; - -/// HuggingFace model IDs. -const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"; -const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz"; -const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts"; - -/// Qwen3-TTS engine — pure Rust candle-based inference. -pub struct Qwen3Tts { - /// The 28-layer language model. - model: Qwen3Model, - /// Multi-token prediction code predictor. - code_predictor: CodePredictor, - /// Speech tokenizer (encoder + decoder + RVQ). - speech_tokenizer: SpeechTokenizer, - /// Text tokenizer. - tokenizer: Tokenizer, - /// Model configuration. - config: Qwen3TtsConfig, - /// Compute device (CPU/CUDA/Metal). - device: Device, - /// Whether the model is fully loaded and ready. - ready: AtomicBool, -} - -// SAFETY: All fields are either Send+Sync or behind appropriate synchronization. -// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync. -unsafe impl Send for Qwen3Tts {} -unsafe impl Sync for Qwen3Tts {} - -impl Qwen3Tts { - /// Load from a local directory or download from HuggingFace. - pub fn from_pretrained( - model_dir: Option<&str>, - device: &Device, - ) -> Result<Self, TtsError> { - let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); - - if !model_path.exists() { - Self::download_models(&model_path)?; - } - - Self::load_from_path(&model_path, device) - } - - /// Load all model components from a local directory. - pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> { - let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU - - // Load configuration - let config_path = model_dir.join("config.json"); - let config = if config_path.exists() { - Qwen3TtsConfig::from_json_path(&config_path)? - } else { - Qwen3TtsConfig::default() - }; - - // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats) - let tokenizer_json_path = model_dir.join("tokenizer.json"); - let tokenizer = if tokenizer_json_path.exists() { - Tokenizer::from_file(&tokenizer_json_path) - .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))? - } else { - // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format) - let vocab_path = model_dir.join("vocab.json"); - let merges_path = model_dir.join("merges.txt"); - - if !vocab_path.exists() || !merges_path.exists() { - return Err(TtsError::Tokenizer(format!( - "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}", - model_dir.display() - ))); - } - - tokenizers::Tokenizer::from_file(&vocab_path) - .or_else(|_| { - // Build BPE tokenizer from vocab and merges - use tokenizers::models::bpe::BPE; - let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy()) - .build() - .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?; - Ok(Tokenizer::new(bpe)) - }) - .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))? - }; - - // Load LM weights from safetensors - let lm_weights_path = model_dir.join("model.safetensors"); - let lm_data = std::fs::read(&lm_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read LM weights from {}: {e}", - lm_weights_path.display() - )) - })?; - let lm_vb = VarBuilder::from_buffered_safetensors( - lm_data, - dtype, - device, - ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?; - - // Build language model - let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| { - TtsError::ModelLoad(format!("failed to build LM model: {e}")) - })?; - - // Build code predictor (weights are in the same safetensors file) - let code_predictor = - CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| { - TtsError::ModelLoad(format!("failed to build code predictor: {e}")) - })?; - - // Load speech tokenizer from separate safetensors - let st_weights_path = model_dir.join("speech_tokenizer.safetensors"); - let st_data = std::fs::read(&st_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read speech tokenizer weights from {}: {e}", - st_weights_path.display() - )) - })?; - let st_vb = VarBuilder::from_buffered_safetensors( - st_data, - dtype, - device, - ).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to create speech tokenizer VarBuilder: {e}" - )) - })?; - - let speech_tokenizer = - SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| { - TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}")) - })?; - - Ok(Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - config, - device: device.clone(), - ready: AtomicBool::new(true), - }) - } - - /// Generate audio from text with optional voice reference. - pub fn generate_speech( - &self, - text: &str, - reference_audio: Option<&[f32]>, - gen_config: Option<GenerationConfig>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - let config = gen_config.unwrap_or_default(); - - let ctx = GenerationContext::new( - &self.model, - &self.code_predictor, - &self.speech_tokenizer, - &self.tokenizer, - &self.device, - config, - cancel_flag, - ); - - ctx.generate(text, reference_audio) - } - - /// Download model files from HuggingFace Hub. - fn download_models(target_dir: &Path) -> Result<(), TtsError> { - std::fs::create_dir_all(target_dir)?; - - let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; - - // Download LM model files - println!("Downloading Qwen3-TTS language model..."); - let lm_repo = api.model(LM_MODEL_ID.to_string()); - - // Note: HuggingFace repo has vocab.json + merges.txt instead of tokenizer.json - let lm_files = [ - "model.safetensors", - "config.json", - "vocab.json", - "merges.txt", - "tokenizer_config.json", - ]; - - for file in &lm_files { - println!(" Downloading {file}..."); - let downloaded = lm_repo - .get(file) - .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?; - - let target = target_dir.join(file); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - } - - // Download speech tokenizer - println!("Downloading Qwen3-TTS speech tokenizer..."); - let st_repo = api.model(TOKENIZER_MODEL_ID.to_string()); - - let st_file = "model.safetensors"; - let downloaded = st_repo - .get(st_file) - .map_err(|e| { - TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}")) - })?; - - let target = target_dir.join("speech_tokenizer.safetensors"); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - - println!("All models downloaded to {}", target_dir.display()); - Ok(()) - } - - /// Get the model configuration. - pub fn config(&self) -> &Qwen3TtsConfig { - &self.config - } - - /// Get the compute device. - pub fn device(&self) -> &Device { - &self.device - } -} - -#[async_trait::async_trait] -impl TtsEngine for Qwen3Tts { - async fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - _reference_sample_rate: Option<u32>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - // Note: reference audio should already be resampled to 24kHz - // by the caller. If a different sample rate is provided, - // the caller should resample using `resample_to_24k()`. - self.generate_speech(text, reference_audio, None, cancel_flag) - } - - fn is_ready(&self) -> bool { - self.ready.load(Ordering::Relaxed) - } - - fn sample_rate(&self) -> u32 { - SAMPLE_RATE - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.sample_rate, 24_000); - } - - #[test] - fn test_model_ids() { - assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base"); - assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz"); - } -} diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs deleted file mode 100644 index e19e5f9..0000000 --- a/makima/src/tts/qwen3/model.rs +++ /dev/null @@ -1,584 +0,0 @@ -//! Qwen3 Language Model transformer backbone. -//! -//! Implements the 28-layer transformer with: -//! - Rotary Position Embeddings (RoPE) -//! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads -//! - SiLU-gated MLP -//! - RMS normalization -//! - KV cache for autoregressive generation -//! -//! Based on the candle-transformers Qwen2 model architecture, -//! extended for Qwen3-TTS. - -use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; - -use super::config::Qwen3LmConfig; - -// --------------------------------------------------------------------------- -// Rotary Position Embeddings -// --------------------------------------------------------------------------- - -/// Precomputed RoPE sin/cos tables. -#[derive(Debug, Clone)] -pub struct RotaryEmbedding { - cos: Tensor, - sin: Tensor, -} - -impl RotaryEmbedding { - pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result<Self> { - let head_dim = config.head_dim; - let max_seq = config.max_position_embeddings; - let theta = config.rope_theta; - - let inv_freq: Vec<f32> = (0..head_dim) - .step_by(2) - .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32)) - .collect(); - - let inv_freq_tensor = - Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?; - - let positions: Vec<f32> = (0..max_seq).map(|p| p as f32).collect(); - let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?; - - // [max_seq, head_dim/2] - let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?; - // [max_seq, head_dim] by repeating - let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; - - let cos = emb.cos()?.to_dtype(dtype)?; - let sin = emb.sin()?.to_dtype(dtype)?; - - Ok(Self { cos, sin }) - } - - /// Apply RoPE to query and key tensors. - /// Input shape: [batch, heads, seq_len, head_dim] - pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { - let seq_len = q.dim(2)?; - let cos = self.cos.narrow(0, offset, seq_len)?; - let sin = self.sin.narrow(0, offset, seq_len)?; - - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim] - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; - - let q_rotated = Self::rotate_half(q, &cos, &sin)?; - let k_rotated = Self::rotate_half(k, &cos, &sin)?; - - Ok((q_rotated, k_rotated)) - } - - fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { - let half_dim = x.dim(D::Minus1)? / 2; - let x1 = x.narrow(D::Minus1, 0, half_dim)?; - let x2 = x.narrow(D::Minus1, half_dim, half_dim)?; - - // [-x2, x1] concatenated - let neg_x2 = x2.neg()?; - let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?; - - // x * cos + rotated * sin - let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?; - Ok(result) - } -} - -// --------------------------------------------------------------------------- -// KV Cache -// --------------------------------------------------------------------------- - -/// Per-layer key-value cache for autoregressive generation. -#[derive(Debug, Clone)] -pub struct KvCache { - key: Option<Tensor>, - value: Option<Tensor>, -} - -impl KvCache { - pub fn new() -> Self { - Self { - key: None, - value: None, - } - } - - /// Append new key/value tensors and return the full cached sequence. - /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim] - pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> { - let (full_key, full_value) = match (&self.key, &self.value) { - (Some(prev_k), Some(prev_v)) => { - let k = Tensor::cat(&[prev_k, key], 2)?; - let v = Tensor::cat(&[prev_v, value], 2)?; - (k, v) - } - _ => (key.clone(), value.clone()), - }; - - self.key = Some(full_key.clone()); - self.value = Some(full_value.clone()); - - Ok((full_key, full_value)) - } - - /// Current cached sequence length. - pub fn seq_len(&self) -> usize { - self.key - .as_ref() - .map(|k| k.dim(2).unwrap_or(0)) - .unwrap_or(0) - } - - /// Reset the cache. - pub fn reset(&mut self) { - self.key = None; - self.value = None; - } -} - -// --------------------------------------------------------------------------- -// Attention -// --------------------------------------------------------------------------- - -/// Multi-head attention with GQA and RoPE. -pub struct Qwen3Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - q_norm: RmsNorm, - k_norm: RmsNorm, - num_heads: usize, - num_kv_heads: usize, - head_dim: usize, - num_kv_groups: usize, -} - -impl Qwen3Attention { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let hidden = config.hidden_size; - let num_heads = config.num_attention_heads; - let num_kv_heads = config.num_key_value_heads; - let head_dim = config.head_dim; - - let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?; - - let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?; - - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - q_norm, - k_norm, - num_heads, - num_kv_heads, - head_dim, - num_kv_groups: config.num_kv_groups(), - }) - } - - /// Forward pass with KV cache and RoPE. - /// Input: [batch, seq_len, hidden_size] - /// Returns: [batch, seq_len, hidden_size] - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let (batch, seq_len, _) = hidden_states.dims3()?; - let offset = kv_cache.seq_len(); - - // Project Q, K, V - let q = self.q_proj.forward(hidden_states)?; - let k = self.k_proj.forward(hidden_states)?; - let v = self.v_proj.forward(hidden_states)?; - - // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim] - let q = q - .reshape((batch, seq_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - - // Apply QK normalization (Qwen3 specific) - let q = self.apply_head_norm(&q, &self.q_norm)?; - let k = self.apply_head_norm(&k, &self.k_norm)?; - - // Apply RoPE - let (q, k) = rope.apply(&q, &k, offset)?; - - // Update KV cache - let (k, v) = kv_cache.append(&k, &v)?; - - // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim] - let k = self.repeat_kv(&k)?; - let v = self.repeat_kv(&v)?; - - // Scaled dot-product attention - let scale = (self.head_dim as f64).sqrt(); - let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?; - - let attn_weights = match attention_mask { - Some(mask) => attn_weights.broadcast_add(mask)?, - None => attn_weights, - }; - - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - - // Attention output - let attn_output = attn_weights.matmul(&v)?; - - // [batch, heads, seq, dim] -> [batch, seq, heads*dim] - let attn_output = attn_output - .transpose(1, 2)? - .reshape((batch, seq_len, self.num_heads * self.head_dim))?; - - self.o_proj.forward(&attn_output) - } - - /// Apply RMS norm per-head. - fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result<Tensor> { - let (b, h, s, d) = x.dims4()?; - // Reshape to [b*h*s, d] for norm, then back - let flat = x.reshape((b * h * s, d))?; - let normed = norm.forward(&flat)?; - normed.reshape((b, h, s, d)) - } - - /// Repeat KV heads for GQA. - fn repeat_kv(&self, x: &Tensor) -> Result<Tensor> { - if self.num_kv_groups == 1 { - return Ok(x.clone()); - } - let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))? - .reshape((batch, self.num_heads, seq_len, head_dim))?; - Ok(x) - } -} - -// --------------------------------------------------------------------------- -// MLP -// --------------------------------------------------------------------------- - -/// SiLU-gated feed-forward network. -pub struct Qwen3Mlp { - gate_proj: Linear, - up_proj: Linear, - down_proj: Linear, -} - -impl Qwen3Mlp { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let hidden = config.hidden_size; - let intermediate = config.intermediate_size; - - let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?; - - Ok(Self { - gate_proj, - up_proj, - down_proj, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let gate = self.gate_proj.forward(x)?; - let gate = candle_nn::Activation::Silu.forward(&gate)?; - let up = self.up_proj.forward(x)?; - let hidden = (gate * up)?; - self.down_proj.forward(&hidden) - } -} - -// --------------------------------------------------------------------------- -// Transformer Layer -// --------------------------------------------------------------------------- - -/// A single Qwen3 transformer decoder layer. -pub struct Qwen3DecoderLayer { - self_attn: Qwen3Attention, - mlp: Qwen3Mlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, -} - -impl Qwen3DecoderLayer { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?; - let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?; - let input_layernorm = - rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - - Ok(Self { - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - }) - } - - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - // Pre-norm attention - let residual = hidden_states; - let hidden_states = self.input_layernorm.forward(hidden_states)?; - let hidden_states = - self.self_attn - .forward(&hidden_states, rope, kv_cache, attention_mask)?; - let hidden_states = (residual + hidden_states)?; - - // Pre-norm MLP - let residual = &hidden_states; - let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let output = (residual + hidden_states)?; - - Ok(output) - } -} - -// --------------------------------------------------------------------------- -// Full Model -// --------------------------------------------------------------------------- - -/// The complete Qwen3 language model for TTS. -/// -/// Architecture: -/// - Token embedding layer -/// - 28 transformer decoder layers -/// - Final RMS normalization -/// - LM head (projects to vocab) -pub struct Qwen3Model { - embed_tokens: Embedding, - layers: Vec<Qwen3DecoderLayer>, - norm: RmsNorm, - lm_head: Linear, - rope: RotaryEmbedding, - config: Qwen3LmConfig, - /// Last hidden states (before lm_head), used by code predictor. - last_hidden: std::cell::RefCell<Option<Tensor>>, -} - -impl Qwen3Model { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - // HuggingFace Qwen3-TTS uses "talker.model.*" prefix - let talker_vb = vb.pp("talker"); - let model_vb = talker_vb.pp("model"); - - // Text embedding (called "text_embedding" in HF, not "embed_tokens") - let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?; - - let mut layers = Vec::with_capacity(config.num_hidden_layers); - for i in 0..config.num_hidden_layers { - let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?; - layers.push(layer); - } - - let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?; - - // Codec head (called "codec_head" in HF, not "lm_head") - let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?; - - let dtype = vb.dtype(); - let device = vb.device().clone(); - let rope = RotaryEmbedding::new(config, dtype, &device)?; - - Ok(Self { - embed_tokens, - layers, - norm, - lm_head, - rope, - config: config.clone(), - last_hidden: std::cell::RefCell::new(None), - }) - } - - /// Forward pass through the full model. - /// - /// `input_ids`: [batch, seq_len] — token IDs - /// `kv_caches`: per-layer KV caches - /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len] - /// - /// Returns logits: [batch, seq_len, vocab_size] - pub fn forward( - &self, - input_ids: &Tensor, - kv_caches: &mut [KvCache], - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let mut hidden_states = self.embed_tokens.forward(input_ids)?; - - for (i, layer) in self.layers.iter().enumerate() { - hidden_states = - layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; - } - - hidden_states = self.norm.forward(&hidden_states)?; - - // Store last hidden state for code predictor - *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); - - let logits = self.lm_head.forward(&hidden_states)?; - Ok(logits) - } - - /// Forward pass with pre-computed embeddings (for first iteration where - /// text embeddings are concatenated with audio features). - /// - /// `inputs_embeds`: [batch, seq_len, hidden_size] - pub fn forward_embeds( - &self, - inputs_embeds: &Tensor, - kv_caches: &mut [KvCache], - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let mut hidden_states = inputs_embeds.clone(); - - for (i, layer) in self.layers.iter().enumerate() { - hidden_states = - layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; - } - - hidden_states = self.norm.forward(&hidden_states)?; - - *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); - - let logits = self.lm_head.forward(&hidden_states)?; - Ok(logits) - } - - /// Get the last hidden states (for the code predictor). - pub fn last_hidden_state(&self) -> Option<Tensor> { - self.last_hidden.borrow().clone() - } - - /// Number of transformer layers. - pub fn num_layers(&self) -> usize { - self.config.num_hidden_layers - } - - /// Hidden size. - pub fn hidden_size(&self) -> usize { - self.config.hidden_size - } - - /// Get token embedding layer (for input preparation). - pub fn embed_tokens(&self) -> &Embedding { - &self.embed_tokens - } - - /// Create a causal attention mask. - pub fn make_causal_mask( - seq_len: usize, - past_len: usize, - dtype: DType, - device: &Device, - ) -> Result<Tensor> { - let total_len = past_len + seq_len; - - if seq_len == 1 { - // Single token: no masking needed (can attend to everything) - return Tensor::zeros((1, 1, 1, total_len), dtype, device); - } - - // Full causal mask: lower triangular - let mask: Vec<f32> = (0..seq_len) - .flat_map(|i| { - (0..total_len).map(move |j| { - if j <= past_len + i { - 0.0 - } else { - f32::NEG_INFINITY - } - }) - }) - .collect(); - - Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_kv_cache() { - let device = Device::Cpu; - let mut cache = KvCache::new(); - assert_eq!(cache.seq_len(), 0); - - let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); - let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); - let (fk, _fv) = cache.append(&k, &v).unwrap(); - assert_eq!(cache.seq_len(), 5); - assert_eq!(fk.dim(2).unwrap(), 5); - - let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); - let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); - let (fk2, _fv2) = cache.append(&k2, &v2).unwrap(); - assert_eq!(cache.seq_len(), 6); - assert_eq!(fk2.dim(2).unwrap(), 6); - - cache.reset(); - assert_eq!(cache.seq_len(), 0); - } - - #[test] - fn test_causal_mask_single_token() { - let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap(); - assert_eq!(mask.dims(), &[1, 1, 1, 11]); - // All zeros — single token can attend to everything - let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap(); - assert_eq!(sum, 0.0); - } - - #[test] - fn test_causal_mask_multi_token() { - let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap(); - assert_eq!(mask.dims(), &[1, 1, 3, 3]); - // Upper triangle should be -inf - let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap(); - // Row 0: [0, -inf, -inf] - assert_eq!(data[0], 0.0); - assert!(data[1].is_infinite() && data[1] < 0.0); - assert!(data[2].is_infinite() && data[2] < 0.0); - // Row 1: [0, 0, -inf] - assert_eq!(data[3], 0.0); - assert_eq!(data[4], 0.0); - assert!(data[5].is_infinite() && data[5] < 0.0); - // Row 2: [0, 0, 0] - assert_eq!(data[6], 0.0); - assert_eq!(data[7], 0.0); - assert_eq!(data[8], 0.0); - } -} diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs deleted file mode 100644 index 86e00f2..0000000 --- a/makima/src/tts/qwen3/speech_tokenizer.rs +++ /dev/null @@ -1,613 +0,0 @@ -//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks. -//! -//! Two sub-components: -//! -//! **Encoder** (voice cloning): converts reference audio waveform to discrete -//! multi-codebook tokens via a causal 1D ConvNet + RVQ. -//! -//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook -//! indices via embedding lookup + causal 1D ConvNet. -//! -//! The speech tokenizer is a separate model (~682MB) loaded from -//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`. - -use candle_core::{Device, Module, Result, Tensor, D}; -use candle_nn::{ - conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder, -}; - -use super::config::SpeechTokenizerConfig; - -// --------------------------------------------------------------------------- -// Weight-Normalized Conv1d -// --------------------------------------------------------------------------- - -/// A 1D convolution with optional weight normalization and activation. -pub struct ConvBlock { - conv: Conv1d, - activation: ConvActivation, -} - -#[derive(Debug, Clone, Copy)] -pub enum ConvActivation { - None, - Elu, - Tanh, -} - -impl ConvBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - activation: ConvActivation, - vb: VarBuilder, - ) -> Result<Self> { - let config = Conv1dConfig { - stride, - padding, - dilation, - groups: 1, - }; - let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?; - - Ok(Self { conv, activation }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let out = self.conv.forward(x)?; - match self.activation { - ConvActivation::None => Ok(out), - ConvActivation::Elu => elu(&out, 1.0), - ConvActivation::Tanh => out.tanh(), - } - } -} - -/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0 -fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> { - let zeros = x.zeros_like()?; - let positive = x.maximum(&zeros)?; - let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?; - let exp_x = x.exp()?; - let one = Tensor::ones_like(&exp_x)?; - let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?; - positive + negative -} - -// --------------------------------------------------------------------------- -// Residual Unit -// --------------------------------------------------------------------------- - -/// Residual convolutional unit with dilated convolutions. -pub struct ResidualUnit { - conv1: ConvBlock, - conv2: ConvBlock, -} - -impl ResidualUnit { - pub fn new( - channels: usize, - dilation: usize, - vb: VarBuilder, - ) -> Result<Self> { - // Dilated causal conv (kernel=7, dilation varies) - let padding = (7 - 1) * dilation / 2; // causal-ish padding - let conv1 = ConvBlock::new( - channels, - channels, - 7, - 1, - padding, - dilation, - ConvActivation::Elu, - vb.pp("block.0"), - )?; - - // Pointwise conv (kernel=1) - let conv2 = ConvBlock::new( - channels, - channels, - 1, - 1, - 0, - 1, - ConvActivation::Elu, - vb.pp("block.1"), - )?; - - Ok(Self { conv1, conv2 }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let residual = x; - let out = self.conv1.forward(x)?; - let out = self.conv2.forward(&out)?; - // Match sequence lengths if needed (causal conv may change length) - let out_len = out.dim(D::Minus1)?; - let res_len = residual.dim(D::Minus1)?; - if out_len != res_len { - let start = res_len.saturating_sub(out_len); - let residual = residual.narrow(D::Minus1, start, out_len)?; - residual + out - } else { - residual + out - } - } -} - -// --------------------------------------------------------------------------- -// Encoder Block -// --------------------------------------------------------------------------- - -/// Encoder downsampling block: residual units + strided conv. -pub struct EncoderBlock { - residual_units: Vec<ResidualUnit>, - downsample: ConvBlock, -} - -impl EncoderBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - stride: usize, - num_residuals: usize, - vb: VarBuilder, - ) -> Result<Self> { - let mut residual_units = Vec::with_capacity(num_residuals); - for i in 0..num_residuals { - let dilation = 3usize.pow(i as u32); // 1, 3, 9 - let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?; - residual_units.push(unit); - } - - // Strided downsampling convolution - let kernel_size = stride * 2; - let padding = stride / 2; - let downsample = ConvBlock::new( - in_channels, - out_channels, - kernel_size, - stride, - padding, - 1, - ConvActivation::Elu, - vb.pp("downsample"), - )?; - - Ok(Self { - residual_units, - downsample, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let mut out = x.clone(); - for unit in &self.residual_units { - out = unit.forward(&out)?; - } - self.downsample.forward(&out) - } -} - -// --------------------------------------------------------------------------- -// Decoder Block -// --------------------------------------------------------------------------- - -/// Decoder upsampling block: transposed conv + residual units. -pub struct DecoderBlock { - upsample: ConvBlock, - residual_units: Vec<ResidualUnit>, -} - -impl DecoderBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - stride: usize, - num_residuals: usize, - vb: VarBuilder, - ) -> Result<Self> { - // Strided upsampling (transpose conv simulated by regular conv + padding) - let kernel_size = stride * 2; - let padding = stride / 2; - let upsample = ConvBlock::new( - in_channels, - out_channels, - kernel_size, - 1, // stride=1 for output; upsample via repeat/interpolation - padding, - 1, - ConvActivation::Elu, - vb.pp("upsample"), - )?; - - let mut residual_units = Vec::with_capacity(num_residuals); - for i in 0..num_residuals { - let dilation = 3usize.pow(i as u32); - let unit = - ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?; - residual_units.push(unit); - } - - Ok(Self { - upsample, - residual_units, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let mut out = self.upsample.forward(x)?; - for unit in &self.residual_units { - out = unit.forward(&out)?; - } - Ok(out) - } -} - -// --------------------------------------------------------------------------- -// RVQ Codebook -// --------------------------------------------------------------------------- - -/// Residual Vector Quantization codebook. -/// -/// Contains `num_codebooks` embedding tables, each mapping -/// `codebook_size` indices to `codebook_dim`-dimensional vectors. -pub struct RvqCodebook { - codebooks: Vec<Embedding>, - num_codebooks: usize, - #[allow(dead_code)] - codebook_dim: usize, -} - -impl RvqCodebook { - pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> { - let mut codebooks = Vec::with_capacity(config.num_codebooks); - for i in 0..config.num_codebooks { - let cb = embedding( - config.codebook_size, - config.codebook_dim, - vb.pp(format!("codebooks.{i}")), - )?; - codebooks.push(cb); - } - - Ok(Self { - codebooks, - num_codebooks: config.num_codebooks, - codebook_dim: config.codebook_dim, - }) - } - - /// Look up codebook embeddings for all codebook layers. - /// - /// `codes`: [num_codebooks, seq_len] — codebook indices per layer - /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings - pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> { - assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks); - - let seq_len = codes[0].len(); - let mut sum: Option<Tensor> = None; - - for (i, code_layer) in codes.iter().enumerate() { - assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch"); - - let indices = Tensor::from_vec( - code_layer.clone(), - (1, seq_len), - device, - )?; - - // [1, seq_len, codebook_dim] - let emb = self.codebooks[i].forward(&indices)?; - - sum = Some(match sum { - Some(prev) => (prev + emb)?, - None => emb, - }); - } - - // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len] - let result = sum.unwrap().transpose(1, 2)?; - Ok(result) - } - - /// Number of codebooks. - pub fn num_codebooks(&self) -> usize { - self.num_codebooks - } -} - -// --------------------------------------------------------------------------- -// Speech Tokenizer (Encoder + Decoder) -// --------------------------------------------------------------------------- - -/// The complete speech tokenizer with encoder and decoder. -pub struct SpeechTokenizer { - /// Encoder: waveform -> latent (for voice cloning). - encoder_input_conv: ConvBlock, - encoder_blocks: Vec<EncoderBlock>, - encoder_output_conv: ConvBlock, - - /// RVQ codebooks for quantization. - codebook: RvqCodebook, - - /// Decoder: codes -> waveform. - decoder_input_conv: ConvBlock, - decoder_blocks: Vec<DecoderBlock>, - decoder_output_conv: ConvBlock, - - /// Projection from codebook dim to decoder hidden channels. - decoder_proj: Linear, - - config: SpeechTokenizerConfig, - device: Device, -} - -impl SpeechTokenizer { - /// Load the speech tokenizer from safetensors. - pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> { - let hidden = config.hidden_channels; // 512 - - // ===== Encoder ===== - // Input: [batch, 1, samples] -> [batch, hidden/8, ...] - let encoder_input_conv = ConvBlock::new( - 1, - hidden / 8, // 64 - 7, - 1, - 3, - 1, - ConvActivation::Elu, - vb.pp("encoder.input_conv"), - )?; - - // Downsampling blocks with increasing channels - let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480 - let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512 - let mut encoder_blocks = Vec::with_capacity(strides.len()); - for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() { - let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] }; - let block = EncoderBlock::new( - in_ch, - out_ch, - stride, - 3, // 3 residual units per block - vb.pp(format!("encoder.blocks.{i}")), - )?; - encoder_blocks.push(block); - } - - // Encoder output projection to codebook dim - let encoder_output_conv = ConvBlock::new( - hidden, - config.codebook_dim, - 3, - 1, - 1, - 1, - ConvActivation::None, - vb.pp("encoder.output_conv"), - )?; - - // ===== RVQ Codebook ===== - let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?; - - // ===== Decoder ===== - // Projection from codebook dim to decoder hidden - let decoder_proj = linear_no_bias( - config.codebook_dim, - hidden, - vb.pp("decoder.proj"), - )?; - - // Input conv - let decoder_input_conv = ConvBlock::new( - hidden, - hidden, - 7, - 1, - 3, - 1, - ConvActivation::Elu, - vb.pp("decoder.input_conv"), - )?; - - // Upsampling blocks (reverse order of encoder) - let dec_strides = [3, 4, 5, 8]; - let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64 - let mut decoder_blocks = Vec::with_capacity(dec_strides.len()); - for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate() - { - let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] }; - let block = DecoderBlock::new( - in_ch, - out_ch, - stride, - 3, - vb.pp(format!("decoder.blocks.{i}")), - )?; - decoder_blocks.push(block); - } - - // Output conv: hidden/8 -> 1 channel (waveform) - let decoder_output_conv = ConvBlock::new( - hidden / 8, - 1, - 7, - 1, - 3, - 1, - ConvActivation::Tanh, - vb.pp("decoder.output_conv"), - )?; - - Ok(Self { - encoder_input_conv, - encoder_blocks, - encoder_output_conv, - codebook, - decoder_input_conv, - decoder_blocks, - decoder_output_conv, - decoder_proj, - config: config.clone(), - device: device.clone(), - }) - } - - /// Encode reference audio waveform to discrete codebook tokens. - /// - /// `audio`: [num_samples] — mono 24kHz audio - /// Returns: Vec of `num_codebooks` vectors, each containing token indices. - pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> { - // [1, 1, num_samples] - let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?; - - // Run encoder - let mut hidden = self.encoder_input_conv.forward(&x)?; - for block in &self.encoder_blocks { - hidden = block.forward(&hidden)?; - } - let latent = self.encoder_output_conv.forward(&hidden)?; - - // latent: [1, codebook_dim, seq_len] - // Quantize via nearest-neighbor lookup in each codebook - let seq_len = latent.dim(D::Minus1)?; - let mut all_codes = Vec::with_capacity(self.config.num_codebooks); - - // Residual quantization: subtract each codebook's contribution - let mut residual = latent.clone(); - - for cb_idx in 0..self.config.num_codebooks { - // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep - let codes = self.quantize_layer(&residual, cb_idx, seq_len)?; - - // Look up the quantized vectors and subtract from residual - let code_indices = - Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?; - let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?; - // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len] - let quantized = quantized.transpose(1, 2)?; - residual = (residual - quantized)?; - - all_codes.push(codes); - } - - Ok(all_codes) - } - - /// Quantize a single RVQ layer by finding the nearest codebook entry. - fn quantize_layer( - &self, - residual: &Tensor, - codebook_idx: usize, - _seq_len: usize, - ) -> Result<Vec<u32>> { - // residual: [1, codebook_dim, seq_len] - // codebook weights: [codebook_size, codebook_dim] - let cb_weight = self.codebook.codebooks[codebook_idx] - .embeddings() - .clone(); // [codebook_size, codebook_dim] - - // Transpose residual: [1, seq_len, codebook_dim] - let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim] - - // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2 - let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len] - let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size] - let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size] - - let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1] - let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size] - - let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size] - - // Argmin per timestep - let indices = distances.argmin(D::Minus1)?; // [seq_len] - let codes: Vec<u32> = indices.to_vec1()?; - - Ok(codes) - } - - /// Decode discrete codebook tokens to audio waveform. - /// - /// `codes`: Vec of `num_codebooks` vectors of token indices. - /// Returns: Vec<f32> — mono 24kHz audio samples. - pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> { - // Look up and sum all codebook embeddings - let embeddings = self.codebook.decode(codes, &self.device)?; - // embeddings: [1, codebook_dim, seq_len] - - // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden] - let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim] - let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden] - let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len] - - // Run decoder - hidden = self.decoder_input_conv.forward(&hidden)?; - for block in &self.decoder_blocks { - hidden = block.forward(&hidden)?; - } - let waveform = self.decoder_output_conv.forward(&hidden)?; - - // [1, 1, num_samples] -> Vec<f32> - let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?; - Ok(samples) - } - - /// Decode a single frame's codes to audio samples (for streaming). - /// - /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame - /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz) - pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> { - let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect(); - self.decode(&codes) - } - - /// Get the number of codebooks. - pub fn num_codebooks(&self) -> usize { - self.config.num_codebooks - } - - /// Get the output sample rate. - pub fn sample_rate(&self) -> u32 { - self.config.sample_rate - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_elu_positive() { - let device = Device::Cpu; - let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap(); - let result = elu(&x, 1.0).unwrap(); - let values: Vec<f32> = result.to_vec1().unwrap(); - assert!((values[0] - 1.0).abs() < 1e-5); - assert!((values[1] - 2.0).abs() < 1e-5); - } - - #[test] - fn test_elu_negative() { - let device = Device::Cpu; - let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap(); - let result = elu(&x, 1.0).unwrap(); - let values: Vec<f32> = result.to_vec1().unwrap(); - // ELU(-1) = exp(-1) - 1 ≈ -0.6321 - assert!((values[0] - (-0.6321)).abs() < 0.01); - } - - #[test] - fn test_speech_tokenizer_config() { - let config = SpeechTokenizerConfig::default(); - assert_eq!(config.num_codebooks, 16); - assert_eq!(config.codebook_size, 2048); - assert_eq!(config.sample_rate, 24_000); - } -} |
