//! Autoregressive generation loop for Qwen3-TTS. //! //! Orchestrates the full inference pipeline: //! 1. Encode reference audio → speaker embedding via speech tokenizer //! 2. Tokenize text → token IDs //! 3. Autoregressive LM generation → zeroth codebook tokens //! 4. Code predictor → remaining 15 codebook tokens per frame //! 5. Speech tokenizer decoder → waveform audio use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use candle_core::{DType, Device, IndexOp, Result, Tensor}; use tokenizers::Tokenizer; use super::code_predictor::CodePredictor; use super::model::{KvCache, Qwen3Model}; use super::speech_tokenizer::SpeechTokenizer; use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE}; /// Special tokens for the Qwen3-TTS vocabulary. pub const BOS_TOKEN_ID: u32 = 151_643; pub const EOS_TOKEN_ID: u32 = 151_645; pub const PAD_TOKEN_ID: u32 = 151_643; /// Speech-specific control tokens. /// These are placeholders — actual values come from the tokenizer config. pub const START_OF_SPEECH: u32 = 151_668; pub const END_OF_SPEECH: u32 = 151_669; /// Generation configuration. #[derive(Debug, Clone)] pub struct GenerationConfig { /// Maximum number of speech tokens to generate. pub max_new_tokens: usize, /// Temperature for sampling (1.0 = greedy if top_k=1). pub temperature: f32, /// Top-k sampling (0 = disabled, use greedy argmax). pub top_k: usize, /// Repetition penalty. pub repetition_penalty: f32, /// Whether to generate audio chunks incrementally (streaming). pub streaming: bool, } impl Default for GenerationConfig { fn default() -> Self { Self { max_new_tokens: 2048, temperature: 1.0, top_k: 0, // Greedy by default repetition_penalty: 1.2, streaming: false, } } } /// Manages the full generation pipeline. pub struct GenerationContext<'a> { model: &'a Qwen3Model, code_predictor: &'a CodePredictor, speech_tokenizer: &'a SpeechTokenizer, tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, /// Optional cancellation flag. When set to `true`, the generation loop /// will break early and return whatever audio has been produced so far. cancel_flag: Option>, } impl<'a> GenerationContext<'a> { pub fn new( model: &'a Qwen3Model, code_predictor: &'a CodePredictor, speech_tokenizer: &'a SpeechTokenizer, tokenizer: &'a Tokenizer, device: &'a Device, config: GenerationConfig, cancel_flag: Option>, ) -> Self { Self { model, code_predictor, speech_tokenizer, tokenizer, device, config, cancel_flag, } } /// Check whether cancellation has been requested. fn is_cancelled(&self) -> bool { self.cancel_flag .as_ref() .map_or(false, |f| f.load(Ordering::Relaxed)) } /// Generate audio from text, optionally with a voice reference. /// /// Returns a list of audio chunks. If `streaming` is false, returns /// a single chunk with the complete audio. pub fn generate( &self, text: &str, reference_audio: Option<&[f32]>, ) -> std::result::Result, TtsError> { // 1. Encode reference audio if provided let reference_codes = match reference_audio { Some(audio) => Some( self.speech_tokenizer .encode(audio) .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?, ), None => None, }; // 2. Tokenize text let encoding = self .tokenizer .encode(text, true) .map_err(|e| TtsError::Tokenizer(e.to_string()))?; let text_token_ids: Vec = encoding.get_ids().to_vec(); // 3. Prepare input sequence // Format: [BOS] [text_tokens] [START_OF_SPEECH] let mut input_ids = Vec::new(); input_ids.push(BOS_TOKEN_ID); input_ids.extend_from_slice(&text_token_ids); input_ids.push(START_OF_SPEECH); // 4. Run autoregressive generation let generated_frames = self .autoregressive_generate(&input_ids, reference_codes.as_deref()) .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?; if generated_frames.is_empty() { return Ok(vec![AudioChunk { samples: vec![], sample_rate: SAMPLE_RATE, is_final: true, }]); } // 5. Decode all frames to audio if self.config.streaming { self.decode_streaming(&generated_frames) } else { self.decode_batch(&generated_frames) } } /// Autoregressive generation loop. /// /// Generates zeroth codebook tokens one at a time, then uses the code /// predictor to fill in the remaining 15 codebooks per frame. /// /// Returns: Vec of frames, each frame is [num_codebooks] tokens. fn autoregressive_generate( &self, input_ids: &[u32], _reference_codes: Option<&[Vec]>, ) -> Result>> { let _num_codebooks = self.code_predictor.num_code_groups(); let mut kv_caches: Vec = (0..self.model.num_layers()) .map(|_| KvCache::new()) .collect(); let mut generated_frames: Vec> = Vec::new(); let mut past_zeroth_tokens: Vec = Vec::new(); // === First iteration: process the full input sequence === let input_tensor = Tensor::from_vec( input_ids.iter().map(|&x| x as i64).collect::>(), (1, input_ids.len()), self.device, )? .to_dtype(DType::I64)?; let seq_len = input_ids.len(); let attention_mask = Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?; let logits = self.model .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?; // Get the logits for the last position let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size] let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?; if first_token == END_OF_SPEECH as u32 { return Ok(generated_frames); } // Use code predictor for all codebooks let lm_hidden = self .model .last_hidden_state() .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?; let frame_codes = self .code_predictor .predict(&last_hidden, first_token, self.device)?; generated_frames.push(frame_codes); past_zeroth_tokens.push(first_token); // === Subsequent iterations: one token at a time === for _step in 1..self.config.max_new_tokens { // Check for cancellation each iteration if self.is_cancelled() { tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); break; } let past_len = kv_caches[0].seq_len(); // Input: just the last generated zeroth codebook token let last_token = *past_zeroth_tokens.last().unwrap(); let token_tensor = Tensor::from_vec( vec![last_token as i64], (1, 1), self.device, )? .to_dtype(DType::I64)?; // Single-token attention mask let attention_mask = Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?; let logits = self.model .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?; let next_logits = logits.i((0, 0, ..))?; // [vocab_size] let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?; if next_token == END_OF_SPEECH as u32 { break; } // Predict all codebooks for this frame let lm_hidden = self .model .last_hidden_state() .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; let frame_codes = self .code_predictor .predict(&lm_hidden, next_token, self.device)?; generated_frames.push(frame_codes); past_zeroth_tokens.push(next_token); } Ok(generated_frames) } /// Sample a token from logits. fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result { let mut logits_vec: Vec = logits.to_vec1()?; // Apply repetition penalty if self.config.repetition_penalty != 1.0 { for &token in past_tokens { let idx = token as usize; if idx < logits_vec.len() { let score = logits_vec[idx]; logits_vec[idx] = if score < 0.0 { score * self.config.repetition_penalty } else { score / self.config.repetition_penalty }; } } } if self.config.top_k == 0 || self.config.temperature == 0.0 { // Greedy: argmax let (max_idx, _) = logits_vec .iter() .enumerate() .max_by(|(_, a), (_, b)| { a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) }) .unwrap_or((0, &0.0)); Ok(max_idx as u32) } else { // Top-k sampling with temperature let temperature = self.config.temperature; // Apply temperature for v in logits_vec.iter_mut() { *v /= temperature; } // Sort indices by logit value (descending) let mut indexed: Vec<(usize, f32)> = logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); // Keep only top-k let k = self.config.top_k.min(indexed.len()); let top_k = &indexed[..k]; // Softmax over top-k let max_val = top_k[0].1; let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::>().iter().sum(); let probs: Vec<(usize, f32)> = top_k .iter() .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum)) .collect(); // Sample from distribution (simple linear scan) let r: f32 = random_float(); let mut cumulative = 0.0; for (idx, prob) in &probs { cumulative += prob; if cumulative >= r { return Ok(*idx as u32); } } // Fallback to highest probability Ok(probs[0].0 as u32) } } /// Decode all frames in batch (non-streaming). fn decode_batch( &self, frames: &[Vec], ) -> std::result::Result, TtsError> { let num_codebooks = self.speech_tokenizer.num_codebooks(); // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames] let mut codes_by_codebook: Vec> = vec![Vec::new(); num_codebooks]; for frame in frames { for (cb_idx, &code) in frame.iter().enumerate() { if cb_idx < num_codebooks { codes_by_codebook[cb_idx].push(code); } } } let samples = self .speech_tokenizer .decode(&codes_by_codebook) .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?; Ok(vec![AudioChunk { samples, sample_rate: SAMPLE_RATE, is_final: true, }]) } /// Decode frames incrementally (streaming). fn decode_streaming( &self, frames: &[Vec], ) -> std::result::Result, TtsError> { let mut chunks: Vec = Vec::new(); // Decode in groups of frames for efficiency let chunk_size = 10; // ~800ms per chunk at 12.5Hz let num_codebooks = self.speech_tokenizer.num_codebooks(); for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { // Check for cancellation between streaming chunks if self.is_cancelled() { tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); if let Some(last) = chunks.last_mut() { last.is_final = true; } return Ok(chunks); } let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); // Transpose chunk frames let mut codes_by_codebook: Vec> = vec![Vec::new(); num_codebooks]; for frame in frame_chunk { for (cb_idx, &code) in frame.iter().enumerate() { if cb_idx < num_codebooks { codes_by_codebook[cb_idx].push(code); } } } let samples = self .speech_tokenizer .decode(&codes_by_codebook) .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?; chunks.push(AudioChunk { samples, sample_rate: SAMPLE_RATE, is_final: is_last, }); } Ok(chunks) } } /// Simple pseudo-random float in [0, 1) using thread-local state. /// Uses a basic xorshift for reproducibility without external deps. fn random_float() -> f32 { use std::cell::Cell; thread_local! { static STATE: Cell = Cell::new(0x12345678_9ABCDEF0); } STATE.with(|s| { let mut x = s.get(); x ^= x << 13; x ^= x >> 7; x ^= x << 17; s.set(x); (x as f32) / (u64::MAX as f32) }) } #[cfg(test)] mod tests { use super::*; #[test] fn test_generation_config_default() { let config = GenerationConfig::default(); assert_eq!(config.max_new_tokens, 2048); assert_eq!(config.top_k, 0); assert_eq!(config.temperature, 1.0); assert_eq!(config.repetition_penalty, 1.2); assert!(!config.streaming); } #[test] fn test_random_float_range() { for _ in 0..100 { let r = random_float(); assert!(r >= 0.0); assert!(r < 1.0); } } #[test] fn test_special_tokens() { assert_eq!(BOS_TOKEN_ID, 151_643); assert_eq!(EOS_TOKEN_ID, 151_645); assert_eq!(START_OF_SPEECH, 151_668); assert_eq!(END_OF_SPEECH, 151_669); } }