summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/generate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/tts/qwen3/generate.rs')
-rw-r--r--makima/src/tts/qwen3/generate.rs456
1 files changed, 0 insertions, 456 deletions
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
deleted file mode 100644
index 30d165b..0000000
--- a/makima/src/tts/qwen3/generate.rs
+++ /dev/null
@@ -1,456 +0,0 @@
-//! Autoregressive generation loop for Qwen3-TTS.
-//!
-//! Orchestrates the full inference pipeline:
-//! 1. Encode reference audio → speaker embedding via speech tokenizer
-//! 2. Tokenize text → token IDs
-//! 3. Autoregressive LM generation → zeroth codebook tokens
-//! 4. Code predictor → remaining 15 codebook tokens per frame
-//! 5. Speech tokenizer decoder → waveform audio
-
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-
-use candle_core::{DType, Device, IndexOp, Result, Tensor};
-use tokenizers::Tokenizer;
-
-use super::code_predictor::CodePredictor;
-use super::model::{KvCache, Qwen3Model};
-use super::speech_tokenizer::SpeechTokenizer;
-use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};
-
-/// Special tokens for the Qwen3-TTS vocabulary.
-pub const BOS_TOKEN_ID: u32 = 151_643;
-pub const EOS_TOKEN_ID: u32 = 151_645;
-pub const PAD_TOKEN_ID: u32 = 151_643;
-
-/// Speech-specific control tokens.
-/// These are placeholders — actual values come from the tokenizer config.
-pub const START_OF_SPEECH: u32 = 151_668;
-pub const END_OF_SPEECH: u32 = 151_669;
-
-/// Generation configuration.
-#[derive(Debug, Clone)]
-pub struct GenerationConfig {
- /// Maximum number of speech tokens to generate.
- pub max_new_tokens: usize,
- /// Temperature for sampling (1.0 = greedy if top_k=1).
- pub temperature: f32,
- /// Top-k sampling (0 = disabled, use greedy argmax).
- pub top_k: usize,
- /// Repetition penalty.
- pub repetition_penalty: f32,
- /// Whether to generate audio chunks incrementally (streaming).
- pub streaming: bool,
-}
-
-impl Default for GenerationConfig {
- fn default() -> Self {
- Self {
- max_new_tokens: 2048,
- temperature: 1.0,
- top_k: 0, // Greedy by default
- repetition_penalty: 1.2,
- streaming: false,
- }
- }
-}
-
-/// Manages the full generation pipeline.
-pub struct GenerationContext<'a> {
- model: &'a Qwen3Model,
- code_predictor: &'a CodePredictor,
- speech_tokenizer: &'a SpeechTokenizer,
- tokenizer: &'a Tokenizer,
- device: &'a Device,
- config: GenerationConfig,
- /// Optional cancellation flag. When set to `true`, the generation loop
- /// will break early and return whatever audio has been produced so far.
- cancel_flag: Option<Arc<AtomicBool>>,
-}
-
-impl<'a> GenerationContext<'a> {
- pub fn new(
- model: &'a Qwen3Model,
- code_predictor: &'a CodePredictor,
- speech_tokenizer: &'a SpeechTokenizer,
- tokenizer: &'a Tokenizer,
- device: &'a Device,
- config: GenerationConfig,
- cancel_flag: Option<Arc<AtomicBool>>,
- ) -> Self {
- Self {
- model,
- code_predictor,
- speech_tokenizer,
- tokenizer,
- device,
- config,
- cancel_flag,
- }
- }
-
- /// Check whether cancellation has been requested.
- fn is_cancelled(&self) -> bool {
- self.cancel_flag
- .as_ref()
- .map_or(false, |f| f.load(Ordering::Relaxed))
- }
-
- /// Generate audio from text, optionally with a voice reference.
- ///
- /// Returns a list of audio chunks. If `streaming` is false, returns
- /// a single chunk with the complete audio.
- pub fn generate(
- &self,
- text: &str,
- reference_audio: Option<&[f32]>,
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- // 1. Encode reference audio if provided
- let reference_codes = match reference_audio {
- Some(audio) => Some(
- self.speech_tokenizer
- .encode(audio)
- .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
- ),
- None => None,
- };
-
- // 2. Tokenize text
- let encoding = self
- .tokenizer
- .encode(text, true)
- .map_err(|e| TtsError::Tokenizer(e.to_string()))?;
-
- let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();
-
- // 3. Prepare input sequence
- // Format: [BOS] [text_tokens] [START_OF_SPEECH]
- let mut input_ids = Vec::new();
- input_ids.push(BOS_TOKEN_ID);
- input_ids.extend_from_slice(&text_token_ids);
- input_ids.push(START_OF_SPEECH);
-
- // 4. Run autoregressive generation
- let generated_frames = self
- .autoregressive_generate(&input_ids, reference_codes.as_deref())
- .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;
-
- if generated_frames.is_empty() {
- return Ok(vec![AudioChunk {
- samples: vec![],
- sample_rate: SAMPLE_RATE,
- is_final: true,
- }]);
- }
-
- // 5. Decode all frames to audio
- if self.config.streaming {
- self.decode_streaming(&generated_frames)
- } else {
- self.decode_batch(&generated_frames)
- }
- }
-
- /// Autoregressive generation loop.
- ///
- /// Generates zeroth codebook tokens one at a time, then uses the code
- /// predictor to fill in the remaining 15 codebooks per frame.
- ///
- /// Returns: Vec of frames, each frame is [num_codebooks] tokens.
- fn autoregressive_generate(
- &self,
- input_ids: &[u32],
- _reference_codes: Option<&[Vec<u32>]>,
- ) -> Result<Vec<Vec<u32>>> {
- let _num_codebooks = self.code_predictor.num_code_groups();
- let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
- .map(|_| KvCache::new())
- .collect();
-
- let mut generated_frames: Vec<Vec<u32>> = Vec::new();
- let mut past_zeroth_tokens: Vec<u32> = Vec::new();
-
- // === First iteration: process the full input sequence ===
- let input_tensor = Tensor::from_vec(
- input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
- (1, input_ids.len()),
- self.device,
- )?
- .to_dtype(DType::I64)?;
-
- let seq_len = input_ids.len();
- let attention_mask =
- Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;
-
- let logits =
- self.model
- .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;
-
- // Get the logits for the last position
- let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
- let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;
-
- if first_token == END_OF_SPEECH as u32 {
- return Ok(generated_frames);
- }
-
- // Use code predictor for all codebooks
- let lm_hidden = self
- .model
- .last_hidden_state()
- .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
- let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;
-
- let frame_codes = self
- .code_predictor
- .predict(&last_hidden, first_token, self.device)?;
- generated_frames.push(frame_codes);
- past_zeroth_tokens.push(first_token);
-
- // === Subsequent iterations: one token at a time ===
- for _step in 1..self.config.max_new_tokens {
- // Check for cancellation each iteration
- if self.is_cancelled() {
- tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
- break;
- }
-
- let past_len = kv_caches[0].seq_len();
-
- // Input: just the last generated zeroth codebook token
- let last_token = *past_zeroth_tokens.last().unwrap();
- let token_tensor = Tensor::from_vec(
- vec![last_token as i64],
- (1, 1),
- self.device,
- )?
- .to_dtype(DType::I64)?;
-
- // Single-token attention mask
- let attention_mask =
- Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;
-
- let logits =
- self.model
- .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;
-
- let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
- let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;
-
- if next_token == END_OF_SPEECH as u32 {
- break;
- }
-
- // Predict all codebooks for this frame
- let lm_hidden = self
- .model
- .last_hidden_state()
- .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
-
- let frame_codes = self
- .code_predictor
- .predict(&lm_hidden, next_token, self.device)?;
- generated_frames.push(frame_codes);
- past_zeroth_tokens.push(next_token);
- }
-
- Ok(generated_frames)
- }
-
- /// Sample a token from logits.
- fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
- let mut logits_vec: Vec<f32> = logits.to_vec1()?;
-
- // Apply repetition penalty
- if self.config.repetition_penalty != 1.0 {
- for &token in past_tokens {
- let idx = token as usize;
- if idx < logits_vec.len() {
- let score = logits_vec[idx];
- logits_vec[idx] = if score < 0.0 {
- score * self.config.repetition_penalty
- } else {
- score / self.config.repetition_penalty
- };
- }
- }
- }
-
- if self.config.top_k == 0 || self.config.temperature == 0.0 {
- // Greedy: argmax
- let (max_idx, _) = logits_vec
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| {
- a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
- })
- .unwrap_or((0, &0.0));
- Ok(max_idx as u32)
- } else {
- // Top-k sampling with temperature
- let temperature = self.config.temperature;
-
- // Apply temperature
- for v in logits_vec.iter_mut() {
- *v /= temperature;
- }
-
- // Sort indices by logit value (descending)
- let mut indexed: Vec<(usize, f32)> =
- logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
- indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
-
- // Keep only top-k
- let k = self.config.top_k.min(indexed.len());
- let top_k = &indexed[..k];
-
- // Softmax over top-k
- let max_val = top_k[0].1;
- let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
- let probs: Vec<(usize, f32)> = top_k
- .iter()
- .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
- .collect();
-
- // Sample from distribution (simple linear scan)
- let r: f32 = random_float();
- let mut cumulative = 0.0;
- for (idx, prob) in &probs {
- cumulative += prob;
- if cumulative >= r {
- return Ok(*idx as u32);
- }
- }
-
- // Fallback to highest probability
- Ok(probs[0].0 as u32)
- }
- }
-
- /// Decode all frames in batch (non-streaming).
- fn decode_batch(
- &self,
- frames: &[Vec<u32>],
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let num_codebooks = self.speech_tokenizer.num_codebooks();
-
- // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
- let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
- for frame in frames {
- for (cb_idx, &code) in frame.iter().enumerate() {
- if cb_idx < num_codebooks {
- codes_by_codebook[cb_idx].push(code);
- }
- }
- }
-
- let samples = self
- .speech_tokenizer
- .decode(&codes_by_codebook)
- .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;
-
- Ok(vec![AudioChunk {
- samples,
- sample_rate: SAMPLE_RATE,
- is_final: true,
- }])
- }
-
- /// Decode frames incrementally (streaming).
- fn decode_streaming(
- &self,
- frames: &[Vec<u32>],
- ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
- let mut chunks: Vec<AudioChunk> = Vec::new();
-
- // Decode in groups of frames for efficiency
- let chunk_size = 10; // ~800ms per chunk at 12.5Hz
- let num_codebooks = self.speech_tokenizer.num_codebooks();
-
- for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
- // Check for cancellation between streaming chunks
- if self.is_cancelled() {
- tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
- if let Some(last) = chunks.last_mut() {
- last.is_final = true;
- }
- return Ok(chunks);
- }
-
- let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
-
- // Transpose chunk frames
- let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
- for frame in frame_chunk {
- for (cb_idx, &code) in frame.iter().enumerate() {
- if cb_idx < num_codebooks {
- codes_by_codebook[cb_idx].push(code);
- }
- }
- }
-
- let samples = self
- .speech_tokenizer
- .decode(&codes_by_codebook)
- .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;
-
- chunks.push(AudioChunk {
- samples,
- sample_rate: SAMPLE_RATE,
- is_final: is_last,
- });
- }
-
- Ok(chunks)
- }
-}
-
-/// Simple pseudo-random float in [0, 1) using thread-local state.
-/// Uses a basic xorshift for reproducibility without external deps.
-fn random_float() -> f32 {
- use std::cell::Cell;
- thread_local! {
- static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
- }
-
- STATE.with(|s| {
- let mut x = s.get();
- x ^= x << 13;
- x ^= x >> 7;
- x ^= x << 17;
- s.set(x);
- (x as f32) / (u64::MAX as f32)
- })
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_generation_config_default() {
- let config = GenerationConfig::default();
- assert_eq!(config.max_new_tokens, 2048);
- assert_eq!(config.top_k, 0);
- assert_eq!(config.temperature, 1.0);
- assert_eq!(config.repetition_penalty, 1.2);
- assert!(!config.streaming);
- }
-
- #[test]
- fn test_random_float_range() {
- for _ in 0..100 {
- let r = random_float();
- assert!(r >= 0.0);
- assert!(r < 1.0);
- }
- }
-
- #[test]
- fn test_special_tokens() {
- assert_eq!(BOS_TOKEN_ID, 151_643);
- assert_eq!(EOS_TOKEN_ID, 151_645);
- assert_eq!(START_OF_SPEECH, 151_668);
- assert_eq!(END_OF_SPEECH, 151_669);
- }
-}