diff options
| author | soryu <soryu@soryu.co> | 2026-02-02 22:52:05 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2026-02-02 22:52:05 +0000 |
| commit | 0f06a7f9968816e5e2553c4f1c2104f2fa504f96 (patch) | |
| tree | 53d8db119c17d7d22f3127ae5a54e12a3f384e29 /makima/src/tts/qwen3/generate.rs | |
| parent | 151e9d87e117b7980e6aad522ac8f3633eeca87a (diff) | |
| download | soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.tar.gz soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.zip | |
Release in makima repo
Also remove all other TTS models
Diffstat (limited to 'makima/src/tts/qwen3/generate.rs')
| -rw-r--r-- | makima/src/tts/qwen3/generate.rs | 456 |
1 files changed, 0 insertions, 456 deletions
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs deleted file mode 100644 index 30d165b..0000000 --- a/makima/src/tts/qwen3/generate.rs +++ /dev/null @@ -1,456 +0,0 @@ -//! Autoregressive generation loop for Qwen3-TTS. -//! -//! Orchestrates the full inference pipeline: -//! 1. Encode reference audio → speaker embedding via speech tokenizer -//! 2. Tokenize text → token IDs -//! 3. Autoregressive LM generation → zeroth codebook tokens -//! 4. Code predictor → remaining 15 codebook tokens per frame -//! 5. Speech tokenizer decoder → waveform audio - -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device, IndexOp, Result, Tensor}; -use tokenizers::Tokenizer; - -use super::code_predictor::CodePredictor; -use super::model::{KvCache, Qwen3Model}; -use super::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE}; - -/// Special tokens for the Qwen3-TTS vocabulary. -pub const BOS_TOKEN_ID: u32 = 151_643; -pub const EOS_TOKEN_ID: u32 = 151_645; -pub const PAD_TOKEN_ID: u32 = 151_643; - -/// Speech-specific control tokens. -/// These are placeholders — actual values come from the tokenizer config. -pub const START_OF_SPEECH: u32 = 151_668; -pub const END_OF_SPEECH: u32 = 151_669; - -/// Generation configuration. -#[derive(Debug, Clone)] -pub struct GenerationConfig { - /// Maximum number of speech tokens to generate. - pub max_new_tokens: usize, - /// Temperature for sampling (1.0 = greedy if top_k=1). - pub temperature: f32, - /// Top-k sampling (0 = disabled, use greedy argmax). - pub top_k: usize, - /// Repetition penalty. - pub repetition_penalty: f32, - /// Whether to generate audio chunks incrementally (streaming). - pub streaming: bool, -} - -impl Default for GenerationConfig { - fn default() -> Self { - Self { - max_new_tokens: 2048, - temperature: 1.0, - top_k: 0, // Greedy by default - repetition_penalty: 1.2, - streaming: false, - } - } -} - -/// Manages the full generation pipeline. -pub struct GenerationContext<'a> { - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - /// Optional cancellation flag. When set to `true`, the generation loop - /// will break early and return whatever audio has been produced so far. - cancel_flag: Option<Arc<AtomicBool>>, -} - -impl<'a> GenerationContext<'a> { - pub fn new( - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Self { - Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - device, - config, - cancel_flag, - } - } - - /// Check whether cancellation has been requested. - fn is_cancelled(&self) -> bool { - self.cancel_flag - .as_ref() - .map_or(false, |f| f.load(Ordering::Relaxed)) - } - - /// Generate audio from text, optionally with a voice reference. - /// - /// Returns a list of audio chunks. If `streaming` is false, returns - /// a single chunk with the complete audio. - pub fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - // 1. Encode reference audio if provided - let reference_codes = match reference_audio { - Some(audio) => Some( - self.speech_tokenizer - .encode(audio) - .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?, - ), - None => None, - }; - - // 2. Tokenize text - let encoding = self - .tokenizer - .encode(text, true) - .map_err(|e| TtsError::Tokenizer(e.to_string()))?; - - let text_token_ids: Vec<u32> = encoding.get_ids().to_vec(); - - // 3. Prepare input sequence - // Format: [BOS] [text_tokens] [START_OF_SPEECH] - let mut input_ids = Vec::new(); - input_ids.push(BOS_TOKEN_ID); - input_ids.extend_from_slice(&text_token_ids); - input_ids.push(START_OF_SPEECH); - - // 4. Run autoregressive generation - let generated_frames = self - .autoregressive_generate(&input_ids, reference_codes.as_deref()) - .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?; - - if generated_frames.is_empty() { - return Ok(vec![AudioChunk { - samples: vec![], - sample_rate: SAMPLE_RATE, - is_final: true, - }]); - } - - // 5. Decode all frames to audio - if self.config.streaming { - self.decode_streaming(&generated_frames) - } else { - self.decode_batch(&generated_frames) - } - } - - /// Autoregressive generation loop. - /// - /// Generates zeroth codebook tokens one at a time, then uses the code - /// predictor to fill in the remaining 15 codebooks per frame. - /// - /// Returns: Vec of frames, each frame is [num_codebooks] tokens. - fn autoregressive_generate( - &self, - input_ids: &[u32], - _reference_codes: Option<&[Vec<u32>]>, - ) -> Result<Vec<Vec<u32>>> { - let _num_codebooks = self.code_predictor.num_code_groups(); - let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers()) - .map(|_| KvCache::new()) - .collect(); - - let mut generated_frames: Vec<Vec<u32>> = Vec::new(); - let mut past_zeroth_tokens: Vec<u32> = Vec::new(); - - // === First iteration: process the full input sequence === - let input_tensor = Tensor::from_vec( - input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(), - (1, input_ids.len()), - self.device, - )? - .to_dtype(DType::I64)?; - - let seq_len = input_ids.len(); - let attention_mask = - Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?; - - let logits = - self.model - .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?; - - // Get the logits for the last position - let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size] - let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?; - - if first_token == END_OF_SPEECH as u32 { - return Ok(generated_frames); - } - - // Use code predictor for all codebooks - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?; - - let frame_codes = self - .code_predictor - .predict(&last_hidden, first_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(first_token); - - // === Subsequent iterations: one token at a time === - for _step in 1..self.config.max_new_tokens { - // Check for cancellation each iteration - if self.is_cancelled() { - tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); - break; - } - - let past_len = kv_caches[0].seq_len(); - - // Input: just the last generated zeroth codebook token - let last_token = *past_zeroth_tokens.last().unwrap(); - let token_tensor = Tensor::from_vec( - vec![last_token as i64], - (1, 1), - self.device, - )? - .to_dtype(DType::I64)?; - - // Single-token attention mask - let attention_mask = - Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?; - - let logits = - self.model - .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?; - - let next_logits = logits.i((0, 0, ..))?; // [vocab_size] - let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?; - - if next_token == END_OF_SPEECH as u32 { - break; - } - - // Predict all codebooks for this frame - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - - let frame_codes = self - .code_predictor - .predict(&lm_hidden, next_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(next_token); - } - - Ok(generated_frames) - } - - /// Sample a token from logits. - fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> { - let mut logits_vec: Vec<f32> = logits.to_vec1()?; - - // Apply repetition penalty - if self.config.repetition_penalty != 1.0 { - for &token in past_tokens { - let idx = token as usize; - if idx < logits_vec.len() { - let score = logits_vec[idx]; - logits_vec[idx] = if score < 0.0 { - score * self.config.repetition_penalty - } else { - score / self.config.repetition_penalty - }; - } - } - } - - if self.config.top_k == 0 || self.config.temperature == 0.0 { - // Greedy: argmax - let (max_idx, _) = logits_vec - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap_or((0, &0.0)); - Ok(max_idx as u32) - } else { - // Top-k sampling with temperature - let temperature = self.config.temperature; - - // Apply temperature - for v in logits_vec.iter_mut() { - *v /= temperature; - } - - // Sort indices by logit value (descending) - let mut indexed: Vec<(usize, f32)> = - logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); - indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Keep only top-k - let k = self.config.top_k.min(indexed.len()); - let top_k = &indexed[..k]; - - // Softmax over top-k - let max_val = top_k[0].1; - let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum(); - let probs: Vec<(usize, f32)> = top_k - .iter() - .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum)) - .collect(); - - // Sample from distribution (simple linear scan) - let r: f32 = random_float(); - let mut cumulative = 0.0; - for (idx, prob) in &probs { - cumulative += prob; - if cumulative >= r { - return Ok(*idx as u32); - } - } - - // Fallback to highest probability - Ok(probs[0].0 as u32) - } - } - - /// Decode all frames in batch (non-streaming). - fn decode_batch( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames] - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frames { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?; - - Ok(vec![AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: true, - }]) - } - - /// Decode frames incrementally (streaming). - fn decode_streaming( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let mut chunks: Vec<AudioChunk> = Vec::new(); - - // Decode in groups of frames for efficiency - let chunk_size = 10; // ~800ms per chunk at 12.5Hz - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { - // Check for cancellation between streaming chunks - if self.is_cancelled() { - tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); - if let Some(last) = chunks.last_mut() { - last.is_final = true; - } - return Ok(chunks); - } - - let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); - - // Transpose chunk frames - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frame_chunk { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?; - - chunks.push(AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: is_last, - }); - } - - Ok(chunks) - } -} - -/// Simple pseudo-random float in [0, 1) using thread-local state. -/// Uses a basic xorshift for reproducibility without external deps. -fn random_float() -> f32 { - use std::cell::Cell; - thread_local! { - static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0); - } - - STATE.with(|s| { - let mut x = s.get(); - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - s.set(x); - (x as f32) / (u64::MAX as f32) - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generation_config_default() { - let config = GenerationConfig::default(); - assert_eq!(config.max_new_tokens, 2048); - assert_eq!(config.top_k, 0); - assert_eq!(config.temperature, 1.0); - assert_eq!(config.repetition_penalty, 1.2); - assert!(!config.streaming); - } - - #[test] - fn test_random_float_range() { - for _ in 0..100 { - let r = random_float(); - assert!(r >= 0.0); - assert!(r < 1.0); - } - } - - #[test] - fn test_special_tokens() { - assert_eq!(BOS_TOKEN_ID, 151_643); - assert_eq!(EOS_TOKEN_ID, 151_645); - assert_eq!(START_OF_SPEECH, 151_668); - assert_eq!(END_OF_SPEECH, 151_669); - } -} |
