//! Autoregressive generation loop for Qwen3-TTS.
//!
//! Orchestrates the full inference pipeline:
//! 1. Encode reference audio → speaker embedding via speech tokenizer
//! 2. Tokenize text → token IDs
//! 3. Autoregressive LM generation → zeroth codebook tokens
//! 4. Code predictor → remaining 15 codebook tokens per frame
//! 5. Speech tokenizer decoder → waveform audio
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use tokenizers::Tokenizer;
use super::code_predictor::CodePredictor;
use super::model::{KvCache, Qwen3Model};
use super::speech_tokenizer::SpeechTokenizer;
use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};
/// Special tokens for the Qwen3-TTS vocabulary.
pub const BOS_TOKEN_ID: u32 = 151_643;
pub const EOS_TOKEN_ID: u32 = 151_645;
pub const PAD_TOKEN_ID: u32 = 151_643;
/// Speech-specific control tokens.
/// These are placeholders — actual values come from the tokenizer config.
pub const START_OF_SPEECH: u32 = 151_668;
pub const END_OF_SPEECH: u32 = 151_669;
/// Generation configuration.
#[derive(Debug, Clone)]
pub struct GenerationConfig {
/// Maximum number of speech tokens to generate.
pub max_new_tokens: usize,
/// Temperature for sampling (1.0 = greedy if top_k=1).
pub temperature: f32,
/// Top-k sampling (0 = disabled, use greedy argmax).
pub top_k: usize,
/// Repetition penalty.
pub repetition_penalty: f32,
/// Whether to generate audio chunks incrementally (streaming).
pub streaming: bool,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_new_tokens: 2048,
temperature: 1.0,
top_k: 0, // Greedy by default
repetition_penalty: 1.2,
streaming: false,
}
}
}
/// Manages the full generation pipeline.
pub struct GenerationContext<'a> {
model: &'a Qwen3Model,
code_predictor: &'a CodePredictor,
speech_tokenizer: &'a SpeechTokenizer,
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
/// Optional cancellation flag. When set to `true`, the generation loop
/// will break early and return whatever audio has been produced so far.
cancel_flag: Option<Arc<AtomicBool>>,
}
impl<'a> GenerationContext<'a> {
pub fn new(
model: &'a Qwen3Model,
code_predictor: &'a CodePredictor,
speech_tokenizer: &'a SpeechTokenizer,
tokenizer: &'a Tokenizer,
device: &'a Device,
config: GenerationConfig,
cancel_flag: Option<Arc<AtomicBool>>,
) -> Self {
Self {
model,
code_predictor,
speech_tokenizer,
tokenizer,
device,
config,
cancel_flag,
}
}
/// Check whether cancellation has been requested.
fn is_cancelled(&self) -> bool {
self.cancel_flag
.as_ref()
.map_or(false, |f| f.load(Ordering::Relaxed))
}
/// Generate audio from text, optionally with a voice reference.
///
/// Returns a list of audio chunks. If `streaming` is false, returns
/// a single chunk with the complete audio.
pub fn generate(
&self,
text: &str,
reference_audio: Option<&[f32]>,
) -> std::result::Result<Vec<AudioChunk>, TtsError> {
// 1. Encode reference audio if provided
let reference_codes = match reference_audio {
Some(audio) => Some(
self.speech_tokenizer
.encode(audio)
.map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
),
None => None,
};
// 2. Tokenize text
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| TtsError::Tokenizer(e.to_string()))?;
let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();
// 3. Prepare input sequence
// Format: [BOS] [text_tokens] [START_OF_SPEECH]
let mut input_ids = Vec::new();
input_ids.push(BOS_TOKEN_ID);
input_ids.extend_from_slice(&text_token_ids);
input_ids.push(START_OF_SPEECH);
// 4. Run autoregressive generation
let generated_frames = self
.autoregressive_generate(&input_ids, reference_codes.as_deref())
.map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;
if generated_frames.is_empty() {
return Ok(vec![AudioChunk {
samples: vec![],
sample_rate: SAMPLE_RATE,
is_final: true,
}]);
}
// 5. Decode all frames to audio
if self.config.streaming {
self.decode_streaming(&generated_frames)
} else {
self.decode_batch(&generated_frames)
}
}
/// Autoregressive generation loop.
///
/// Generates zeroth codebook tokens one at a time, then uses the code
/// predictor to fill in the remaining 15 codebooks per frame.
///
/// Returns: Vec of frames, each frame is [num_codebooks] tokens.
fn autoregressive_generate(
&self,
input_ids: &[u32],
_reference_codes: Option<&[Vec<u32>]>,
) -> Result<Vec<Vec<u32>>> {
let _num_codebooks = self.code_predictor.num_code_groups();
let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
.map(|_| KvCache::new())
.collect();
let mut generated_frames: Vec<Vec<u32>> = Vec::new();
let mut past_zeroth_tokens: Vec<u32> = Vec::new();
// === First iteration: process the full input sequence ===
let input_tensor = Tensor::from_vec(
input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
(1, input_ids.len()),
self.device,
)?
.to_dtype(DType::I64)?;
let seq_len = input_ids.len();
let attention_mask =
Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;
let logits =
self.model
.forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;
// Get the logits for the last position
let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;
if first_token == END_OF_SPEECH as u32 {
return Ok(generated_frames);
}
// Use code predictor for all codebooks
let lm_hidden = self
.model
.last_hidden_state()
.ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;
let frame_codes = self
.code_predictor
.predict(&last_hidden, first_token, self.device)?;
generated_frames.push(frame_codes);
past_zeroth_tokens.push(first_token);
// === Subsequent iterations: one token at a time ===
for _step in 1..self.config.max_new_tokens {
// Check for cancellation each iteration
if self.is_cancelled() {
tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
break;
}
let past_len = kv_caches[0].seq_len();
// Input: just the last generated zeroth codebook token
let last_token = *past_zeroth_tokens.last().unwrap();
let token_tensor = Tensor::from_vec(
vec![last_token as i64],
(1, 1),
self.device,
)?
.to_dtype(DType::I64)?;
// Single-token attention mask
let attention_mask =
Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;
let logits =
self.model
.forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;
let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;
if next_token == END_OF_SPEECH as u32 {
break;
}
// Predict all codebooks for this frame
let lm_hidden = self
.model
.last_hidden_state()
.ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
let frame_codes = self
.code_predictor
.predict(&lm_hidden, next_token, self.device)?;
generated_frames.push(frame_codes);
past_zeroth_tokens.push(next_token);
}
Ok(generated_frames)
}
/// Sample a token from logits.
fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
let mut logits_vec: Vec<f32> = logits.to_vec1()?;
// Apply repetition penalty
if self.config.repetition_penalty != 1.0 {
for &token in past_tokens {
let idx = token as usize;
if idx < logits_vec.len() {
let score = logits_vec[idx];
logits_vec[idx] = if score < 0.0 {
score * self.config.repetition_penalty
} else {
score / self.config.repetition_penalty
};
}
}
}
if self.config.top_k == 0 || self.config.temperature == 0.0 {
// Greedy: argmax
let (max_idx, _) = logits_vec
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or((0, &0.0));
Ok(max_idx as u32)
} else {
// Top-k sampling with temperature
let temperature = self.config.temperature;
// Apply temperature
for v in logits_vec.iter_mut() {
*v /= temperature;
}
// Sort indices by logit value (descending)
let mut indexed: Vec<(usize, f32)> =
logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Keep only top-k
let k = self.config.top_k.min(indexed.len());
let top_k = &indexed[..k];
// Softmax over top-k
let max_val = top_k[0].1;
let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
let probs: Vec<(usize, f32)> = top_k
.iter()
.map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
.collect();
// Sample from distribution (simple linear scan)
let r: f32 = random_float();
let mut cumulative = 0.0;
for (idx, prob) in &probs {
cumulative += prob;
if cumulative >= r {
return Ok(*idx as u32);
}
}
// Fallback to highest probability
Ok(probs[0].0 as u32)
}
}
/// Decode all frames in batch (non-streaming).
fn decode_batch(
&self,
frames: &[Vec<u32>],
) -> std::result::Result<Vec<AudioChunk>, TtsError> {
let num_codebooks = self.speech_tokenizer.num_codebooks();
// Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
for frame in frames {
for (cb_idx, &code) in frame.iter().enumerate() {
if cb_idx < num_codebooks {
codes_by_codebook[cb_idx].push(code);
}
}
}
let samples = self
.speech_tokenizer
.decode(&codes_by_codebook)
.map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;
Ok(vec![AudioChunk {
samples,
sample_rate: SAMPLE_RATE,
is_final: true,
}])
}
/// Decode frames incrementally (streaming).
fn decode_streaming(
&self,
frames: &[Vec<u32>],
) -> std::result::Result<Vec<AudioChunk>, TtsError> {
let mut chunks: Vec<AudioChunk> = Vec::new();
// Decode in groups of frames for efficiency
let chunk_size = 10; // ~800ms per chunk at 12.5Hz
let num_codebooks = self.speech_tokenizer.num_codebooks();
for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
// Check for cancellation between streaming chunks
if self.is_cancelled() {
tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
if let Some(last) = chunks.last_mut() {
last.is_final = true;
}
return Ok(chunks);
}
let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
// Transpose chunk frames
let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
for frame in frame_chunk {
for (cb_idx, &code) in frame.iter().enumerate() {
if cb_idx < num_codebooks {
codes_by_codebook[cb_idx].push(code);
}
}
}
let samples = self
.speech_tokenizer
.decode(&codes_by_codebook)
.map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;
chunks.push(AudioChunk {
samples,
sample_rate: SAMPLE_RATE,
is_final: is_last,
});
}
Ok(chunks)
}
}
/// Simple pseudo-random float in [0, 1) using thread-local state.
/// Uses a basic xorshift for reproducibility without external deps.
fn random_float() -> f32 {
use std::cell::Cell;
thread_local! {
static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
}
STATE.with(|s| {
let mut x = s.get();
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
s.set(x);
(x as f32) / (u64::MAX as f32)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generation_config_default() {
let config = GenerationConfig::default();
assert_eq!(config.max_new_tokens, 2048);
assert_eq!(config.top_k, 0);
assert_eq!(config.temperature, 1.0);
assert_eq!(config.repetition_penalty, 1.2);
assert!(!config.streaming);
}
#[test]
fn test_random_float_range() {
for _ in 0..100 {
let r = random_float();
assert!(r >= 0.0);
assert!(r < 1.0);
}
}
#[test]
fn test_special_tokens() {
assert_eq!(BOS_TOKEN_ID, 151_643);
assert_eq!(EOS_TOKEN_ID, 151_645);
assert_eq!(START_OF_SPEECH, 151_668);
assert_eq!(END_OF_SPEECH, 151_669);
}
}