summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/speech_tokenizer.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-02-02 22:52:05 +0000
committersoryu <soryu@soryu.co>2026-02-02 22:52:05 +0000
commit0f06a7f9968816e5e2553c4f1c2104f2fa504f96 (patch)
tree53d8db119c17d7d22f3127ae5a54e12a3f384e29 /makima/src/tts/qwen3/speech_tokenizer.rs
parent151e9d87e117b7980e6aad522ac8f3633eeca87a (diff)
downloadsoryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.tar.gz
soryu-0f06a7f9968816e5e2553c4f1c2104f2fa504f96.zip
Release in makima repo
Also remove all other TTS models
Diffstat (limited to 'makima/src/tts/qwen3/speech_tokenizer.rs')
-rw-r--r--makima/src/tts/qwen3/speech_tokenizer.rs613
1 files changed, 0 insertions, 613 deletions
diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs
deleted file mode 100644
index 86e00f2..0000000
--- a/makima/src/tts/qwen3/speech_tokenizer.rs
+++ /dev/null
@@ -1,613 +0,0 @@
-//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks.
-//!
-//! Two sub-components:
-//!
-//! **Encoder** (voice cloning): converts reference audio waveform to discrete
-//! multi-codebook tokens via a causal 1D ConvNet + RVQ.
-//!
-//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook
-//! indices via embedding lookup + causal 1D ConvNet.
-//!
-//! The speech tokenizer is a separate model (~682MB) loaded from
-//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`.
-
-use candle_core::{Device, Module, Result, Tensor, D};
-use candle_nn::{
- conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder,
-};
-
-use super::config::SpeechTokenizerConfig;
-
-// ---------------------------------------------------------------------------
-// Weight-Normalized Conv1d
-// ---------------------------------------------------------------------------
-
-/// A 1D convolution with optional weight normalization and activation.
-pub struct ConvBlock {
- conv: Conv1d,
- activation: ConvActivation,
-}
-
-#[derive(Debug, Clone, Copy)]
-pub enum ConvActivation {
- None,
- Elu,
- Tanh,
-}
-
-impl ConvBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- stride: usize,
- padding: usize,
- dilation: usize,
- activation: ConvActivation,
- vb: VarBuilder,
- ) -> Result<Self> {
- let config = Conv1dConfig {
- stride,
- padding,
- dilation,
- groups: 1,
- };
- let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?;
-
- Ok(Self { conv, activation })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let out = self.conv.forward(x)?;
- match self.activation {
- ConvActivation::None => Ok(out),
- ConvActivation::Elu => elu(&out, 1.0),
- ConvActivation::Tanh => out.tanh(),
- }
- }
-}
-
-/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0
-fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> {
- let zeros = x.zeros_like()?;
- let positive = x.maximum(&zeros)?;
- let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?;
- let exp_x = x.exp()?;
- let one = Tensor::ones_like(&exp_x)?;
- let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?;
- positive + negative
-}
-
-// ---------------------------------------------------------------------------
-// Residual Unit
-// ---------------------------------------------------------------------------
-
-/// Residual convolutional unit with dilated convolutions.
-pub struct ResidualUnit {
- conv1: ConvBlock,
- conv2: ConvBlock,
-}
-
-impl ResidualUnit {
- pub fn new(
- channels: usize,
- dilation: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- // Dilated causal conv (kernel=7, dilation varies)
- let padding = (7 - 1) * dilation / 2; // causal-ish padding
- let conv1 = ConvBlock::new(
- channels,
- channels,
- 7,
- 1,
- padding,
- dilation,
- ConvActivation::Elu,
- vb.pp("block.0"),
- )?;
-
- // Pointwise conv (kernel=1)
- let conv2 = ConvBlock::new(
- channels,
- channels,
- 1,
- 1,
- 0,
- 1,
- ConvActivation::Elu,
- vb.pp("block.1"),
- )?;
-
- Ok(Self { conv1, conv2 })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let residual = x;
- let out = self.conv1.forward(x)?;
- let out = self.conv2.forward(&out)?;
- // Match sequence lengths if needed (causal conv may change length)
- let out_len = out.dim(D::Minus1)?;
- let res_len = residual.dim(D::Minus1)?;
- if out_len != res_len {
- let start = res_len.saturating_sub(out_len);
- let residual = residual.narrow(D::Minus1, start, out_len)?;
- residual + out
- } else {
- residual + out
- }
- }
-}
-
-// ---------------------------------------------------------------------------
-// Encoder Block
-// ---------------------------------------------------------------------------
-
-/// Encoder downsampling block: residual units + strided conv.
-pub struct EncoderBlock {
- residual_units: Vec<ResidualUnit>,
- downsample: ConvBlock,
-}
-
-impl EncoderBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- stride: usize,
- num_residuals: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- let mut residual_units = Vec::with_capacity(num_residuals);
- for i in 0..num_residuals {
- let dilation = 3usize.pow(i as u32); // 1, 3, 9
- let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?;
- residual_units.push(unit);
- }
-
- // Strided downsampling convolution
- let kernel_size = stride * 2;
- let padding = stride / 2;
- let downsample = ConvBlock::new(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- 1,
- ConvActivation::Elu,
- vb.pp("downsample"),
- )?;
-
- Ok(Self {
- residual_units,
- downsample,
- })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let mut out = x.clone();
- for unit in &self.residual_units {
- out = unit.forward(&out)?;
- }
- self.downsample.forward(&out)
- }
-}
-
-// ---------------------------------------------------------------------------
-// Decoder Block
-// ---------------------------------------------------------------------------
-
-/// Decoder upsampling block: transposed conv + residual units.
-pub struct DecoderBlock {
- upsample: ConvBlock,
- residual_units: Vec<ResidualUnit>,
-}
-
-impl DecoderBlock {
- pub fn new(
- in_channels: usize,
- out_channels: usize,
- stride: usize,
- num_residuals: usize,
- vb: VarBuilder,
- ) -> Result<Self> {
- // Strided upsampling (transpose conv simulated by regular conv + padding)
- let kernel_size = stride * 2;
- let padding = stride / 2;
- let upsample = ConvBlock::new(
- in_channels,
- out_channels,
- kernel_size,
- 1, // stride=1 for output; upsample via repeat/interpolation
- padding,
- 1,
- ConvActivation::Elu,
- vb.pp("upsample"),
- )?;
-
- let mut residual_units = Vec::with_capacity(num_residuals);
- for i in 0..num_residuals {
- let dilation = 3usize.pow(i as u32);
- let unit =
- ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?;
- residual_units.push(unit);
- }
-
- Ok(Self {
- upsample,
- residual_units,
- })
- }
-
- pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let mut out = self.upsample.forward(x)?;
- for unit in &self.residual_units {
- out = unit.forward(&out)?;
- }
- Ok(out)
- }
-}
-
-// ---------------------------------------------------------------------------
-// RVQ Codebook
-// ---------------------------------------------------------------------------
-
-/// Residual Vector Quantization codebook.
-///
-/// Contains `num_codebooks` embedding tables, each mapping
-/// `codebook_size` indices to `codebook_dim`-dimensional vectors.
-pub struct RvqCodebook {
- codebooks: Vec<Embedding>,
- num_codebooks: usize,
- #[allow(dead_code)]
- codebook_dim: usize,
-}
-
-impl RvqCodebook {
- pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> {
- let mut codebooks = Vec::with_capacity(config.num_codebooks);
- for i in 0..config.num_codebooks {
- let cb = embedding(
- config.codebook_size,
- config.codebook_dim,
- vb.pp(format!("codebooks.{i}")),
- )?;
- codebooks.push(cb);
- }
-
- Ok(Self {
- codebooks,
- num_codebooks: config.num_codebooks,
- codebook_dim: config.codebook_dim,
- })
- }
-
- /// Look up codebook embeddings for all codebook layers.
- ///
- /// `codes`: [num_codebooks, seq_len] — codebook indices per layer
- /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings
- pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
- assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks);
-
- let seq_len = codes[0].len();
- let mut sum: Option<Tensor> = None;
-
- for (i, code_layer) in codes.iter().enumerate() {
- assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch");
-
- let indices = Tensor::from_vec(
- code_layer.clone(),
- (1, seq_len),
- device,
- )?;
-
- // [1, seq_len, codebook_dim]
- let emb = self.codebooks[i].forward(&indices)?;
-
- sum = Some(match sum {
- Some(prev) => (prev + emb)?,
- None => emb,
- });
- }
-
- // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
- let result = sum.unwrap().transpose(1, 2)?;
- Ok(result)
- }
-
- /// Number of codebooks.
- pub fn num_codebooks(&self) -> usize {
- self.num_codebooks
- }
-}
-
-// ---------------------------------------------------------------------------
-// Speech Tokenizer (Encoder + Decoder)
-// ---------------------------------------------------------------------------
-
-/// The complete speech tokenizer with encoder and decoder.
-pub struct SpeechTokenizer {
- /// Encoder: waveform -> latent (for voice cloning).
- encoder_input_conv: ConvBlock,
- encoder_blocks: Vec<EncoderBlock>,
- encoder_output_conv: ConvBlock,
-
- /// RVQ codebooks for quantization.
- codebook: RvqCodebook,
-
- /// Decoder: codes -> waveform.
- decoder_input_conv: ConvBlock,
- decoder_blocks: Vec<DecoderBlock>,
- decoder_output_conv: ConvBlock,
-
- /// Projection from codebook dim to decoder hidden channels.
- decoder_proj: Linear,
-
- config: SpeechTokenizerConfig,
- device: Device,
-}
-
-impl SpeechTokenizer {
- /// Load the speech tokenizer from safetensors.
- pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> {
- let hidden = config.hidden_channels; // 512
-
- // ===== Encoder =====
- // Input: [batch, 1, samples] -> [batch, hidden/8, ...]
- let encoder_input_conv = ConvBlock::new(
- 1,
- hidden / 8, // 64
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Elu,
- vb.pp("encoder.input_conv"),
- )?;
-
- // Downsampling blocks with increasing channels
- let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480
- let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512
- let mut encoder_blocks = Vec::with_capacity(strides.len());
- for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() {
- let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] };
- let block = EncoderBlock::new(
- in_ch,
- out_ch,
- stride,
- 3, // 3 residual units per block
- vb.pp(format!("encoder.blocks.{i}")),
- )?;
- encoder_blocks.push(block);
- }
-
- // Encoder output projection to codebook dim
- let encoder_output_conv = ConvBlock::new(
- hidden,
- config.codebook_dim,
- 3,
- 1,
- 1,
- 1,
- ConvActivation::None,
- vb.pp("encoder.output_conv"),
- )?;
-
- // ===== RVQ Codebook =====
- let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?;
-
- // ===== Decoder =====
- // Projection from codebook dim to decoder hidden
- let decoder_proj = linear_no_bias(
- config.codebook_dim,
- hidden,
- vb.pp("decoder.proj"),
- )?;
-
- // Input conv
- let decoder_input_conv = ConvBlock::new(
- hidden,
- hidden,
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Elu,
- vb.pp("decoder.input_conv"),
- )?;
-
- // Upsampling blocks (reverse order of encoder)
- let dec_strides = [3, 4, 5, 8];
- let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64
- let mut decoder_blocks = Vec::with_capacity(dec_strides.len());
- for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate()
- {
- let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] };
- let block = DecoderBlock::new(
- in_ch,
- out_ch,
- stride,
- 3,
- vb.pp(format!("decoder.blocks.{i}")),
- )?;
- decoder_blocks.push(block);
- }
-
- // Output conv: hidden/8 -> 1 channel (waveform)
- let decoder_output_conv = ConvBlock::new(
- hidden / 8,
- 1,
- 7,
- 1,
- 3,
- 1,
- ConvActivation::Tanh,
- vb.pp("decoder.output_conv"),
- )?;
-
- Ok(Self {
- encoder_input_conv,
- encoder_blocks,
- encoder_output_conv,
- codebook,
- decoder_input_conv,
- decoder_blocks,
- decoder_output_conv,
- decoder_proj,
- config: config.clone(),
- device: device.clone(),
- })
- }
-
- /// Encode reference audio waveform to discrete codebook tokens.
- ///
- /// `audio`: [num_samples] — mono 24kHz audio
- /// Returns: Vec of `num_codebooks` vectors, each containing token indices.
- pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> {
- // [1, 1, num_samples]
- let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?;
-
- // Run encoder
- let mut hidden = self.encoder_input_conv.forward(&x)?;
- for block in &self.encoder_blocks {
- hidden = block.forward(&hidden)?;
- }
- let latent = self.encoder_output_conv.forward(&hidden)?;
-
- // latent: [1, codebook_dim, seq_len]
- // Quantize via nearest-neighbor lookup in each codebook
- let seq_len = latent.dim(D::Minus1)?;
- let mut all_codes = Vec::with_capacity(self.config.num_codebooks);
-
- // Residual quantization: subtract each codebook's contribution
- let mut residual = latent.clone();
-
- for cb_idx in 0..self.config.num_codebooks {
- // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep
- let codes = self.quantize_layer(&residual, cb_idx, seq_len)?;
-
- // Look up the quantized vectors and subtract from residual
- let code_indices =
- Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?;
- let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?;
- // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
- let quantized = quantized.transpose(1, 2)?;
- residual = (residual - quantized)?;
-
- all_codes.push(codes);
- }
-
- Ok(all_codes)
- }
-
- /// Quantize a single RVQ layer by finding the nearest codebook entry.
- fn quantize_layer(
- &self,
- residual: &Tensor,
- codebook_idx: usize,
- _seq_len: usize,
- ) -> Result<Vec<u32>> {
- // residual: [1, codebook_dim, seq_len]
- // codebook weights: [codebook_size, codebook_dim]
- let cb_weight = self.codebook.codebooks[codebook_idx]
- .embeddings()
- .clone(); // [codebook_size, codebook_dim]
-
- // Transpose residual: [1, seq_len, codebook_dim]
- let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim]
-
- // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2
- let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len]
- let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size]
- let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size]
-
- let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1]
- let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size]
-
- let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size]
-
- // Argmin per timestep
- let indices = distances.argmin(D::Minus1)?; // [seq_len]
- let codes: Vec<u32> = indices.to_vec1()?;
-
- Ok(codes)
- }
-
- /// Decode discrete codebook tokens to audio waveform.
- ///
- /// `codes`: Vec of `num_codebooks` vectors of token indices.
- /// Returns: Vec<f32> — mono 24kHz audio samples.
- pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> {
- // Look up and sum all codebook embeddings
- let embeddings = self.codebook.decode(codes, &self.device)?;
- // embeddings: [1, codebook_dim, seq_len]
-
- // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden]
- let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim]
- let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden]
- let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len]
-
- // Run decoder
- hidden = self.decoder_input_conv.forward(&hidden)?;
- for block in &self.decoder_blocks {
- hidden = block.forward(&hidden)?;
- }
- let waveform = self.decoder_output_conv.forward(&hidden)?;
-
- // [1, 1, num_samples] -> Vec<f32>
- let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?;
- Ok(samples)
- }
-
- /// Decode a single frame's codes to audio samples (for streaming).
- ///
- /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame
- /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz)
- pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> {
- let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect();
- self.decode(&codes)
- }
-
- /// Get the number of codebooks.
- pub fn num_codebooks(&self) -> usize {
- self.config.num_codebooks
- }
-
- /// Get the output sample rate.
- pub fn sample_rate(&self) -> u32 {
- self.config.sample_rate
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_elu_positive() {
- let device = Device::Cpu;
- let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
- let result = elu(&x, 1.0).unwrap();
- let values: Vec<f32> = result.to_vec1().unwrap();
- assert!((values[0] - 1.0).abs() < 1e-5);
- assert!((values[1] - 2.0).abs() < 1e-5);
- }
-
- #[test]
- fn test_elu_negative() {
- let device = Device::Cpu;
- let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap();
- let result = elu(&x, 1.0).unwrap();
- let values: Vec<f32> = result.to_vec1().unwrap();
- // ELU(-1) = exp(-1) - 1 ≈ -0.6321
- assert!((values[0] - (-0.6321)).abs() < 0.01);
- }
-
- #[test]
- fn test_speech_tokenizer_config() {
- let config = SpeechTokenizerConfig::default();
- assert_eq!(config.num_codebooks, 16);
- assert_eq!(config.codebook_size, 2048);
- assert_eq!(config.sample_rate, 24_000);
- }
-}