summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/tts/qwen3')
-rw-r--r--makima/src/tts/qwen3/code_predictor.rs253
-rw-r--r--makima/src/tts/qwen3/config.rs271
-rw-r--r--makima/src/tts/qwen3/generate.rs456
-rw-r--r--makima/src/tts/qwen3/mod.rs317
-rw-r--r--makima/src/tts/qwen3/model.rs584
-rw-r--r--makima/src/tts/qwen3/speech_tokenizer.rs613
6 files changed, 0 insertions, 2494 deletions
diff --git a/makima/src/tts/qwen3/code_predictor.rs b/makima/src/tts/qwen3/code_predictor.rs
deleted file mode 100644
index 363105f..0000000
--- a/makima/src/tts/qwen3/code_predictor.rs
+++ /dev/null
@@ -1,253 +0,0 @@
-//! Multi-Token Prediction (MTP) code predictor.
-//!
-//! After the main LM predicts the zeroth codebook token, this module
-//! predicts the remaining 15 codebook layers in parallel from the
-//! LM's hidden states.
-//!
-//! Architecture:
-//! - 5 transformer layers (same structure as main LM layers)
-//! - 16 output heads, one per codebook (vocab 2048 each)
-//! - Input: last hidden state from main LM + zeroth codebook embedding
-//! - Output: 16 codebook token predictions
-
-use candle_core::{Device, Module, Result, Tensor};
-use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
-
-use super::config::{CodePredictorConfig, Qwen3LmConfig};
-use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};
-
-/// A single code predictor transformer layer.
-///
-/// Uses the same pre-norm residual structure as the main LM layers.
-pub struct CodePredictorLayer {
- self_attn: Qwen3Attention,
- mlp: Qwen3Mlp,
- input_layernorm: RmsNorm,
- post_attention_layernorm: RmsNorm,
-}
-
-impl CodePredictorLayer {
- pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
- // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
- let lm_config = Qwen3LmConfig {
- hidden_size: config.hidden_size,
- num_hidden_layers: config.num_layers,
- num_attention_heads: config.num_attention_heads,
- num_key_value_heads: config.num_attention_heads, // No GQA in predictor
- intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
- head_dim: config.hidden_size / config.num_attention_heads,
- rms_norm_eps: config.rms_norm_eps,
- ..Qwen3LmConfig::default()
- };
-
- let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
- let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
- let input_layernorm = rms_norm(
- config.hidden_size,
- config.rms_norm_eps,
- vb.pp("input_layernorm"),
- )?;
- let post_attention_layernorm = rms_norm(
- config.hidden_size,
- config.rms_norm_eps,
- vb.pp("post_attention_layernorm"),
- )?;
-
- Ok(Self {
- self_attn,
- mlp,
- input_layernorm,
- post_attention_layernorm,
- })
- }
-
- pub fn forward(
- &self,
- hidden_states: &Tensor,
- rope: &RotaryEmbedding,
- kv_cache: &mut KvCache,
- attention_mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- let residual = hidden_states;
- let hidden_states = self.input_layernorm.forward(hidden_states)?;
- let hidden_states =
- self.self_attn
- .forward(&hidden_states, rope, kv_cache, attention_mask)?;
- let hidden_states = (residual + hidden_states)?;
-
- let residual = &hidden_states;
- let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
- let hidden_states = self.mlp.forward(&hidden_states)?;
- let output = (residual + hidden_states)?;
-
- Ok(output)
- }
-}
-
-/// Multi-token prediction code predictor.
-///
-/// Takes the hidden states from the main LM and predicts all 16 codebook
-/// tokens. The zeroth codebook is predicted by the main LM head; this
-/// module predicts the remaining 15 residual codebooks.
-pub struct CodePredictor {
- /// Embedding layer for codebook tokens (one per residual codebook group, 0-14).
- code_embeddings: Vec<Embedding>,
- /// 5 transformer layers.
- layers: Vec<CodePredictorLayer>,
- /// Final normalization.
- norm: RmsNorm,
- /// Per-codebook output heads (15 heads for residual codebooks).
- output_heads: Vec<Linear>,
- /// RoPE for the predictor's attention layers.
- rope: RotaryEmbedding,
- config: CodePredictorConfig,
-}
-
-impl CodePredictor {
- pub fn new(
- config: &CodePredictorConfig,
- lm_config: &Qwen3LmConfig,
- vb: VarBuilder,
- ) -> Result<Self> {
- // HuggingFace Qwen3-TTS uses "talker.code_predictor.*" prefix
- let predictor_vb = vb.pp("talker").pp("code_predictor");
- let model_vb = predictor_vb.pp("model");
-
- // Code embeddings for residual codebook groups (15 groups, indices 0-14)
- // HF names them "codec_embedding" not "code_embeddings"
- let num_residual_groups = config.num_code_groups - 1; // 15, not 16
- let mut code_embeddings = Vec::with_capacity(num_residual_groups);
- for i in 0..num_residual_groups {
- let emb = embedding(
- config.codebook_vocab_size,
- config.hidden_size,
- model_vb.pp(format!("codec_embedding.{i}")),
- )?;
- code_embeddings.push(emb);
- }
-
- // Transformer layers
- let mut layers = Vec::with_capacity(config.num_layers);
- for i in 0..config.num_layers {
- let layer =
- CodePredictorLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
- layers.push(layer);
- }
-
- let norm = rms_norm(
- config.hidden_size,
- config.rms_norm_eps,
- model_vb.pp("norm"),
- )?;
-
- // Output heads for residual codebooks (15 heads, indices 0-14)
- // HF names them "lm_head" not "output_heads"
- let mut output_heads = Vec::with_capacity(num_residual_groups);
- for i in 0..num_residual_groups {
- let head = linear_no_bias(
- config.hidden_size,
- config.codebook_vocab_size,
- predictor_vb.pp(format!("lm_head.{i}")),
- )?;
- output_heads.push(head);
- }
-
- // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
- let predictor_head_dim = config.hidden_size / config.num_attention_heads;
- let rope_config = Qwen3LmConfig {
- head_dim: predictor_head_dim,
- rope_theta: lm_config.rope_theta,
- max_position_embeddings: lm_config.max_position_embeddings,
- ..Qwen3LmConfig::default()
- };
- let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;
-
- Ok(Self {
- code_embeddings,
- layers,
- norm,
- output_heads,
- rope,
- config: config.clone(),
- })
- }
-
- /// Predict all 16 codebook tokens from the LM hidden state.
- ///
- /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
- /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
- ///
- /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
- pub fn predict(
- &self,
- lm_hidden: &Tensor,
- zeroth_code: u32,
- device: &Device,
- ) -> Result<Vec<u32>> {
- let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
- all_codes.push(zeroth_code);
-
- // The code predictor iterates through the 15 residual codebook groups.
- // For each group i (0..15), it:
- // 1. Embeds the previous codebook token
- // 2. Adds to LM hidden state
- // 3. Runs through predictor layers
- // 4. Predicts the next codebook token via lm_head[i]
- let mut prev_code = zeroth_code;
-
- for group_idx in 0..self.code_embeddings.len() {
- // Embed the previous codebook token
- let code_tensor = Tensor::from_vec(
- vec![prev_code],
- (1, 1),
- device,
- )?;
- let code_emb = self.code_embeddings[group_idx].forward(&code_tensor)?;
-
- // Add code embedding to LM hidden state (no concatenation, no projection)
- let mut hidden = (lm_hidden + &code_emb)?;
-
- // Run through predictor transformer layers (no KV cache needed — single step)
- let mut kv_caches: Vec<KvCache> =
- (0..self.config.num_layers).map(|_| KvCache::new()).collect();
- for (i, layer) in self.layers.iter().enumerate() {
- hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
- }
-
- hidden = self.norm.forward(&hidden)?;
-
- // Predict codebook token
- let logits = self.output_heads[group_idx].forward(&hidden)?;
-
- // Greedy decode: argmax
- let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
- let next_code = logits_flat
- .argmax(0)?
- .to_scalar::<u32>()?;
-
- all_codes.push(next_code);
- prev_code = next_code;
- }
-
- Ok(all_codes)
- }
-
- /// Number of codebook groups.
- pub fn num_code_groups(&self) -> usize {
- self.config.num_code_groups
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_code_predictor_config() {
- let config = CodePredictorConfig::default();
- assert_eq!(config.num_layers, 5);
- assert_eq!(config.num_code_groups, 16);
- assert_eq!(config.codebook_vocab_size, 2048);
- assert_eq!(config.hidden_size, 1024);
- }
-}
diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs
deleted file mode 100644
index 6fb55d7..0000000
--- a/makima/src/tts/qwen3/config.rs
+++ /dev/null
@@ -1,271 +0,0 @@
-//! Qwen3-TTS model configuration.
-//!
-//! Parses config.json from the HuggingFace model repository to configure
-//! the language model, code predictor, and speech tokenizer.
-
-use serde::Deserialize;
-
-use crate::tts::TtsError;
-
-/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
-#[derive(Debug, Clone, Deserialize)]
-pub struct Qwen3TtsConfig {
- /// Language model (talker) configuration.
- #[serde(default = "Qwen3LmConfig::default")]
- pub lm: Qwen3LmConfig,
-
- /// Code predictor (multi-token prediction) configuration.
- #[serde(default = "CodePredictorConfig::default")]
- pub code_predictor: CodePredictorConfig,
-
- /// Speech tokenizer configuration.
- #[serde(default = "SpeechTokenizerConfig::default")]
- pub speech_tokenizer: SpeechTokenizerConfig,
-}
-
-impl Default for Qwen3TtsConfig {
- fn default() -> Self {
- Self {
- lm: Qwen3LmConfig::default(),
- code_predictor: CodePredictorConfig::default(),
- speech_tokenizer: SpeechTokenizerConfig::default(),
- }
- }
-}
-
-impl Qwen3TtsConfig {
- /// Load from a config.json file path.
- pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
- let content = std::fs::read_to_string(path)
- .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
- Self::from_json_str(&content)
- }
-
- /// Load from a JSON string.
- pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
- // Try to parse the full HuggingFace config.json format first
- if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
- return Ok(Self::from_hf_config(&hf_config));
- }
- // Fall back to direct deserialization
- serde_json::from_str(json)
- .map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
- }
-
- /// Convert from HuggingFace's config.json format.
- fn from_hf_config(hf: &HfConfig) -> Self {
- Self {
- lm: Qwen3LmConfig {
- hidden_size: hf.hidden_size.unwrap_or(1024),
- num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
- num_attention_heads: hf.num_attention_heads.unwrap_or(16),
- num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
- intermediate_size: hf.intermediate_size.unwrap_or(3072),
- head_dim: hf.head_dim.unwrap_or(128),
- vocab_size: hf.vocab_size.unwrap_or(151_936),
- max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
- rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
- rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
- use_sliding_window: hf.use_sliding_window.unwrap_or(false),
- sliding_window: hf.sliding_window,
- hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
- },
- code_predictor: CodePredictorConfig {
- hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
- num_layers: hf.code_predictor_num_layers.unwrap_or(5),
- num_attention_heads: hf
- .code_predictor_num_attention_heads
- .unwrap_or(16),
- num_code_groups: hf.num_code_groups.unwrap_or(16),
- codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
- rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
- },
- speech_tokenizer: SpeechTokenizerConfig::default(),
- }
- }
-}
-
-/// Language model configuration (28-layer Qwen3 transformer).
-#[derive(Debug, Clone, Deserialize)]
-pub struct Qwen3LmConfig {
- /// Hidden dimension of transformer layers.
- pub hidden_size: usize,
- /// Number of transformer layers.
- pub num_hidden_layers: usize,
- /// Number of attention heads.
- pub num_attention_heads: usize,
- /// Number of key-value heads (GQA).
- pub num_key_value_heads: usize,
- /// Feed-forward intermediate size.
- pub intermediate_size: usize,
- /// Dimension per attention head.
- pub head_dim: usize,
- /// Text vocabulary size.
- pub vocab_size: usize,
- /// Maximum sequence length for RoPE.
- pub max_position_embeddings: usize,
- /// RMS normalization epsilon.
- pub rms_norm_eps: f64,
- /// RoPE theta parameter.
- pub rope_theta: f64,
- /// Whether to use sliding window attention.
- pub use_sliding_window: bool,
- /// Sliding window size (if enabled).
- pub sliding_window: Option<usize>,
- /// Activation function name.
- pub hidden_act: String,
-}
-
-impl Default for Qwen3LmConfig {
- fn default() -> Self {
- Self {
- hidden_size: 1024,
- num_hidden_layers: 28,
- num_attention_heads: 16,
- num_key_value_heads: 8,
- intermediate_size: 3072,
- head_dim: 128,
- vocab_size: 151_936,
- max_position_embeddings: 32_768,
- rms_norm_eps: 1e-6,
- rope_theta: 1_000_000.0,
- use_sliding_window: false,
- sliding_window: None,
- hidden_act: "silu".to_string(),
- }
- }
-}
-
-impl Qwen3LmConfig {
- /// Number of key-value head groups for GQA.
- pub fn num_kv_groups(&self) -> usize {
- self.num_attention_heads / self.num_key_value_heads
- }
-}
-
-/// Code predictor (multi-token prediction) configuration.
-#[derive(Debug, Clone, Deserialize)]
-pub struct CodePredictorConfig {
- /// Hidden size (matches LM hidden size).
- pub hidden_size: usize,
- /// Number of predictor transformer layers.
- pub num_layers: usize,
- /// Number of attention heads.
- pub num_attention_heads: usize,
- /// Number of codebook groups (residual codebooks).
- pub num_code_groups: usize,
- /// Vocabulary size per codebook.
- pub codebook_vocab_size: usize,
- /// RMS norm epsilon.
- pub rms_norm_eps: f64,
-}
-
-impl Default for CodePredictorConfig {
- fn default() -> Self {
- Self {
- hidden_size: 1024,
- num_layers: 5,
- num_attention_heads: 16,
- num_code_groups: 16,
- codebook_vocab_size: 2048,
- rms_norm_eps: 1e-6,
- }
- }
-}
-
-/// Speech tokenizer (ConvNet codec) configuration.
-#[derive(Debug, Clone, Deserialize)]
-pub struct SpeechTokenizerConfig {
- /// Number of RVQ codebooks.
- pub num_codebooks: usize,
- /// Codebook embedding dimension.
- pub codebook_dim: usize,
- /// Codebook vocabulary size per layer.
- pub codebook_size: usize,
- /// Encoder/decoder hidden channels.
- pub hidden_channels: usize,
- /// Output sample rate.
- pub sample_rate: u32,
- /// Token frame rate (Hz).
- pub frame_rate: f32,
- /// HuggingFace model ID for the speech tokenizer.
- pub model_id: String,
-}
-
-impl Default for SpeechTokenizerConfig {
- fn default() -> Self {
- Self {
- num_codebooks: 16,
- codebook_dim: 256,
- codebook_size: 2048,
- hidden_channels: 512,
- sample_rate: 24_000,
- frame_rate: 12.5,
- model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
- }
- }
-}
-
-/// HuggingFace config.json format (partial, fields we need).
-#[derive(Debug, Deserialize)]
-struct HfConfig {
- hidden_size: Option<usize>,
- num_hidden_layers: Option<usize>,
- num_attention_heads: Option<usize>,
- num_key_value_heads: Option<usize>,
- intermediate_size: Option<usize>,
- head_dim: Option<usize>,
- vocab_size: Option<usize>,
- max_position_embeddings: Option<usize>,
- rms_norm_eps: Option<f64>,
- rope_theta: Option<f64>,
- use_sliding_window: Option<bool>,
- sliding_window: Option<usize>,
- hidden_act: Option<String>,
- // Code predictor specific fields
- code_predictor_hidden_size: Option<usize>,
- code_predictor_num_layers: Option<usize>,
- code_predictor_num_attention_heads: Option<usize>,
- num_code_groups: Option<usize>,
- codebook_vocab_size: Option<usize>,
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_default_config() {
- let config = Qwen3TtsConfig::default();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.lm.num_attention_heads, 16);
- assert_eq!(config.lm.num_key_value_heads, 8);
- assert_eq!(config.lm.head_dim, 128);
- assert_eq!(config.lm.num_kv_groups(), 2);
- assert_eq!(config.code_predictor.num_layers, 5);
- assert_eq!(config.code_predictor.num_code_groups, 16);
- assert_eq!(config.speech_tokenizer.num_codebooks, 16);
- }
-
- #[test]
- fn test_config_from_json() {
- let json = r#"{
- "hidden_size": 1024,
- "num_hidden_layers": 28,
- "num_attention_heads": 16,
- "num_key_value_heads": 8,
- "intermediate_size": 3072,
- "vocab_size": 151936,
- "max_position_embeddings": 32768,
- "rms_norm_eps": 1e-6,
- "rope_theta": 1000000.0,
- "hidden_act": "silu"
- }"#;
-
- let config = Qwen3TtsConfig::from_json_str(json).unwrap();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.lm.vocab_size, 151_936);
- }
-}
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
deleted file mode 100644
index 30d165b..0000000
--- a/makima/src/tts/qwen3/generate.rs
+++ /dev/null
@@ -1,456 +0,0 @@
-//! Autoregressive generation loop for Qwen3-TTS.
-//!
-//! Orchestrates the full inference pipeline:
-//! 1. Encode reference audio → speaker embedding via speech tokenizer
-//! 2. Tokenize text → token IDs
-//! 3. Autoregressive LM generation → zeroth codebook tokens
-//! 4. Code predictor → remaining 15 codebook tokens per frame
-//! 5. Speech tokenizer decoder → waveform audio
-
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-
-use candle_core::{DType, Device, IndexOp, Result, Tensor};
-use tokenizers::Tokenizer;
-
-use super::code_predictor::CodePredictor;
-use super::model::{KvCache, Qwen3Model};
-use super::speech_tokenizer::SpeechTokenizer;
-use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};
-
-/// Special tokens for the Qwen3-TTS vocabulary.
-pub const BOS_TOKEN_ID: u32 = 151_643;
-pub const EOS_TOKEN_ID: u32 = 151_645;
-pub const PAD_TOKEN_ID: u32 = 151_643;
-
-/// Speech-specific control tokens.
-/// These are placeholders — actual values come from the tokenizer config.
-pub const START_OF_SPEECH: u32 = 151_668;
-pub const END_OF_SPEECH: u32 = 151_669;
-
-/// Generation configuration.
-#[derive(Debug, Clone)]
-pub struct GenerationConfig {
- /// Maximum number of speech tokens to generate.
- pub max_new_tokens: usize,
- /// Temperature for sampling (1.0 = greedy if top_k=1).
- pub temperature: f32,
- /// Top-k sampling (0 = disabled, use greedy argmax).
- pub top_k: usize,
- /// Repetition penalty.
- pub repetition_penalty: f32,
- /// Whether to generate audio chunks incrementally (streaming).
- pub streaming: bool,
-}
-
-impl Default for GenerationConfig {
- fn default() -> Self {
- Self {
- max_new_tokens: 2048,
- temperature: 1.0,
- top_k: 0, // Greedy by default
- repetition_penalty: 1.2,
- streaming: false,
- }
- }
-}
-
-/// Manages the full generation pipeline.
-pub struct GenerationContext<'a> {
- model: &'a Qwen3Model,
- code_predictor: &'a CodePredictor,
- speech_tokenizer: &'a SpeechTokenizer,
- tokenizer: &'a Tokenizer,
- device: &'a Device,
- config: GenerationConfig,
- /// Optional cancellation flag. When set to `true`, the generation loop
- /// will break early and return whatever audio has been produced so far.
- cancel_flag: Option<Arc<AtomicBool>>,
-}
-
-impl<'a> GenerationContext<'a> {
- pub fn new(
- model: &'a Qwen3Model,
- code_predictor: &'a CodePredictor,
- speech_tokenizer: &'a SpeechTokenizer,
- tokenizer: &'a Tokenizer,
- device: &'a Device,
- config: GenerationConfig,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Self {
- Self {
- model,
- code_predictor,
- speech_tokenizer,
- tokenizer,
- device,
- config,
- cancel_flag,
- }
- }
-
- /// Check whether cancellation has been requested.
- fn is_cancelled(&self) -> bool {
- self.cancel_flag
- .as_ref()
- .map_or(false, |f| f.load(Ordering::Relaxed))
- }
-
- /// Generate audio from text, optionally with a voice reference.
- ///
- /// Returns a list of audio chunks. If `streaming` is false, returns
- /// a single chunk with the complete audio.
- pub fn generate(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- // 1. Encode reference audio if provided
- let reference_codes = match reference_audio {
- Some(audio) => Some(
- self.speech_tokenizer
- .encode(audio)
- .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
- ),
- None => None,
- };
-
- // 2. Tokenize text
- let encoding = self
- .tokenizer
- .encode(text, true)
- .map_err(|e| TtsError::Tokenizer(e.to_string()))?;
-
- let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();
-
- // 3. Prepare input sequence
- // Format: [BOS] [text_tokens] [START_OF_SPEECH]
- let mut input_ids = Vec::new();
- input_ids.push(BOS_TOKEN_ID);
- input_ids.extend_from_slice(&text_token_ids);
- input_ids.push(START_OF_SPEECH);
-
- // 4. Run autoregressive generation
- let generated_frames = self
- .autoregressive_generate(&input_ids, reference_codes.as_deref())
- .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;
-
- if generated_frames.is_empty() {
- return Ok(vec![AudioChunk {
- samples: vec![],
- sample_rate: SAMPLE_RATE,
- is_final: true,
- }]);
- }
-
- // 5. Decode all frames to audio
- if self.config.streaming {
- self.decode_streaming(&generated_frames)
- } else {
- self.decode_batch(&generated_frames)
- }
- }
-
- /// Autoregressive generation loop.
- ///
- /// Generates zeroth codebook tokens one at a time, then uses the code
- /// predictor to fill in the remaining 15 codebooks per frame.
- ///
- /// Returns: Vec of frames, each frame is [num_codebooks] tokens.
- fn autoregressive_generate(
- &self,
- input_ids: &[u32],
- _reference_codes: Option<&[Vec<u32>]>,
- ) -> Result<Vec<Vec<u32>>> {
- let _num_codebooks = self.code_predictor.num_code_groups();
- let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
- .map(|_| KvCache::new())
- .collect();
-
- let mut generated_frames: Vec<Vec<u32>> = Vec::new();
- let mut past_zeroth_tokens: Vec<u32> = Vec::new();
-
- // === First iteration: process the full input sequence ===
- let input_tensor = Tensor::from_vec(
- input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
- (1, input_ids.len()),
- self.device,
- )?
- .to_dtype(DType::I64)?;
-
- let seq_len = input_ids.len();
- let attention_mask =
- Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;
-
- let logits =
- self.model
- .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;
-
- // Get the logits for the last position
- let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
- let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;
-
- if first_token == END_OF_SPEECH as u32 {
- return Ok(generated_frames);
- }
-
- // Use code predictor for all codebooks
- let lm_hidden = self
- .model
- .last_hidden_state()
- .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
- let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;
-
- let frame_codes = self
- .code_predictor
- .predict(&last_hidden, first_token, self.device)?;
- generated_frames.push(frame_codes);
- past_zeroth_tokens.push(first_token);
-
- // === Subsequent iterations: one token at a time ===
- for _step in 1..self.config.max_new_tokens {
- // Check for cancellation each iteration
- if self.is_cancelled() {
- tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
- break;
- }
-
- let past_len = kv_caches[0].seq_len();
-
- // Input: just the last generated zeroth codebook token
- let last_token = *past_zeroth_tokens.last().unwrap();
- let token_tensor = Tensor::from_vec(
- vec![last_token as i64],
- (1, 1),
- self.device,
- )?
- .to_dtype(DType::I64)?;
-
- // Single-token attention mask
- let attention_mask =
- Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;
-
- let logits =
- self.model
- .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;
-
- let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
- let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;
-
- if next_token == END_OF_SPEECH as u32 {
- break;
- }
-
- // Predict all codebooks for this frame
- let lm_hidden = self
- .model
- .last_hidden_state()
- .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
-
- let frame_codes = self
- .code_predictor
- .predict(&lm_hidden, next_token, self.device)?;
- generated_frames.push(frame_codes);
- past_zeroth_tokens.push(next_token);
- }
-
- Ok(generated_frames)
- }
-
- /// Sample a token from logits.
- fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
- let mut logits_vec: Vec<f32> = logits.to_vec1()?;
-
- // Apply repetition penalty
- if self.config.repetition_penalty != 1.0 {
- for &token in past_tokens {
- let idx = token as usize;
- if idx < logits_vec.len() {
- let score = logits_vec[idx];
- logits_vec[idx] = if score < 0.0 {
- score * self.config.repetition_penalty
- } else {
- score / self.config.repetition_penalty
- };
- }
- }
- }
-
- if self.config.top_k == 0 || self.config.temperature == 0.0 {
- // Greedy: argmax
- let (max_idx, _) = logits_vec
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| {
- a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
- })
- .unwrap_or((0, &0.0));
- Ok(max_idx as u32)
- } else {
- // Top-k sampling with temperature
- let temperature = self.config.temperature;
-
- // Apply temperature
- for v in logits_vec.iter_mut() {
- *v /= temperature;
- }
-
- // Sort indices by logit value (descending)
- let mut indexed: Vec<(usize, f32)> =
- logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
- indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
-
- // Keep only top-k
- let k = self.config.top_k.min(indexed.len());
- let top_k = &indexed[..k];
-
- // Softmax over top-k
- let max_val = top_k[0].1;
- let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
- let probs: Vec<(usize, f32)> = top_k
- .iter()
- .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
- .collect();
-
- // Sample from distribution (simple linear scan)
- let r: f32 = random_float();
- let mut cumulative = 0.0;
- for (idx, prob) in &probs {
- cumulative += prob;
- if cumulative >= r {
- return Ok(*idx as u32);
- }
- }
-
- // Fallback to highest probability
- Ok(probs[0].0 as u32)
- }
- }
-
- /// Decode all frames in batch (non-streaming).
- fn decode_batch(
- &self,
- frames: &[Vec<u32>],
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let num_codebooks = self.speech_tokenizer.num_codebooks();
-
- // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
- let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
- for frame in frames {
- for (cb_idx, &code) in frame.iter().enumerate() {
- if cb_idx < num_codebooks {
- codes_by_codebook[cb_idx].push(code);
- }
- }
- }
-
- let samples = self
- .speech_tokenizer
- .decode(&codes_by_codebook)
- .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;
-
- Ok(vec![AudioChunk {
- samples,
- sample_rate: SAMPLE_RATE,
- is_final: true,
- }])
- }
-
- /// Decode frames incrementally (streaming).
- fn decode_streaming(
- &self,
- frames: &[Vec<u32>],
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let mut chunks: Vec<AudioChunk> = Vec::new();
-
- // Decode in groups of frames for efficiency
- let chunk_size = 10; // ~800ms per chunk at 12.5Hz
- let num_codebooks = self.speech_tokenizer.num_codebooks();
-
- for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
- // Check for cancellation between streaming chunks
- if self.is_cancelled() {
- tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
- if let Some(last) = chunks.last_mut() {
- last.is_final = true;
- }
- return Ok(chunks);
- }
-
- let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
-
- // Transpose chunk frames
- let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
- for frame in frame_chunk {
- for (cb_idx, &code) in frame.iter().enumerate() {
- if cb_idx < num_codebooks {
- codes_by_codebook[cb_idx].push(code);
- }
- }
- }
-
- let samples = self
- .speech_tokenizer
- .decode(&codes_by_codebook)
- .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;
-
- chunks.push(AudioChunk {
- samples,
- sample_rate: SAMPLE_RATE,
- is_final: is_last,
- });
- }
-
- Ok(chunks)
- }
-}
-
-/// Simple pseudo-random float in [0, 1) using thread-local state.
-/// Uses a basic xorshift for reproducibility without external deps.
-fn random_float() -> f32 {
- use std::cell::Cell;
- thread_local! {
- static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
- }
-
- STATE.with(|s| {
- let mut x = s.get();
- x ^= x << 13;
- x ^= x >> 7;
- x ^= x << 17;
- s.set(x);
- (x as f32) / (u64::MAX as f32)
- })
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_generation_config_default() {
- let config = GenerationConfig::default();
- assert_eq!(config.max_new_tokens, 2048);
- assert_eq!(config.top_k, 0);
- assert_eq!(config.temperature, 1.0);
- assert_eq!(config.repetition_penalty, 1.2);
- assert!(!config.streaming);
- }
-
- #[test]
- fn test_random_float_range() {
- for _ in 0..100 {
- let r = random_float();
- assert!(r >= 0.0);
- assert!(r < 1.0);
- }
- }
-
- #[test]
- fn test_special_tokens() {
- assert_eq!(BOS_TOKEN_ID, 151_643);
- assert_eq!(EOS_TOKEN_ID, 151_645);
- assert_eq!(START_OF_SPEECH, 151_668);
- assert_eq!(END_OF_SPEECH, 151_669);
- }
-}
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs
deleted file mode 100644
index fc6c472..0000000
--- a/makima/src/tts/qwen3/mod.rs
+++ /dev/null
@@ -1,317 +0,0 @@
-//! Qwen3-TTS — Pure Rust implementation using candle.
-//!
-//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
-//! with voice cloning support. No Python, no ONNX — pure Rust inference
-//! via the candle ML framework.
-//!
-//! # Architecture
-//!
-//! The model has three components:
-//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
-//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
-//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
-//!
-//! # Usage
-//!
-//! ```rust,no_run
-//! use makima::tts::qwen3::Qwen3Tts;
-//! use candle_core::Device;
-//!
-//! let device = Device::Cpu;
-//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
-//! // Use via TtsEngine trait or direct API
-//! ```
-
-pub mod code_predictor;
-pub mod config;
-pub mod generate;
-pub mod model;
-pub mod speech_tokenizer;
-
-use std::path::{Path, PathBuf};
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-
-use candle_core::{DType, Device};
-use candle_nn::VarBuilder;
-use hf_hub::api::sync::Api;
-use tokenizers::Tokenizer;
-
-use self::code_predictor::CodePredictor;
-use self::config::Qwen3TtsConfig;
-use self::generate::{GenerationConfig, GenerationContext};
-use self::model::Qwen3Model;
-use self::speech_tokenizer::SpeechTokenizer;
-use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};
-
-/// HuggingFace model IDs.
-const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
-const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
-const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";
-
-/// Qwen3-TTS engine — pure Rust candle-based inference.
-pub struct Qwen3Tts {
- /// The 28-layer language model.
- model: Qwen3Model,
- /// Multi-token prediction code predictor.
- code_predictor: CodePredictor,
- /// Speech tokenizer (encoder + decoder + RVQ).
- speech_tokenizer: SpeechTokenizer,
- /// Text tokenizer.
- tokenizer: Tokenizer,
- /// Model configuration.
- config: Qwen3TtsConfig,
- /// Compute device (CPU/CUDA/Metal).
- device: Device,
- /// Whether the model is fully loaded and ready.
- ready: AtomicBool,
-}
-
-// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
-// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
-unsafe impl Send for Qwen3Tts {}
-unsafe impl Sync for Qwen3Tts {}
-
-impl Qwen3Tts {
- /// Load from a local directory or download from HuggingFace.
- pub fn from_pretrained(
- model_dir: Option<&str>,
- device: &Device,
- ) -> Result<Self, TtsError> {
- let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
-
- if !model_path.exists() {
- Self::download_models(&model_path)?;
- }
-
- Self::load_from_path(&model_path, device)
- }
-
- /// Load all model components from a local directory.
- pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
- let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU
-
- // Load configuration
- let config_path = model_dir.join("config.json");
- let config = if config_path.exists() {
- Qwen3TtsConfig::from_json_path(&config_path)?
- } else {
- Qwen3TtsConfig::default()
- };
-
- // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats)
- let tokenizer_json_path = model_dir.join("tokenizer.json");
- let tokenizer = if tokenizer_json_path.exists() {
- Tokenizer::from_file(&tokenizer_json_path)
- .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))?
- } else {
- // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format)
- let vocab_path = model_dir.join("vocab.json");
- let merges_path = model_dir.join("merges.txt");
-
- if !vocab_path.exists() || !merges_path.exists() {
- return Err(TtsError::Tokenizer(format!(
- "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}",
- model_dir.display()
- )));
- }
-
- tokenizers::Tokenizer::from_file(&vocab_path)
- .or_else(|_| {
- // Build BPE tokenizer from vocab and merges
- use tokenizers::models::bpe::BPE;
- let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy())
- .build()
- .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?;
- Ok(Tokenizer::new(bpe))
- })
- .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?
- };
-
- // Load LM weights from safetensors
- let lm_weights_path = model_dir.join("model.safetensors");
- let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to read LM weights from {}: {e}",
- lm_weights_path.display()
- ))
- })?;
- let lm_vb = VarBuilder::from_buffered_safetensors(
- lm_data,
- dtype,
- device,
- ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;
-
- // Build language model
- let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build LM model: {e}"))
- })?;
-
- // Build code predictor (weights are in the same safetensors file)
- let code_predictor =
- CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
- })?;
-
- // Load speech tokenizer from separate safetensors
- let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
- let st_data = std::fs::read(&st_weights_path).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to read speech tokenizer weights from {}: {e}",
- st_weights_path.display()
- ))
- })?;
- let st_vb = VarBuilder::from_buffered_safetensors(
- st_data,
- dtype,
- device,
- ).map_err(|e| {
- TtsError::ModelLoad(format!(
- "failed to create speech tokenizer VarBuilder: {e}"
- ))
- })?;
-
- let speech_tokenizer =
- SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
- TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
- })?;
-
- Ok(Self {
- model,
- code_predictor,
- speech_tokenizer,
- tokenizer,
- config,
- device: device.clone(),
- ready: AtomicBool::new(true),
- })
- }
-
- /// Generate audio from text with optional voice reference.
- pub fn generate_speech(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- gen_config: Option<GenerationConfig>,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Result<Vec<AudioChunk>, TtsError> {
- let config = gen_config.unwrap_or_default();
-
- let ctx = GenerationContext::new(
- &self.model,
- &self.code_predictor,
- &self.speech_tokenizer,
- &self.tokenizer,
- &self.device,
- config,
- cancel_flag,
- );
-
- ctx.generate(text, reference_audio)
- }
-
- /// Download model files from HuggingFace Hub.
- fn download_models(target_dir: &Path) -> Result<(), TtsError> {
- std::fs::create_dir_all(target_dir)?;
-
- let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;
-
- // Download LM model files
- println!("Downloading Qwen3-TTS language model...");
- let lm_repo = api.model(LM_MODEL_ID.to_string());
-
- // Note: HuggingFace repo has vocab.json + merges.txt instead of tokenizer.json
- let lm_files = [
- "model.safetensors",
- "config.json",
- "vocab.json",
- "merges.txt",
- "tokenizer_config.json",
- ];
-
- for file in &lm_files {
- println!(" Downloading {file}...");
- let downloaded = lm_repo
- .get(file)
- .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;
-
- let target = target_dir.join(file);
- if !target.exists() {
- std::fs::copy(&downloaded, &target)?;
- }
- }
-
- // Download speech tokenizer
- println!("Downloading Qwen3-TTS speech tokenizer...");
- let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());
-
- let st_file = "model.safetensors";
- let downloaded = st_repo
- .get(st_file)
- .map_err(|e| {
- TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
- })?;
-
- let target = target_dir.join("speech_tokenizer.safetensors");
- if !target.exists() {
- std::fs::copy(&downloaded, &target)?;
- }
-
- println!("All models downloaded to {}", target_dir.display());
- Ok(())
- }
-
- /// Get the model configuration.
- pub fn config(&self) -> &Qwen3TtsConfig {
- &self.config
- }
-
- /// Get the compute device.
- pub fn device(&self) -> &Device {
- &self.device
- }
-}
-
-#[async_trait::async_trait]
-impl TtsEngine for Qwen3Tts {
- async fn generate(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- _reference_sample_rate: Option<u32>,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Result<Vec<AudioChunk>, TtsError> {
- // Note: reference audio should already be resampled to 24kHz
- // by the caller. If a different sample rate is provided,
- // the caller should resample using `resample_to_24k()`.
- self.generate_speech(text, reference_audio, None, cancel_flag)
- }
-
- fn is_ready(&self) -> bool {
- self.ready.load(Ordering::Relaxed)
- }
-
- fn sample_rate(&self) -> u32 {
- SAMPLE_RATE
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_default_config() {
- let config = Qwen3TtsConfig::default();
- assert_eq!(config.lm.hidden_size, 1024);
- assert_eq!(config.lm.num_hidden_layers, 28);
- assert_eq!(config.code_predictor.num_code_groups, 16);
- assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
- }
-
- #[test]
- fn test_model_ids() {
- assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
- assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
- }
-}
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs
deleted file mode 100644
index e19e5f9..0000000
--- a/makima/src/tts/qwen3/model.rs
+++ /dev/null
@@ -1,584 +0,0 @@
-//! Qwen3 Language Model transformer backbone.
-//!
-//! Implements the 28-layer transformer with:
-//! - Rotary Position Embeddings (RoPE)
-//! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads
-//! - SiLU-gated MLP
-//! - RMS normalization
-//! - KV cache for autoregressive generation
-//!
-//! Based on the candle-transformers Qwen2 model architecture,
-//! extended for Qwen3-TTS.
-
-use candle_core::{DType, Device, Module, Result, Tensor, D};
-use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
-
-use super::config::Qwen3LmConfig;
-
-// ---------------------------------------------------------------------------
-// Rotary Position Embeddings
-// ---------------------------------------------------------------------------
-
-/// Precomputed RoPE sin/cos tables.
-#[derive(Debug, Clone)]
-pub struct RotaryEmbedding {
- cos: Tensor,
- sin: Tensor,
-}
-
-impl RotaryEmbedding {
- pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result<Self> {
- let head_dim = config.head_dim;
- let max_seq = config.max_position_embeddings;
- let theta = config.rope_theta;
-
- let inv_freq: Vec<f32> = (0..head_dim)
- .step_by(2)
- .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
- .collect();
-
- let inv_freq_tensor =
- Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?;
-
- let positions: Vec<f32> = (0..max_seq).map(|p| p as f32).collect();
- let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?;
-
- // [max_seq, head_dim/2]
- let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?;
- // [max_seq, head_dim] by repeating
- let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
-
- let cos = emb.cos()?.to_dtype(dtype)?;
- let sin = emb.sin()?.to_dtype(dtype)?;
-
- Ok(Self { cos, sin })
- }
-
- /// Apply RoPE to query and key tensors.
- /// Input shape: [batch, heads, seq_len, head_dim]
- pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
- let seq_len = q.dim(2)?;
- let cos = self.cos.narrow(0, offset, seq_len)?;
- let sin = self.sin.narrow(0, offset, seq_len)?;
-
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim]
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
-
- let q_rotated = Self::rotate_half(q, &cos, &sin)?;
- let k_rotated = Self::rotate_half(k, &cos, &sin)?;
-
- Ok((q_rotated, k_rotated))
- }
-
- fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
- let half_dim = x.dim(D::Minus1)? / 2;
- let x1 = x.narrow(D::Minus1, 0, half_dim)?;
- let x2 = x.narrow(D::Minus1, half_dim, half_dim)?;
-
- // [-x2, x1] concatenated
- let neg_x2 = x2.neg()?;
- let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?;
-
- // x * cos + rotated * sin
- let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?;
- Ok(result)
- }
-}
-
-// ---------------------------------------------------------------------------
-// KV Cache
-// ---------------------------------------------------------------------------
-
-/// Per-layer key-value cache for autoregressive generation.
-#[derive(Debug, Clone)]
-pub struct KvCache {
- key: Option<Tensor>,
- value: Option<Tensor>,
-}
-
-impl KvCache {
- pub fn new() -> Self {
- Self {
- key: None,
- value: None,
- }
- }
-
- /// Append new key/value tensors and return the full cached sequence.
- /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim]
- pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
- let (full_key, full_value) = match (&self.key, &self.value) {
- (Some(prev_k), Some(prev_v)) => {
- let k = Tensor::cat(&[prev_k, key], 2)?;
- let v = Tensor::cat(&[prev_v, value], 2)?;
- (k, v)
- }
- _ => (key.clone(), value.clone()),
- };
-
- self.key = Some(full_key.clone());
- self.value = Some(full_value.clone());
-
- Ok((full_key, full_value))
- }
-
- /// Current cached sequence length.
- pub fn seq_len(&self) -> usize {
- self.key
- .as_ref()
- .map(|k| k.dim(2).unwrap_or(0))
- .unwrap_or(0)
- }
-
- /// Reset the cache.
- pub fn reset(&mut self) {
- self.key = None;
- self.value = None;
- }
-}
-
-// ---------------------------------------------------------------------------
-// Attention
-// ---------------------------------------------------------------------------
-
-/// Multi-head attention with GQA and RoPE.
-pub struct Qwen3Attention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
- o_proj: Linear,
- q_norm: RmsNorm,
- k_norm: RmsNorm,
- num_heads: usize,
- num_kv_heads: usize,
- head_dim: usize,
- num_kv_groups: usize,
-}
-
-impl Qwen3Attention {
- pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- let hidden = config.hidden_size;
- let num_heads = config.num_attention_heads;
- let num_kv_heads = config.num_key_value_heads;
- let head_dim = config.head_dim;
-
- let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?;
- let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?;
- let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?;
- let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?;
-
- let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?;
- let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?;
-
- Ok(Self {
- q_proj,
- k_proj,
- v_proj,
- o_proj,
- q_norm,
- k_norm,
- num_heads,
- num_kv_heads,
- head_dim,
- num_kv_groups: config.num_kv_groups(),
- })
- }
-
- /// Forward pass with KV cache and RoPE.
- /// Input: [batch, seq_len, hidden_size]
- /// Returns: [batch, seq_len, hidden_size]
- pub fn forward(
- &self,
- hidden_states: &Tensor,
- rope: &RotaryEmbedding,
- kv_cache: &mut KvCache,
- attention_mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- let (batch, seq_len, _) = hidden_states.dims3()?;
- let offset = kv_cache.seq_len();
-
- // Project Q, K, V
- let q = self.q_proj.forward(hidden_states)?;
- let k = self.k_proj.forward(hidden_states)?;
- let v = self.v_proj.forward(hidden_states)?;
-
- // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim]
- let q = q
- .reshape((batch, seq_len, self.num_heads, self.head_dim))?
- .transpose(1, 2)?;
- let k = k
- .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
- .transpose(1, 2)?;
- let v = v
- .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
- .transpose(1, 2)?;
-
- // Apply QK normalization (Qwen3 specific)
- let q = self.apply_head_norm(&q, &self.q_norm)?;
- let k = self.apply_head_norm(&k, &self.k_norm)?;
-
- // Apply RoPE
- let (q, k) = rope.apply(&q, &k, offset)?;
-
- // Update KV cache
- let (k, v) = kv_cache.append(&k, &v)?;
-
- // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim]
- let k = self.repeat_kv(&k)?;
- let v = self.repeat_kv(&v)?;
-
- // Scaled dot-product attention
- let scale = (self.head_dim as f64).sqrt();
- let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?;
-
- let attn_weights = match attention_mask {
- Some(mask) => attn_weights.broadcast_add(mask)?,
- None => attn_weights,
- };
-
- let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
-
- // Attention output
- let attn_output = attn_weights.matmul(&v)?;
-
- // [batch, heads, seq, dim] -> [batch, seq, heads*dim]
- let attn_output = attn_output
- .transpose(1, 2)?
- .reshape((batch, seq_len, self.num_heads * self.head_dim))?;
-
- self.o_proj.forward(&attn_output)
- }
-
- /// Apply RMS norm per-head.
- fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result<Tensor> {
- let (b, h, s, d) = x.dims4()?;
- // Reshape to [b*h*s, d] for norm, then back
- let flat = x.reshape((b * h * s, d))?;
- let normed = norm.forward(&flat)?;
- normed.reshape((b, h, s, d))
- }
-
- /// Repeat KV heads for GQA.
- fn repeat_kv(&self, x: &Tensor) -> Result<Tensor> {
- if self.num_kv_groups == 1 {
- return Ok(x.clone());
- }
- let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(2)?
- .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))?
- .reshape((batch, self.num_heads, seq_len, head_dim))?;
- Ok(x)
- }
-}
-
-// ---------------------------------------------------------------------------
-// MLP
-// ---------------------------------------------------------------------------
-
-/// SiLU-gated feed-forward network.
-pub struct Qwen3Mlp {
- gate_proj: Linear,
- up_proj: Linear,
- down_proj: Linear,
-}
-
-impl Qwen3Mlp {
- pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- let hidden = config.hidden_size;
- let intermediate = config.intermediate_size;
-
- let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?;
- let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?;
- let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?;
-
- Ok(Self {
- gate_proj,
- up_proj,
- down_proj,
- })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let gate = self.gate_proj.forward(x)?;
- let gate = candle_nn::Activation::Silu.forward(&gate)?;
- let up = self.up_proj.forward(x)?;
- let hidden = (gate * up)?;
- self.down_proj.forward(&hidden)
- }
-}
-
-// ---------------------------------------------------------------------------
-// Transformer Layer
-// ---------------------------------------------------------------------------
-
-/// A single Qwen3 transformer decoder layer.
-pub struct Qwen3DecoderLayer {
- self_attn: Qwen3Attention,
- mlp: Qwen3Mlp,
- input_layernorm: RmsNorm,
- post_attention_layernorm: RmsNorm,
-}
-
-impl Qwen3DecoderLayer {
- pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?;
- let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?;
- let input_layernorm =
- rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?;
- let post_attention_layernorm = rms_norm(
- config.hidden_size,
- config.rms_norm_eps,
- vb.pp("post_attention_layernorm"),
- )?;
-
- Ok(Self {
- self_attn,
- mlp,
- input_layernorm,
- post_attention_layernorm,
- })
- }
-
- pub fn forward(
- &self,
- hidden_states: &Tensor,
- rope: &RotaryEmbedding,
- kv_cache: &mut KvCache,
- attention_mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- // Pre-norm attention
- let residual = hidden_states;
- let hidden_states = self.input_layernorm.forward(hidden_states)?;
- let hidden_states =
- self.self_attn
- .forward(&hidden_states, rope, kv_cache, attention_mask)?;
- let hidden_states = (residual + hidden_states)?;
-
- // Pre-norm MLP
- let residual = &hidden_states;
- let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
- let hidden_states = self.mlp.forward(&hidden_states)?;
- let output = (residual + hidden_states)?;
-
- Ok(output)
- }
-}
-
-// ---------------------------------------------------------------------------
-// Full Model
-// ---------------------------------------------------------------------------
-
-/// The complete Qwen3 language model for TTS.
-///
-/// Architecture:
-/// - Token embedding layer
-/// - 28 transformer decoder layers
-/// - Final RMS normalization
-/// - LM head (projects to vocab)
-pub struct Qwen3Model {
- embed_tokens: Embedding,
- layers: Vec<Qwen3DecoderLayer>,
- norm: RmsNorm,
- lm_head: Linear,
- rope: RotaryEmbedding,
- config: Qwen3LmConfig,
- /// Last hidden states (before lm_head), used by code predictor.
- last_hidden: std::cell::RefCell<Option<Tensor>>,
-}
-
-impl Qwen3Model {
- pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- // HuggingFace Qwen3-TTS uses "talker.model.*" prefix
- let talker_vb = vb.pp("talker");
- let model_vb = talker_vb.pp("model");
-
- // Text embedding (called "text_embedding" in HF, not "embed_tokens")
- let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?;
-
- let mut layers = Vec::with_capacity(config.num_hidden_layers);
- for i in 0..config.num_hidden_layers {
- let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
- layers.push(layer);
- }
-
- let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?;
-
- // Codec head (called "codec_head" in HF, not "lm_head")
- let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?;
-
- let dtype = vb.dtype();
- let device = vb.device().clone();
- let rope = RotaryEmbedding::new(config, dtype, &device)?;
-
- Ok(Self {
- embed_tokens,
- layers,
- norm,
- lm_head,
- rope,
- config: config.clone(),
- last_hidden: std::cell::RefCell::new(None),
- })
- }
-
- /// Forward pass through the full model.
- ///
- /// `input_ids`: [batch, seq_len] — token IDs
- /// `kv_caches`: per-layer KV caches
- /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len]
- ///
- /// Returns logits: [batch, seq_len, vocab_size]
- pub fn forward(
- &self,
- input_ids: &Tensor,
- kv_caches: &mut [KvCache],
- attention_mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- let mut hidden_states = self.embed_tokens.forward(input_ids)?;
-
- for (i, layer) in self.layers.iter().enumerate() {
- hidden_states =
- layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
- }
-
- hidden_states = self.norm.forward(&hidden_states)?;
-
- // Store last hidden state for code predictor
- *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
-
- let logits = self.lm_head.forward(&hidden_states)?;
- Ok(logits)
- }
-
- /// Forward pass with pre-computed embeddings (for first iteration where
- /// text embeddings are concatenated with audio features).
- ///
- /// `inputs_embeds`: [batch, seq_len, hidden_size]
- pub fn forward_embeds(
- &self,
- inputs_embeds: &Tensor,
- kv_caches: &mut [KvCache],
- attention_mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- let mut hidden_states = inputs_embeds.clone();
-
- for (i, layer) in self.layers.iter().enumerate() {
- hidden_states =
- layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
- }
-
- hidden_states = self.norm.forward(&hidden_states)?;
-
- *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
-
- let logits = self.lm_head.forward(&hidden_states)?;
- Ok(logits)
- }
-
- /// Get the last hidden states (for the code predictor).
- pub fn last_hidden_state(&self) -> Option<Tensor> {
- self.last_hidden.borrow().clone()
- }
-
- /// Number of transformer layers.
- pub fn num_layers(&self) -> usize {
- self.config.num_hidden_layers
- }
-
- /// Hidden size.
- pub fn hidden_size(&self) -> usize {
- self.config.hidden_size
- }
-
- /// Get token embedding layer (for input preparation).
- pub fn embed_tokens(&self) -> &Embedding {
- &self.embed_tokens
- }
-
- /// Create a causal attention mask.
- pub fn make_causal_mask(
- seq_len: usize,
- past_len: usize,
- dtype: DType,
- device: &Device,
- ) -> Result<Tensor> {
- let total_len = past_len + seq_len;
-
- if seq_len == 1 {
- // Single token: no masking needed (can attend to everything)
- return Tensor::zeros((1, 1, 1, total_len), dtype, device);
- }
-
- // Full causal mask: lower triangular
- let mask: Vec<f32> = (0..seq_len)
- .flat_map(|i| {
- (0..total_len).map(move |j| {
- if j <= past_len + i {
- 0.0
- } else {
- f32::NEG_INFINITY
- }
- })
- })
- .collect();
-
- Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_kv_cache() {
- let device = Device::Cpu;
- let mut cache = KvCache::new();
- assert_eq!(cache.seq_len(), 0);
-
- let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
- let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
- let (fk, _fv) = cache.append(&k, &v).unwrap();
- assert_eq!(cache.seq_len(), 5);
- assert_eq!(fk.dim(2).unwrap(), 5);
-
- let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
- let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
- let (fk2, _fv2) = cache.append(&k2, &v2).unwrap();
- assert_eq!(cache.seq_len(), 6);
- assert_eq!(fk2.dim(2).unwrap(), 6);
-
- cache.reset();
- assert_eq!(cache.seq_len(), 0);
- }
-
- #[test]
- fn test_causal_mask_single_token() {
- let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap();
- assert_eq!(mask.dims(), &[1, 1, 1, 11]);
- // All zeros — single token can attend to everything
- let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
- assert_eq!(sum, 0.0);
- }
-
- #[test]
- fn test_causal_mask_multi_token() {
- let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap();
- assert_eq!(mask.dims(), &[1, 1, 3, 3]);
- // Upper triangle should be -inf
- let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
- // Row 0: [0, -inf, -inf]
- assert_eq!(data[0], 0.0);
- assert!(data[1].is_infinite() && data[1] < 0.0);
- assert!(data[2].is_infinite() && data[2] < 0.0);
- // Row 1: [0, 0, -inf]
- assert_eq!(data[3], 0.0);
- assert_eq!(data[4], 0.0);
- assert!(data[5].is_infinite() && data[5] < 0.0);
- // Row 2: [0, 0, 0]
- assert_eq!(data[6], 0.0);
- assert_eq!(data[7], 0.0);
- assert_eq!(data[8], 0.0);
- }
-}
diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs
deleted file mode 100644
index 86e00f2..0000000
--- a/makima/src/tts/qwen3/speech_tokenizer.rs
+++ /dev/null
@@ -1,613 +0,0 @@
-//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks.
-//!
-//! Two sub-components:
-//!
-//! **Encoder** (voice cloning): converts reference audio waveform to discrete
-//! multi-codebook tokens via a causal 1D ConvNet + RVQ.
-//!
-//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook
-//! indices via embedding lookup + causal 1D ConvNet.
-//!
-//! The speech tokenizer is a separate model (~682MB) loaded from
-//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`.
-
-use candle_core::{Device, Module, Result, Tensor, D};
-use candle_nn::{
- conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder,
-};
-
-use super::config::SpeechTokenizerConfig;
-
-// ---------------------------------------------------------------------------
-// Weight-Normalized Conv1d
-// ---------------------------------------------------------------------------
-
-/// A 1D convolution with optional weight normalization and activation.
-pub struct ConvBlock {
- conv: Conv1d,
- activation: ConvActivation,
-}
-
-#[derive(Debug, Clone, Copy)]
-pub enum ConvActivation {
- None,
- Elu,
- Tanh,
-}
-
-impl ConvBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- stride: usize,
- padding: usize,
- dilation: usize,
- activation: ConvActivation,
- vb: VarBuilder,
- ) -> Result<Self> {
- let config = Conv1dConfig {
- stride,
- padding,
- dilation,
- groups: 1,
- };
- let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?;
-
- Ok(Self { conv, activation })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let out = self.conv.forward(x)?;
- match self.activation {
- ConvActivation::None => Ok(out),
- ConvActivation::Elu => elu(&out, 1.0),
- ConvActivation::Tanh => out.tanh(),
- }
- }
-}
-
-/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0
-fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> {
- let zeros = x.zeros_like()?;
- let positive = x.maximum(&zeros)?;
- let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?;
- let exp_x = x.exp()?;
- let one = Tensor::ones_like(&exp_x)?;
- let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?;
- positive + negative
-}
-
-// ---------------------------------------------------------------------------
-// Residual Unit
-// ---------------------------------------------------------------------------
-
-/// Residual convolutional unit with dilated convolutions.
-pub struct ResidualUnit {
- conv1: ConvBlock,
- conv2: ConvBlock,
-}
-
-impl ResidualUnit {
- pub fn new(
- channels: usize,
- dilation: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- // Dilated causal conv (kernel=7, dilation varies)
- let padding = (7 - 1) * dilation / 2; // causal-ish padding
- let conv1 = ConvBlock::new(
- channels,
- channels,
- 7,
- 1,
- padding,
- dilation,
- ConvActivation::Elu,
- vb.pp("block.0"),
- )?;
-
- // Pointwise conv (kernel=1)
- let conv2 = ConvBlock::new(
- channels,
- channels,
- 1,
- 1,
- 0,
- 1,
- ConvActivation::Elu,
- vb.pp("block.1"),
- )?;
-
- Ok(Self { conv1, conv2 })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let residual = x;
- let out = self.conv1.forward(x)?;
- let out = self.conv2.forward(&out)?;
- // Match sequence lengths if needed (causal conv may change length)
- let out_len = out.dim(D::Minus1)?;
- let res_len = residual.dim(D::Minus1)?;
- if out_len != res_len {
- let start = res_len.saturating_sub(out_len);
- let residual = residual.narrow(D::Minus1, start, out_len)?;
- residual + out
- } else {
- residual + out
- }
- }
-}
-
-// ---------------------------------------------------------------------------
-// Encoder Block
-// ---------------------------------------------------------------------------
-
-/// Encoder downsampling block: residual units + strided conv.
-pub struct EncoderBlock {
- residual_units: Vec<ResidualUnit>,
- downsample: ConvBlock,
-}
-
-impl EncoderBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- stride: usize,
- num_residuals: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let mut residual_units = Vec::with_capacity(num_residuals);
- for i in 0..num_residuals {
- let dilation = 3usize.pow(i as u32); // 1, 3, 9
- let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?;
- residual_units.push(unit);
- }
-
- // Strided downsampling convolution
- let kernel_size = stride * 2;
- let padding = stride / 2;
- let downsample = ConvBlock::new(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- 1,
- ConvActivation::Elu,
- vb.pp("downsample"),
- )?;
-
- Ok(Self {
- residual_units,
- downsample,
- })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let mut out = x.clone();
- for unit in &self.residual_units {
- out = unit.forward(&out)?;
- }
- self.downsample.forward(&out)
- }
-}
-
-// ---------------------------------------------------------------------------
-// Decoder Block
-// ---------------------------------------------------------------------------
-
-/// Decoder upsampling block: transposed conv + residual units.
-pub struct DecoderBlock {
- upsample: ConvBlock,
- residual_units: Vec<ResidualUnit>,
-}
-
-impl DecoderBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- stride: usize,
- num_residuals: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- // Strided upsampling (transpose conv simulated by regular conv + padding)
- let kernel_size = stride * 2;
- let padding = stride / 2;
- let upsample = ConvBlock::new(
- in_channels,
- out_channels,
- kernel_size,
- 1, // stride=1 for output; upsample via repeat/interpolation
- padding,
- 1,
- ConvActivation::Elu,
- vb.pp("upsample"),
- )?;
-
- let mut residual_units = Vec::with_capacity(num_residuals);
- for i in 0..num_residuals {
- let dilation = 3usize.pow(i as u32);
- let unit =
- ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?;
- residual_units.push(unit);
- }
-
- Ok(Self {
- upsample,
- residual_units,
- })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let mut out = self.upsample.forward(x)?;
- for unit in &self.residual_units {
- out = unit.forward(&out)?;
- }
- Ok(out)
- }
-}
-
-// ---------------------------------------------------------------------------
-// RVQ Codebook
-// ---------------------------------------------------------------------------
-
-/// Residual Vector Quantization codebook.
-///
-/// Contains `num_codebooks` embedding tables, each mapping
-/// `codebook_size` indices to `codebook_dim`-dimensional vectors.
-pub struct RvqCodebook {
- codebooks: Vec<Embedding>,
- num_codebooks: usize,
- #[allow(dead_code)]
- codebook_dim: usize,
-}
-
-impl RvqCodebook {
- pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> {
- let mut codebooks = Vec::with_capacity(config.num_codebooks);
- for i in 0..config.num_codebooks {
- let cb = embedding(
- config.codebook_size,
- config.codebook_dim,
- vb.pp(format!("codebooks.{i}")),
- )?;
- codebooks.push(cb);
- }
-
- Ok(Self {
- codebooks,
- num_codebooks: config.num_codebooks,
- codebook_dim: config.codebook_dim,
- })
- }
-
- /// Look up codebook embeddings for all codebook layers.
- ///
- /// `codes`: [num_codebooks, seq_len] — codebook indices per layer
- /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings
- pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
- assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks);
-
- let seq_len = codes[0].len();
- let mut sum: Option<Tensor> = None;
-
- for (i, code_layer) in codes.iter().enumerate() {
- assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch");
-
- let indices = Tensor::from_vec(
- code_layer.clone(),
- (1, seq_len),
- device,
- )?;
-
- // [1, seq_len, codebook_dim]
- let emb = self.codebooks[i].forward(&indices)?;
-
- sum = Some(match sum {
- Some(prev) => (prev + emb)?,
- None => emb,
- });
- }
-
- // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
- let result = sum.unwrap().transpose(1, 2)?;
- Ok(result)
- }
-
- /// Number of codebooks.
- pub fn num_codebooks(&self) -> usize {
- self.num_codebooks
- }
-}
-
-// ---------------------------------------------------------------------------
-// Speech Tokenizer (Encoder + Decoder)
-// ---------------------------------------------------------------------------
-
-/// The complete speech tokenizer with encoder and decoder.
-pub struct SpeechTokenizer {
- /// Encoder: waveform -> latent (for voice cloning).
- encoder_input_conv: ConvBlock,
- encoder_blocks: Vec<EncoderBlock>,
- encoder_output_conv: ConvBlock,
-
- /// RVQ codebooks for quantization.
- codebook: RvqCodebook,
-
- /// Decoder: codes -> waveform.
- decoder_input_conv: ConvBlock,
- decoder_blocks: Vec<DecoderBlock>,
- decoder_output_conv: ConvBlock,
-
- /// Projection from codebook dim to decoder hidden channels.
- decoder_proj: Linear,
-
- config: SpeechTokenizerConfig,
- device: Device,
-}
-
-impl SpeechTokenizer {
- /// Load the speech tokenizer from safetensors.
- pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> {
- let hidden = config.hidden_channels; // 512
-
- // ===== Encoder =====
- // Input: [batch, 1, samples] -> [batch, hidden/8, ...]
- let encoder_input_conv = ConvBlock::new(
- 1,
- hidden / 8, // 64
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Elu,
- vb.pp("encoder.input_conv"),
- )?;
-
- // Downsampling blocks with increasing channels
- let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480
- let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512
- let mut encoder_blocks = Vec::with_capacity(strides.len());
- for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() {
- let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] };
- let block = EncoderBlock::new(
- in_ch,
- out_ch,
- stride,
- 3, // 3 residual units per block
- vb.pp(format!("encoder.blocks.{i}")),
- )?;
- encoder_blocks.push(block);
- }
-
- // Encoder output projection to codebook dim
- let encoder_output_conv = ConvBlock::new(
- hidden,
- config.codebook_dim,
- 3,
- 1,
- 1,
- 1,
- ConvActivation::None,
- vb.pp("encoder.output_conv"),
- )?;
-
- // ===== RVQ Codebook =====
- let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?;
-
- // ===== Decoder =====
- // Projection from codebook dim to decoder hidden
- let decoder_proj = linear_no_bias(
- config.codebook_dim,
- hidden,
- vb.pp("decoder.proj"),
- )?;
-
- // Input conv
- let decoder_input_conv = ConvBlock::new(
- hidden,
- hidden,
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Elu,
- vb.pp("decoder.input_conv"),
- )?;
-
- // Upsampling blocks (reverse order of encoder)
- let dec_strides = [3, 4, 5, 8];
- let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64
- let mut decoder_blocks = Vec::with_capacity(dec_strides.len());
- for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate()
- {
- let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] };
- let block = DecoderBlock::new(
- in_ch,
- out_ch,
- stride,
- 3,
- vb.pp(format!("decoder.blocks.{i}")),
- )?;
- decoder_blocks.push(block);
- }
-
- // Output conv: hidden/8 -> 1 channel (waveform)
- let decoder_output_conv = ConvBlock::new(
- hidden / 8,
- 1,
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Tanh,
- vb.pp("decoder.output_conv"),
- )?;
-
- Ok(Self {
- encoder_input_conv,
- encoder_blocks,
- encoder_output_conv,
- codebook,
- decoder_input_conv,
- decoder_blocks,
- decoder_output_conv,
- decoder_proj,
- config: config.clone(),
- device: device.clone(),
- })
- }
-
- /// Encode reference audio waveform to discrete codebook tokens.
- ///
- /// `audio`: [num_samples] — mono 24kHz audio
- /// Returns: Vec of `num_codebooks` vectors, each containing token indices.
- pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> {
- // [1, 1, num_samples]
- let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?;
-
- // Run encoder
- let mut hidden = self.encoder_input_conv.forward(&x)?;
- for block in &self.encoder_blocks {
- hidden = block.forward(&hidden)?;
- }
- let latent = self.encoder_output_conv.forward(&hidden)?;
-
- // latent: [1, codebook_dim, seq_len]
- // Quantize via nearest-neighbor lookup in each codebook
- let seq_len = latent.dim(D::Minus1)?;
- let mut all_codes = Vec::with_capacity(self.config.num_codebooks);
-
- // Residual quantization: subtract each codebook's contribution
- let mut residual = latent.clone();
-
- for cb_idx in 0..self.config.num_codebooks {
- // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep
- let codes = self.quantize_layer(&residual, cb_idx, seq_len)?;
-
- // Look up the quantized vectors and subtract from residual
- let code_indices =
- Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?;
- let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?;
- // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
- let quantized = quantized.transpose(1, 2)?;
- residual = (residual - quantized)?;
-
- all_codes.push(codes);
- }
-
- Ok(all_codes)
- }
-
- /// Quantize a single RVQ layer by finding the nearest codebook entry.
- fn quantize_layer(
- &self,
- residual: &Tensor,
- codebook_idx: usize,
- _seq_len: usize,
- ) -> Result<Vec<u32>> {
- // residual: [1, codebook_dim, seq_len]
- // codebook weights: [codebook_size, codebook_dim]
- let cb_weight = self.codebook.codebooks[codebook_idx]
- .embeddings()
- .clone(); // [codebook_size, codebook_dim]
-
- // Transpose residual: [1, seq_len, codebook_dim]
- let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim]
-
- // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2
- let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len]
- let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size]
- let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size]
-
- let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1]
- let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size]
-
- let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size]
-
- // Argmin per timestep
- let indices = distances.argmin(D::Minus1)?; // [seq_len]
- let codes: Vec<u32> = indices.to_vec1()?;
-
- Ok(codes)
- }
-
- /// Decode discrete codebook tokens to audio waveform.
- ///
- /// `codes`: Vec of `num_codebooks` vectors of token indices.
- /// Returns: Vec<f32> — mono 24kHz audio samples.
- pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> {
- // Look up and sum all codebook embeddings
- let embeddings = self.codebook.decode(codes, &self.device)?;
- // embeddings: [1, codebook_dim, seq_len]
-
- // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden]
- let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim]
- let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden]
- let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len]
-
- // Run decoder
- hidden = self.decoder_input_conv.forward(&hidden)?;
- for block in &self.decoder_blocks {
- hidden = block.forward(&hidden)?;
- }
- let waveform = self.decoder_output_conv.forward(&hidden)?;
-
- // [1, 1, num_samples] -> Vec<f32>
- let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?;
- Ok(samples)
- }
-
- /// Decode a single frame's codes to audio samples (for streaming).
- ///
- /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame
- /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz)
- pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> {
- let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect();
- self.decode(&codes)
- }
-
- /// Get the number of codebooks.
- pub fn num_codebooks(&self) -> usize {
- self.config.num_codebooks
- }
-
- /// Get the output sample rate.
- pub fn sample_rate(&self) -> u32 {
- self.config.sample_rate
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_elu_positive() {
- let device = Device::Cpu;
- let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
- let result = elu(&x, 1.0).unwrap();
- let values: Vec<f32> = result.to_vec1().unwrap();
- assert!((values[0] - 1.0).abs() < 1e-5);
- assert!((values[1] - 2.0).abs() < 1e-5);
- }
-
- #[test]
- fn test_elu_negative() {
- let device = Device::Cpu;
- let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap();
- let result = elu(&x, 1.0).unwrap();
- let values: Vec<f32> = result.to_vec1().unwrap();
- // ELU(-1) = exp(-1) - 1 ≈ -0.6321
- assert!((values[0] - (-0.6321)).abs() < 0.01);
- }
-
- #[test]
- fn test_speech_tokenizer_config() {
- let config = SpeechTokenizerConfig::default();
- assert_eq!(config.num_codebooks, 16);
- assert_eq!(config.codebook_size, 2048);
- assert_eq!(config.sample_rate, 24_000);
- }
-}