use std::path::{Path, PathBuf}; use std::fs; use hf_hub::api::sync::Api; use std::borrow::Cow; use ndarray::{ArrayD, Array2, Array3, Array4, IxDyn}; use ort::session::Session; use ort::value::{Value, DynValue}; use tokenizers::Tokenizer; use crate::audio; pub const SAMPLE_RATE: u32 = 24_000; const START_SPEECH_TOKEN: i64 = 6561; const STOP_SPEECH_TOKEN: i64 = 6562; const SILENCE_TOKEN: i64 = 4299; const NUM_LAYERS: usize = 24; const NUM_KV_HEADS: usize = 16; const HEAD_DIM: usize = 64; const MODEL_ID: &str = "ResembleAI/chatterbox-turbo-ONNX"; const DEFAULT_MODEL_DIR: &str = "models/chatterbox-turbo"; #[derive(Debug)] pub enum TtsError { ModelLoad(String), Inference(String), Tokenizer(String), Audio(audio::AudioError), Io(std::io::Error), VoiceRequired, } impl std::fmt::Display for TtsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"), TtsError::Inference(msg) => write!(f, "inference error: {msg}"), TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"), TtsError::Audio(err) => write!(f, "audio error: {err}"), TtsError::Io(err) => write!(f, "io error: {err}"), TtsError::VoiceRequired => write!(f, "voice reference audio is required for chatterbox-turbo"), } } } impl std::error::Error for TtsError {} impl From for TtsError { fn from(value: audio::AudioError) -> Self { TtsError::Audio(value) } } impl From for TtsError { fn from(value: std::io::Error) -> Self { TtsError::Io(value) } } impl From for TtsError { fn from(value: ort::Error) -> Self { TtsError::ModelLoad(value.to_string()) } } pub struct ChatterboxTTS { speech_encoder: Session, embed_tokens: Session, language_model: Session, conditional_decoder: Session, tokenizer: Tokenizer, } struct VoiceCondition { audio_features: ArrayD, prompt_tokens: ArrayD, speaker_embeddings: ArrayD, speaker_features: ArrayD, } fn extract_f32_tensor(value: &Value) -> Result, TtsError> { let (shape, data) = value .try_extract_tensor::() .map_err(|e| TtsError::Inference(e.to_string()))?; let dims: Vec = shape.iter().map(|&d| d as usize).collect(); ArrayD::from_shape_vec(IxDyn(&dims), data.to_vec()) .map_err(|e| TtsError::Inference(e.to_string())) } fn extract_i64_tensor(value: &Value) -> Result, TtsError> { let (shape, data) = value .try_extract_tensor::() .map_err(|e| TtsError::Inference(e.to_string()))?; let dims: Vec = shape.iter().map(|&d| d as usize).collect(); ArrayD::from_shape_vec(IxDyn(&dims), data.to_vec()) .map_err(|e| TtsError::Inference(e.to_string())) } impl ChatterboxTTS { pub fn from_pretrained(model_dir: Option<&str>) -> Result { let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); if !model_path.exists() { download_models(&model_path)?; } Self::load_from_path(&model_path) } pub fn load_from_path(model_dir: &Path) -> Result { let speech_encoder = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("speech_encoder.onnx"))?; let embed_tokens = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("embed_tokens.onnx"))?; let language_model = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("language_model.onnx"))?; let conditional_decoder = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("conditional_decoder.onnx"))?; let tokenizer_path = model_dir.join("tokenizer.json"); let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| TtsError::Tokenizer(e.to_string()))?; Ok(Self { speech_encoder, embed_tokens, language_model, conditional_decoder, tokenizer, }) } pub fn generate_tts(&mut self, _text: &str) -> Result, TtsError> { // Chatterbox TTS requires voice reference audio Err(TtsError::VoiceRequired) } pub fn generate_tts_with_voice( &mut self, text: &str, sample_audio_path: &Path, ) -> Result, TtsError> { let audio = audio::to_16k_mono_from_path(sample_audio_path)?; let resampled = resample_to_24k(&audio.samples, audio.sample_rate); self.generate_tts_with_samples(text, &resampled, SAMPLE_RATE) } pub fn generate_tts_with_samples( &mut self, text: &str, samples: &[f32], sample_rate: u32, ) -> Result, TtsError> { let resampled = if sample_rate != SAMPLE_RATE { resample_to_24k(samples, sample_rate) } else { samples.to_vec() }; // 1. Encode reference audio let voice_condition = self.encode_voice(&resampled)?; // 2. Tokenize text let encoding = self .tokenizer .encode(text, true) .map_err(|e| TtsError::Tokenizer(e.to_string()))?; let text_input_ids: Vec = encoding.get_ids().iter().map(|&id| id as i64).collect(); // 3. Generate speech tokens let generated_tokens = self.generate_speech_tokens( &text_input_ids, &voice_condition.audio_features, )?; // 4. Prepare final speech tokens: prompt_tokens + generated + silence let prompt_tokens: Vec = voice_condition.prompt_tokens.iter().copied().collect(); let silence_tokens = vec![SILENCE_TOKEN; 3]; let mut final_tokens = Vec::with_capacity( prompt_tokens.len() + generated_tokens.len() + silence_tokens.len() ); final_tokens.extend_from_slice(&prompt_tokens); final_tokens.extend_from_slice(&generated_tokens); final_tokens.extend_from_slice(&silence_tokens); // 5. Decode to audio let audio_samples = self.decode_speech_tokens( &final_tokens, &voice_condition.speaker_embeddings, &voice_condition.speaker_features, )?; Ok(audio_samples) } fn encode_voice(&mut self, samples: &[f32]) -> Result { let audio_arr = Array2::from_shape_vec((1, samples.len()), samples.to_vec()) .map_err(|e| TtsError::Inference(e.to_string()))?; let audio_tensor = Value::from_array(audio_arr)?; let outputs = self.speech_encoder.run(ort::inputs!["audio_values" => audio_tensor])?; // Order: audio_features, audio_tokens (prompt_token), speaker_embeddings, speaker_features let audio_features = extract_f32_tensor(&outputs[0])?; let prompt_tokens = extract_i64_tensor(&outputs[1])?; let speaker_embeddings = extract_f32_tensor(&outputs[2])?; let speaker_features = extract_f32_tensor(&outputs[3])?; Ok(VoiceCondition { audio_features, prompt_tokens, speaker_embeddings, speaker_features, }) } fn generate_speech_tokens( &mut self, text_input_ids: &[i64], audio_features: &ArrayD, ) -> Result, TtsError> { let max_new_tokens: usize = 1024; let repetition_penalty: f32 = 1.2; // Start with START_SPEECH_TOKEN let mut generate_tokens: Vec = vec![START_SPEECH_TOKEN]; // Initialize empty KV cache (seq_len = 0) let mut past_key_values = self.init_kv_cache(0)?; let mut first_iteration = true; let mut total_seq_len: usize = 0; for _ in 0..max_new_tokens { // Get embeddings for current input_ids let current_input_ids = if first_iteration { // First iteration: use text input_ids text_input_ids.to_vec() } else { // Subsequent iterations: use last generated token vec![*generate_tokens.last().unwrap()] }; let input_ids_arr = Array2::from_shape_vec( (1, current_input_ids.len()), current_input_ids ).map_err(|e| TtsError::Inference(e.to_string()))?; let input_ids_tensor = Value::from_array(input_ids_arr)?; let inputs_embeds = { let embed_outputs = self.embed_tokens.run(ort::inputs![input_ids_tensor])?; extract_f32_tensor(&embed_outputs[0])? }; // On first iteration, concatenate audio features with text embeddings let inputs_embeds = if first_iteration { let audio_feat_3d = audio_features.view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let text_emb_3d = inputs_embeds.view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; ndarray::concatenate(ndarray::Axis(1), &[audio_feat_3d, text_emb_3d]) .map_err(|e| TtsError::Inference(e.to_string()))? } else { inputs_embeds.view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))? .to_owned() }; let seq_len = inputs_embeds.shape()[1]; // Set up attention mask and position ids let (attention_mask, position_ids) = if first_iteration { total_seq_len = seq_len; let attention_mask: Array2 = Array2::ones((1, seq_len)); let position_ids = Array2::from_shape_fn((1, seq_len), |(_, j)| j as i64); (attention_mask, position_ids) } else { total_seq_len += 1; let attention_mask: Array2 = Array2::ones((1, total_seq_len)); let position_ids = Array2::from_shape_vec( (1, 1), vec![(total_seq_len - 1) as i64] ).map_err(|e| TtsError::Inference(e.to_string()))?; (attention_mask, position_ids) }; // Run language model let (logits, new_kv) = self.run_language_model( inputs_embeds, position_ids, attention_mask, past_key_values, )?; past_key_values = new_kv; // Get last logits let logits_3d = logits.view().into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let last_idx = logits_3d.shape()[1] - 1; let mut current_logits: Vec = logits_3d .slice(ndarray::s![0, last_idx, ..]) .iter() .copied() .collect(); // Apply repetition penalty apply_repetition_penalty(&mut current_logits, &generate_tokens, repetition_penalty); // Get next token let next_token = argmax(¤t_logits); generate_tokens.push(next_token); if next_token == STOP_SPEECH_TOKEN { break; } first_iteration = false; } // Return tokens without START and STOP tokens: [1:-1] if generate_tokens.len() > 2 { Ok(generate_tokens[1..generate_tokens.len()-1].to_vec()) } else { Ok(Vec::new()) } } fn init_kv_cache(&self, seq_len: usize) -> Result>, TtsError> { let mut cache = Vec::with_capacity(NUM_LAYERS * 2); for _ in 0..NUM_LAYERS { let key = Array4::::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM)); let value = Array4::::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM)); cache.push(key); cache.push(value); } Ok(cache) } fn run_language_model( &mut self, inputs_embeds: Array3, position_ids: Array2, attention_mask: Array2, past_key_values: Vec>, ) -> Result<(ArrayD, Vec>), TtsError> { let mut inputs: Vec<(Cow, DynValue)> = Vec::new(); inputs.push((Cow::from("inputs_embeds"), Value::from_array(inputs_embeds)?.into_dyn())); inputs.push((Cow::from("position_ids"), Value::from_array(position_ids)?.into_dyn())); inputs.push((Cow::from("attention_mask"), Value::from_array(attention_mask)?.into_dyn())); // Add KV cache inputs for layer_idx in 0..NUM_LAYERS { let key_name = format!("past_key_values.{}.key", layer_idx); let value_name = format!("past_key_values.{}.value", layer_idx); let key_tensor = Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn(); let value_tensor = Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn(); inputs.push((Cow::from(key_name), key_tensor)); inputs.push((Cow::from(value_name), value_tensor)); } let outputs = self.language_model.run(inputs)?; let logits = extract_f32_tensor(&outputs[0])?; let mut new_kv = Vec::with_capacity(NUM_LAYERS * 2); for layer_idx in 0..NUM_LAYERS { let key_idx = 1 + layer_idx * 2; let value_idx = 2 + layer_idx * 2; let key_arr = extract_f32_tensor(&outputs[key_idx])?; let value_arr = extract_f32_tensor(&outputs[value_idx])?; let key_4d = key_arr.into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let value_4d = value_arr.into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; new_kv.push(key_4d.to_owned()); new_kv.push(value_4d.to_owned()); } Ok((logits, new_kv)) } fn decode_speech_tokens( &mut self, speech_tokens: &[i64], speaker_embeddings: &ArrayD, speaker_features: &ArrayD, ) -> Result, TtsError> { if speech_tokens.is_empty() { return Ok(Vec::new()); } let tokens_arr = Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec()) .map_err(|e| TtsError::Inference(e.to_string()))?; let mut inputs: Vec<(Cow, DynValue)> = Vec::new(); inputs.push((Cow::from("speech_tokens"), Value::from_array(tokens_arr)?.into_dyn())); inputs.push((Cow::from("speaker_embeddings"), Value::from_array(speaker_embeddings.clone())?.into_dyn())); inputs.push((Cow::from("speaker_features"), Value::from_array(speaker_features.clone())?.into_dyn())); let outputs = self.conditional_decoder.run(inputs)?; let waveform = extract_f32_tensor(&outputs[0])?; Ok(waveform.iter().copied().collect()) } } fn download_models(target_dir: &Path) -> Result<(), TtsError> { fs::create_dir_all(target_dir)?; let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; let repo = api.model(MODEL_ID.to_string()); let model_files = [ "onnx/speech_encoder.onnx", "onnx/speech_encoder.onnx_data", "onnx/embed_tokens.onnx", "onnx/embed_tokens.onnx_data", "onnx/language_model.onnx", "onnx/language_model.onnx_data", "onnx/conditional_decoder.onnx", "onnx/conditional_decoder.onnx_data", "tokenizer.json", ]; for file in &model_files { println!("Downloading {}...", file); let downloaded_path = repo.get(file).map_err(|e| TtsError::ModelLoad(e.to_string()))?; let filename = Path::new(file).file_name().unwrap(); let target_path = target_dir.join(filename); if !target_path.exists() { fs::copy(&downloaded_path, &target_path)?; } } println!("Models downloaded to {:?}", target_dir); Ok(()) } fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec { if input_rate == SAMPLE_RATE { return samples.to_vec(); } if samples.is_empty() { return Vec::new(); } let ratio = input_rate as f64 / SAMPLE_RATE as f64; let output_len = ((samples.len() as f64) / ratio).ceil() as usize; let mut output = Vec::with_capacity(output_len); for i in 0..output_len { let src_idx = (i as f64 * ratio) as usize; let sample = samples.get(src_idx).copied().unwrap_or(0.0); output.push(sample); } output } fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) { for &token in generated { if (token as usize) < logits.len() { let score = logits[token as usize]; // Note: opposite of standard - if score < 0, multiply; if > 0, divide logits[token as usize] = if score < 0.0 { score * penalty } else { score / penalty }; } } } fn argmax(logits: &[f32]) -> i64 { logits .iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .map(|(idx, _)| idx as i64) .unwrap_or(0) } pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> { let mut file = fs::File::create(path)?; write_wav(&mut file, samples, SAMPLE_RATE)?; Ok(()) } fn write_wav(writer: &mut W, samples: &[f32], sample_rate: u32) -> Result<(), std::io::Error> { let num_samples = samples.len() as u32; let num_channels: u16 = 1; let bits_per_sample: u16 = 16; let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; let block_align = num_channels * bits_per_sample / 8; let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; let file_size = 36 + data_size; writer.write_all(b"RIFF")?; writer.write_all(&file_size.to_le_bytes())?; writer.write_all(b"WAVE")?; writer.write_all(b"fmt ")?; writer.write_all(&16u32.to_le_bytes())?; writer.write_all(&1u16.to_le_bytes())?; writer.write_all(&num_channels.to_le_bytes())?; writer.write_all(&sample_rate.to_le_bytes())?; writer.write_all(&byte_rate.to_le_bytes())?; writer.write_all(&block_align.to_le_bytes())?; writer.write_all(&bits_per_sample.to_le_bytes())?; writer.write_all(b"data")?; writer.write_all(&data_size.to_le_bytes())?; for &sample in samples { let clamped = sample.clamp(-1.0, 1.0); let int_sample = (clamped * 32767.0) as i16; writer.write_all(&int_sample.to_le_bytes())?; } Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_argmax() { let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2]; assert_eq!(argmax(&logits), 3); } #[test] fn test_resample_same_rate() { let samples = vec![0.1, 0.2, 0.3]; let resampled = resample_to_24k(&samples, SAMPLE_RATE); assert_eq!(resampled, samples); } #[test] fn test_repetition_penalty() { let mut logits = vec![1.0, 2.0, 3.0, 4.0]; let generated = vec![1, 3]; apply_repetition_penalty(&mut logits, &generated, 1.2); // score > 0 -> divide assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6); assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6); } }