summaryrefslogblamecommitdiff
path: root/makima/src/tts/qwen3/generate.rs
blob: 30d165b812b24745bed6e30d3d7f71bf3aaf4308 (plain) (tree)
1
2
3
4
5
6
7
8
9








                                                                        


                                              




















































                                                                            


                                                                            









                                              
                                             







                             
                        


         






                                                         
















































































































                                                                                              





                                                                                                   

















































































































































                                                                                                                
                                                     





                                                                               








                                                                                               












































































                                                                                            
//! Autoregressive generation loop for Qwen3-TTS.
//!
//! Orchestrates the full inference pipeline:
//! 1. Encode reference audio → speaker embedding via speech tokenizer
//! 2. Tokenize text → token IDs
//! 3. Autoregressive LM generation → zeroth codebook tokens
//! 4. Code predictor → remaining 15 codebook tokens per frame
//! 5. Speech tokenizer decoder → waveform audio

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use candle_core::{DType, Device, IndexOp, Result, Tensor};
use tokenizers::Tokenizer;

use super::code_predictor::CodePredictor;
use super::model::{KvCache, Qwen3Model};
use super::speech_tokenizer::SpeechTokenizer;
use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};

/// Special tokens for the Qwen3-TTS vocabulary.
pub const BOS_TOKEN_ID: u32 = 151_643;
pub const EOS_TOKEN_ID: u32 = 151_645;
pub const PAD_TOKEN_ID: u32 = 151_643;

/// Speech-specific control tokens.
/// These are placeholders — actual values come from the tokenizer config.
pub const START_OF_SPEECH: u32 = 151_668;
pub const END_OF_SPEECH: u32 = 151_669;

/// Generation configuration.
#[derive(Debug, Clone)]
pub struct GenerationConfig {
    /// Maximum number of speech tokens to generate.
    pub max_new_tokens: usize,
    /// Temperature for sampling (1.0 = greedy if top_k=1).
    pub temperature: f32,
    /// Top-k sampling (0 = disabled, use greedy argmax).
    pub top_k: usize,
    /// Repetition penalty.
    pub repetition_penalty: f32,
    /// Whether to generate audio chunks incrementally (streaming).
    pub streaming: bool,
}

impl Default for GenerationConfig {
    fn default() -> Self {
        Self {
            max_new_tokens: 2048,
            temperature: 1.0,
            top_k: 0, // Greedy by default
            repetition_penalty: 1.2,
            streaming: false,
        }
    }
}

/// Manages the full generation pipeline.
pub struct GenerationContext<'a> {
    model: &'a Qwen3Model,
    code_predictor: &'a CodePredictor,
    speech_tokenizer: &'a SpeechTokenizer,
    tokenizer: &'a Tokenizer,
    device: &'a Device,
    config: GenerationConfig,
    /// Optional cancellation flag. When set to `true`, the generation loop
    /// will break early and return whatever audio has been produced so far.
    cancel_flag: Option<Arc<AtomicBool>>,
}

impl<'a> GenerationContext<'a> {
    pub fn new(
        model: &'a Qwen3Model,
        code_predictor: &'a CodePredictor,
        speech_tokenizer: &'a SpeechTokenizer,
        tokenizer: &'a Tokenizer,
        device: &'a Device,
        config: GenerationConfig,
        cancel_flag: Option<Arc<AtomicBool>>,
    ) -> Self {
        Self {
            model,
            code_predictor,
            speech_tokenizer,
            tokenizer,
            device,
            config,
            cancel_flag,
        }
    }

    /// Check whether cancellation has been requested.
    fn is_cancelled(&self) -> bool {
        self.cancel_flag
            .as_ref()
            .map_or(false, |f| f.load(Ordering::Relaxed))
    }

    /// Generate audio from text, optionally with a voice reference.
    ///
    /// Returns a list of audio chunks. If `streaming` is false, returns
    /// a single chunk with the complete audio.
    pub fn generate(
        &self,
        text: &str,
        reference_audio: Option<&[f32]>,
    ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
        // 1. Encode reference audio if provided
        let reference_codes = match reference_audio {
            Some(audio) => Some(
                self.speech_tokenizer
                    .encode(audio)
                    .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
            ),
            None => None,
        };

        // 2. Tokenize text
        let encoding = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| TtsError::Tokenizer(e.to_string()))?;

        let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();

        // 3. Prepare input sequence
        // Format: [BOS] [text_tokens] [START_OF_SPEECH]
        let mut input_ids = Vec::new();
        input_ids.push(BOS_TOKEN_ID);
        input_ids.extend_from_slice(&text_token_ids);
        input_ids.push(START_OF_SPEECH);

        // 4. Run autoregressive generation
        let generated_frames = self
            .autoregressive_generate(&input_ids, reference_codes.as_deref())
            .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;

        if generated_frames.is_empty() {
            return Ok(vec![AudioChunk {
                samples: vec![],
                sample_rate: SAMPLE_RATE,
                is_final: true,
            }]);
        }

        // 5. Decode all frames to audio
        if self.config.streaming {
            self.decode_streaming(&generated_frames)
        } else {
            self.decode_batch(&generated_frames)
        }
    }

    /// Autoregressive generation loop.
    ///
    /// Generates zeroth codebook tokens one at a time, then uses the code
    /// predictor to fill in the remaining 15 codebooks per frame.
    ///
    /// Returns: Vec of frames, each frame is [num_codebooks] tokens.
    fn autoregressive_generate(
        &self,
        input_ids: &[u32],
        _reference_codes: Option<&[Vec<u32>]>,
    ) -> Result<Vec<Vec<u32>>> {
        let _num_codebooks = self.code_predictor.num_code_groups();
        let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
            .map(|_| KvCache::new())
            .collect();

        let mut generated_frames: Vec<Vec<u32>> = Vec::new();
        let mut past_zeroth_tokens: Vec<u32> = Vec::new();

        // === First iteration: process the full input sequence ===
        let input_tensor = Tensor::from_vec(
            input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
            (1, input_ids.len()),
            self.device,
        )?
        .to_dtype(DType::I64)?;

        let seq_len = input_ids.len();
        let attention_mask =
            Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;

        let logits =
            self.model
                .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;

        // Get the logits for the last position
        let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
        let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;

        if first_token == END_OF_SPEECH as u32 {
            return Ok(generated_frames);
        }

        // Use code predictor for all codebooks
        let lm_hidden = self
            .model
            .last_hidden_state()
            .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
        let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;

        let frame_codes = self
            .code_predictor
            .predict(&last_hidden, first_token, self.device)?;
        generated_frames.push(frame_codes);
        past_zeroth_tokens.push(first_token);

        // === Subsequent iterations: one token at a time ===
        for _step in 1..self.config.max_new_tokens {
            // Check for cancellation each iteration
            if self.is_cancelled() {
                tracing::info!("TTS generation cancelled after {} frames", generated_frames.len());
                break;
            }

            let past_len = kv_caches[0].seq_len();

            // Input: just the last generated zeroth codebook token
            let last_token = *past_zeroth_tokens.last().unwrap();
            let token_tensor = Tensor::from_vec(
                vec![last_token as i64],
                (1, 1),
                self.device,
            )?
            .to_dtype(DType::I64)?;

            // Single-token attention mask
            let attention_mask =
                Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;

            let logits =
                self.model
                    .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;

            let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
            let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;

            if next_token == END_OF_SPEECH as u32 {
                break;
            }

            // Predict all codebooks for this frame
            let lm_hidden = self
                .model
                .last_hidden_state()
                .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;

            let frame_codes = self
                .code_predictor
                .predict(&lm_hidden, next_token, self.device)?;
            generated_frames.push(frame_codes);
            past_zeroth_tokens.push(next_token);
        }

        Ok(generated_frames)
    }

    /// Sample a token from logits.
    fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
        let mut logits_vec: Vec<f32> = logits.to_vec1()?;

        // Apply repetition penalty
        if self.config.repetition_penalty != 1.0 {
            for &token in past_tokens {
                let idx = token as usize;
                if idx < logits_vec.len() {
                    let score = logits_vec[idx];
                    logits_vec[idx] = if score < 0.0 {
                        score * self.config.repetition_penalty
                    } else {
                        score / self.config.repetition_penalty
                    };
                }
            }
        }

        if self.config.top_k == 0 || self.config.temperature == 0.0 {
            // Greedy: argmax
            let (max_idx, _) = logits_vec
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| {
                    a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
                })
                .unwrap_or((0, &0.0));
            Ok(max_idx as u32)
        } else {
            // Top-k sampling with temperature
            let temperature = self.config.temperature;

            // Apply temperature
            for v in logits_vec.iter_mut() {
                *v /= temperature;
            }

            // Sort indices by logit value (descending)
            let mut indexed: Vec<(usize, f32)> =
                logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

            // Keep only top-k
            let k = self.config.top_k.min(indexed.len());
            let top_k = &indexed[..k];

            // Softmax over top-k
            let max_val = top_k[0].1;
            let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
            let probs: Vec<(usize, f32)> = top_k
                .iter()
                .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
                .collect();

            // Sample from distribution (simple linear scan)
            let r: f32 = random_float();
            let mut cumulative = 0.0;
            for (idx, prob) in &probs {
                cumulative += prob;
                if cumulative >= r {
                    return Ok(*idx as u32);
                }
            }

            // Fallback to highest probability
            Ok(probs[0].0 as u32)
        }
    }

    /// Decode all frames in batch (non-streaming).
    fn decode_batch(
        &self,
        frames: &[Vec<u32>],
    ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
        let num_codebooks = self.speech_tokenizer.num_codebooks();

        // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
        let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
        for frame in frames {
            for (cb_idx, &code) in frame.iter().enumerate() {
                if cb_idx < num_codebooks {
                    codes_by_codebook[cb_idx].push(code);
                }
            }
        }

        let samples = self
            .speech_tokenizer
            .decode(&codes_by_codebook)
            .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;

        Ok(vec![AudioChunk {
            samples,
            sample_rate: SAMPLE_RATE,
            is_final: true,
        }])
    }

    /// Decode frames incrementally (streaming).
    fn decode_streaming(
        &self,
        frames: &[Vec<u32>],
    ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
        let mut chunks: Vec<AudioChunk> = Vec::new();

        // Decode in groups of frames for efficiency
        let chunk_size = 10; // ~800ms per chunk at 12.5Hz
        let num_codebooks = self.speech_tokenizer.num_codebooks();

        for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
            // Check for cancellation between streaming chunks
            if self.is_cancelled() {
                tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len());
                if let Some(last) = chunks.last_mut() {
                    last.is_final = true;
                }
                return Ok(chunks);
            }

            let is_last = (chunk_idx + 1) * chunk_size >= frames.len();

            // Transpose chunk frames
            let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
            for frame in frame_chunk {
                for (cb_idx, &code) in frame.iter().enumerate() {
                    if cb_idx < num_codebooks {
                        codes_by_codebook[cb_idx].push(code);
                    }
                }
            }

            let samples = self
                .speech_tokenizer
                .decode(&codes_by_codebook)
                .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;

            chunks.push(AudioChunk {
                samples,
                sample_rate: SAMPLE_RATE,
                is_final: is_last,
            });
        }

        Ok(chunks)
    }
}

/// Simple pseudo-random float in [0, 1) using thread-local state.
/// Uses a basic xorshift for reproducibility without external deps.
fn random_float() -> f32 {
    use std::cell::Cell;
    thread_local! {
        static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
    }

    STATE.with(|s| {
        let mut x = s.get();
        x ^= x << 13;
        x ^= x >> 7;
        x ^= x << 17;
        s.set(x);
        (x as f32) / (u64::MAX as f32)
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_generation_config_default() {
        let config = GenerationConfig::default();
        assert_eq!(config.max_new_tokens, 2048);
        assert_eq!(config.top_k, 0);
        assert_eq!(config.temperature, 1.0);
        assert_eq!(config.repetition_penalty, 1.2);
        assert!(!config.streaming);
    }

    #[test]
    fn test_random_float_range() {
        for _ in 0..100 {
            let r = random_float();
            assert!(r >= 0.0);
            assert!(r < 1.0);
        }
    }

    #[test]
    fn test_special_tokens() {
        assert_eq!(BOS_TOKEN_ID, 151_643);
        assert_eq!(EOS_TOKEN_ID, 151_645);
        assert_eq!(START_OF_SPEECH, 151_668);
        assert_eq!(END_OF_SPEECH, 151_669);
    }
}