diff options
Diffstat (limited to 'makima/src/tts/qwen3/generate.rs')
| -rw-r--r-- | makima/src/tts/qwen3/generate.rs | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs new file mode 100644 index 0000000..02161e6 --- /dev/null +++ b/makima/src/tts/qwen3/generate.rs @@ -0,0 +1,426 @@ +//! Autoregressive generation loop for Qwen3-TTS. +//! +//! Orchestrates the full inference pipeline: +//! 1. Encode reference audio → speaker embedding via speech tokenizer +//! 2. Tokenize text → token IDs +//! 3. Autoregressive LM generation → zeroth codebook tokens +//! 4. Code predictor → remaining 15 codebook tokens per frame +//! 5. Speech tokenizer decoder → waveform audio + +use candle_core::{DType, Device, IndexOp, Result, Tensor}; +use tokenizers::Tokenizer; + +use super::code_predictor::CodePredictor; +use super::model::{KvCache, Qwen3Model}; +use super::speech_tokenizer::SpeechTokenizer; +use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE}; + +/// Special tokens for the Qwen3-TTS vocabulary. +pub const BOS_TOKEN_ID: u32 = 151_643; +pub const EOS_TOKEN_ID: u32 = 151_645; +pub const PAD_TOKEN_ID: u32 = 151_643; + +/// Speech-specific control tokens. +/// These are placeholders — actual values come from the tokenizer config. +pub const START_OF_SPEECH: u32 = 151_668; +pub const END_OF_SPEECH: u32 = 151_669; + +/// Generation configuration. +#[derive(Debug, Clone)] +pub struct GenerationConfig { + /// Maximum number of speech tokens to generate. + pub max_new_tokens: usize, + /// Temperature for sampling (1.0 = greedy if top_k=1). + pub temperature: f32, + /// Top-k sampling (0 = disabled, use greedy argmax). + pub top_k: usize, + /// Repetition penalty. + pub repetition_penalty: f32, + /// Whether to generate audio chunks incrementally (streaming). + pub streaming: bool, +} + +impl Default for GenerationConfig { + fn default() -> Self { + Self { + max_new_tokens: 2048, + temperature: 1.0, + top_k: 0, // Greedy by default + repetition_penalty: 1.2, + streaming: false, + } + } +} + +/// Manages the full generation pipeline. +pub struct GenerationContext<'a> { + model: &'a Qwen3Model, + code_predictor: &'a CodePredictor, + speech_tokenizer: &'a SpeechTokenizer, + tokenizer: &'a Tokenizer, + device: &'a Device, + config: GenerationConfig, +} + +impl<'a> GenerationContext<'a> { + pub fn new( + model: &'a Qwen3Model, + code_predictor: &'a CodePredictor, + speech_tokenizer: &'a SpeechTokenizer, + tokenizer: &'a Tokenizer, + device: &'a Device, + config: GenerationConfig, + ) -> Self { + Self { + model, + code_predictor, + speech_tokenizer, + tokenizer, + device, + config, + } + } + + /// Generate audio from text, optionally with a voice reference. + /// + /// Returns a list of audio chunks. If `streaming` is false, returns + /// a single chunk with the complete audio. + pub fn generate( + &self, + text: &str, + reference_audio: Option<&[f32]>, + ) -> std::result::Result<Vec<AudioChunk>, TtsError> { + // 1. Encode reference audio if provided + let reference_codes = match reference_audio { + Some(audio) => Some( + self.speech_tokenizer + .encode(audio) + .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?, + ), + None => None, + }; + + // 2. Tokenize text + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| TtsError::Tokenizer(e.to_string()))?; + + let text_token_ids: Vec<u32> = encoding.get_ids().to_vec(); + + // 3. Prepare input sequence + // Format: [BOS] [text_tokens] [START_OF_SPEECH] + let mut input_ids = Vec::new(); + input_ids.push(BOS_TOKEN_ID); + input_ids.extend_from_slice(&text_token_ids); + input_ids.push(START_OF_SPEECH); + + // 4. Run autoregressive generation + let generated_frames = self + .autoregressive_generate(&input_ids, reference_codes.as_deref()) + .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?; + + if generated_frames.is_empty() { + return Ok(vec![AudioChunk { + samples: vec![], + sample_rate: SAMPLE_RATE, + is_final: true, + }]); + } + + // 5. Decode all frames to audio + if self.config.streaming { + self.decode_streaming(&generated_frames) + } else { + self.decode_batch(&generated_frames) + } + } + + /// Autoregressive generation loop. + /// + /// Generates zeroth codebook tokens one at a time, then uses the code + /// predictor to fill in the remaining 15 codebooks per frame. + /// + /// Returns: Vec of frames, each frame is [num_codebooks] tokens. + fn autoregressive_generate( + &self, + input_ids: &[u32], + _reference_codes: Option<&[Vec<u32>]>, + ) -> Result<Vec<Vec<u32>>> { + let _num_codebooks = self.code_predictor.num_code_groups(); + let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers()) + .map(|_| KvCache::new()) + .collect(); + + let mut generated_frames: Vec<Vec<u32>> = Vec::new(); + let mut past_zeroth_tokens: Vec<u32> = Vec::new(); + + // === First iteration: process the full input sequence === + let input_tensor = Tensor::from_vec( + input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(), + (1, input_ids.len()), + self.device, + )? + .to_dtype(DType::I64)?; + + let seq_len = input_ids.len(); + let attention_mask = + Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?; + + let logits = + self.model + .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?; + + // Get the logits for the last position + let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size] + let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?; + + if first_token == END_OF_SPEECH as u32 { + return Ok(generated_frames); + } + + // Use code predictor for all codebooks + let lm_hidden = self + .model + .last_hidden_state() + .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; + let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?; + + let frame_codes = self + .code_predictor + .predict(&last_hidden, first_token, self.device)?; + generated_frames.push(frame_codes); + past_zeroth_tokens.push(first_token); + + // === Subsequent iterations: one token at a time === + for _step in 1..self.config.max_new_tokens { + let past_len = kv_caches[0].seq_len(); + + // Input: just the last generated zeroth codebook token + let last_token = *past_zeroth_tokens.last().unwrap(); + let token_tensor = Tensor::from_vec( + vec![last_token as i64], + (1, 1), + self.device, + )? + .to_dtype(DType::I64)?; + + // Single-token attention mask + let attention_mask = + Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?; + + let logits = + self.model + .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?; + + let next_logits = logits.i((0, 0, ..))?; // [vocab_size] + let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?; + + if next_token == END_OF_SPEECH as u32 { + break; + } + + // Predict all codebooks for this frame + let lm_hidden = self + .model + .last_hidden_state() + .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; + + let frame_codes = self + .code_predictor + .predict(&lm_hidden, next_token, self.device)?; + generated_frames.push(frame_codes); + past_zeroth_tokens.push(next_token); + } + + Ok(generated_frames) + } + + /// Sample a token from logits. + fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> { + let mut logits_vec: Vec<f32> = logits.to_vec1()?; + + // Apply repetition penalty + if self.config.repetition_penalty != 1.0 { + for &token in past_tokens { + let idx = token as usize; + if idx < logits_vec.len() { + let score = logits_vec[idx]; + logits_vec[idx] = if score < 0.0 { + score * self.config.repetition_penalty + } else { + score / self.config.repetition_penalty + }; + } + } + } + + if self.config.top_k == 0 || self.config.temperature == 0.0 { + // Greedy: argmax + let (max_idx, _) = logits_vec + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| { + a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) + }) + .unwrap_or((0, &0.0)); + Ok(max_idx as u32) + } else { + // Top-k sampling with temperature + let temperature = self.config.temperature; + + // Apply temperature + for v in logits_vec.iter_mut() { + *v /= temperature; + } + + // Sort indices by logit value (descending) + let mut indexed: Vec<(usize, f32)> = + logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); + indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Keep only top-k + let k = self.config.top_k.min(indexed.len()); + let top_k = &indexed[..k]; + + // Softmax over top-k + let max_val = top_k[0].1; + let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum(); + let probs: Vec<(usize, f32)> = top_k + .iter() + .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum)) + .collect(); + + // Sample from distribution (simple linear scan) + let r: f32 = random_float(); + let mut cumulative = 0.0; + for (idx, prob) in &probs { + cumulative += prob; + if cumulative >= r { + return Ok(*idx as u32); + } + } + + // Fallback to highest probability + Ok(probs[0].0 as u32) + } + } + + /// Decode all frames in batch (non-streaming). + fn decode_batch( + &self, + frames: &[Vec<u32>], + ) -> std::result::Result<Vec<AudioChunk>, TtsError> { + let num_codebooks = self.speech_tokenizer.num_codebooks(); + + // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames] + let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; + for frame in frames { + for (cb_idx, &code) in frame.iter().enumerate() { + if cb_idx < num_codebooks { + codes_by_codebook[cb_idx].push(code); + } + } + } + + let samples = self + .speech_tokenizer + .decode(&codes_by_codebook) + .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?; + + Ok(vec![AudioChunk { + samples, + sample_rate: SAMPLE_RATE, + is_final: true, + }]) + } + + /// Decode frames incrementally (streaming). + fn decode_streaming( + &self, + frames: &[Vec<u32>], + ) -> std::result::Result<Vec<AudioChunk>, TtsError> { + let mut chunks = Vec::new(); + + // Decode in groups of frames for efficiency + let chunk_size = 10; // ~800ms per chunk at 12.5Hz + let num_codebooks = self.speech_tokenizer.num_codebooks(); + + for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { + let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); + + // Transpose chunk frames + let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; + for frame in frame_chunk { + for (cb_idx, &code) in frame.iter().enumerate() { + if cb_idx < num_codebooks { + codes_by_codebook[cb_idx].push(code); + } + } + } + + let samples = self + .speech_tokenizer + .decode(&codes_by_codebook) + .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?; + + chunks.push(AudioChunk { + samples, + sample_rate: SAMPLE_RATE, + is_final: is_last, + }); + } + + Ok(chunks) + } +} + +/// Simple pseudo-random float in [0, 1) using thread-local state. +/// Uses a basic xorshift for reproducibility without external deps. +fn random_float() -> f32 { + use std::cell::Cell; + thread_local! { + static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0); + } + + STATE.with(|s| { + let mut x = s.get(); + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + s.set(x); + (x as f32) / (u64::MAX as f32) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generation_config_default() { + let config = GenerationConfig::default(); + assert_eq!(config.max_new_tokens, 2048); + assert_eq!(config.top_k, 0); + assert_eq!(config.temperature, 1.0); + assert_eq!(config.repetition_penalty, 1.2); + assert!(!config.streaming); + } + + #[test] + fn test_random_float_range() { + for _ in 0..100 { + let r = random_float(); + assert!(r >= 0.0); + assert!(r < 1.0); + } + } + + #[test] + fn test_special_tokens() { + assert_eq!(BOS_TOKEN_ID, 151_643); + assert_eq!(EOS_TOKEN_ID, 151_645); + assert_eq!(START_OF_SPEECH, 151_668); + assert_eq!(END_OF_SPEECH, 151_669); + } +} |
