//! Chatterbox TTS engine — ONNX-based (legacy). //! //! This is the existing Chatterbox TTS implementation moved from `tts.rs`, //! now implementing the `TtsEngine` trait for unified access. use std::borrow::Cow; use std::fs; use std::path::{Path, PathBuf}; use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex}; use hf_hub::api::sync::Api; use ndarray::{Array2, Array3, Array4, ArrayD, IxDyn}; use ort::session::Session; use ort::value::{DynValue, Value}; use tokenizers::Tokenizer; use crate::audio; use super::{ apply_repetition_penalty, argmax, resample_to_24k, AudioChunk, TtsEngine, TtsError, SAMPLE_RATE, }; const START_SPEECH_TOKEN: i64 = 6561; const STOP_SPEECH_TOKEN: i64 = 6562; const SILENCE_TOKEN: i64 = 4299; const NUM_LAYERS: usize = 24; const NUM_KV_HEADS: usize = 16; const HEAD_DIM: usize = 64; const MODEL_ID: &str = "ResembleAI/chatterbox-turbo-ONNX"; const DEFAULT_MODEL_DIR: &str = "models/chatterbox-turbo"; struct VoiceCondition { audio_features: ArrayD, prompt_tokens: ArrayD, speaker_embeddings: ArrayD, speaker_features: ArrayD, } fn extract_f32_tensor(value: &Value) -> Result, TtsError> { let (shape, data) = value .try_extract_tensor::() .map_err(|e| TtsError::Inference(e.to_string()))?; let dims: Vec = shape.iter().map(|&d| d as usize).collect(); ArrayD::from_shape_vec(IxDyn(&dims), data.to_vec()) .map_err(|e| TtsError::Inference(e.to_string())) } fn extract_i64_tensor(value: &Value) -> Result, TtsError> { let (shape, data) = value .try_extract_tensor::() .map_err(|e| TtsError::Inference(e.to_string()))?; let dims: Vec = shape.iter().map(|&d| d as usize).collect(); ArrayD::from_shape_vec(IxDyn(&dims), data.to_vec()) .map_err(|e| TtsError::Inference(e.to_string())) } pub struct ChatterboxTTS { speech_encoder: Mutex, embed_tokens: Mutex, language_model: Mutex, conditional_decoder: Mutex, tokenizer: Tokenizer, } // SAFETY: Sessions are behind Mutex, Tokenizer is Send+Sync unsafe impl Send for ChatterboxTTS {} unsafe impl Sync for ChatterboxTTS {} impl ChatterboxTTS { pub fn from_pretrained(model_dir: Option<&str>) -> Result { let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); if !model_path.exists() { download_models(&model_path)?; } Self::load_from_path(&model_path) } pub fn load_from_path(model_dir: &Path) -> Result { let speech_encoder = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("speech_encoder.onnx"))?; let embed_tokens = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("embed_tokens.onnx"))?; let language_model = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("language_model.onnx"))?; let conditional_decoder = Session::builder()? .with_intra_threads(4)? .commit_from_file(model_dir.join("conditional_decoder.onnx"))?; let tokenizer_path = model_dir.join("tokenizer.json"); let tokenizer = Tokenizer::from_file(&tokenizer_path) .map_err(|e| TtsError::Tokenizer(e.to_string()))?; Ok(Self { speech_encoder: Mutex::new(speech_encoder), embed_tokens: Mutex::new(embed_tokens), language_model: Mutex::new(language_model), conditional_decoder: Mutex::new(conditional_decoder), tokenizer, }) } pub fn generate_tts(&self) -> Result, TtsError> { Err(TtsError::VoiceRequired) } pub fn generate_tts_with_voice( &self, text: &str, sample_audio_path: &Path, ) -> Result, TtsError> { let audio = audio::to_16k_mono_from_path(sample_audio_path)?; let resampled = resample_to_24k(&audio.samples, audio.sample_rate); self.generate_tts_with_samples(text, &resampled, SAMPLE_RATE) } pub fn generate_tts_with_samples( &self, text: &str, samples: &[f32], sample_rate: u32, ) -> Result, TtsError> { let resampled = if sample_rate != SAMPLE_RATE { resample_to_24k(samples, sample_rate) } else { samples.to_vec() }; let voice_condition = self.encode_voice(&resampled)?; let encoding = self .tokenizer .encode(text, true) .map_err(|e| TtsError::Tokenizer(e.to_string()))?; let text_input_ids: Vec = encoding.get_ids().iter().map(|&id| id as i64).collect(); let generated_tokens = self.generate_speech_tokens(&text_input_ids, &voice_condition.audio_features)?; let prompt_tokens: Vec = voice_condition.prompt_tokens.iter().copied().collect(); let silence_tokens = vec![SILENCE_TOKEN; 3]; let mut final_tokens = Vec::with_capacity(prompt_tokens.len() + generated_tokens.len() + silence_tokens.len()); final_tokens.extend_from_slice(&prompt_tokens); final_tokens.extend_from_slice(&generated_tokens); final_tokens.extend_from_slice(&silence_tokens); let audio_samples = self.decode_speech_tokens( &final_tokens, &voice_condition.speaker_embeddings, &voice_condition.speaker_features, )?; Ok(audio_samples) } fn encode_voice(&self, samples: &[f32]) -> Result { let audio_arr = Array2::from_shape_vec((1, samples.len()), samples.to_vec()) .map_err(|e| TtsError::Inference(e.to_string()))?; let audio_tensor = Value::from_array(audio_arr)?; let mut encoder = self .speech_encoder .lock() .map_err(|e| TtsError::Inference(e.to_string()))?; let outputs = encoder.run(ort::inputs!["audio_values" => audio_tensor])?; let audio_features = extract_f32_tensor(&outputs[0])?; let prompt_tokens = extract_i64_tensor(&outputs[1])?; let speaker_embeddings = extract_f32_tensor(&outputs[2])?; let speaker_features = extract_f32_tensor(&outputs[3])?; Ok(VoiceCondition { audio_features, prompt_tokens, speaker_embeddings, speaker_features, }) } fn generate_speech_tokens( &self, text_input_ids: &[i64], audio_features: &ArrayD, ) -> Result, TtsError> { let max_new_tokens: usize = 1024; let repetition_penalty: f32 = 1.2; let mut generate_tokens: Vec = vec![START_SPEECH_TOKEN]; let mut past_key_values = Self::init_kv_cache(0); let mut first_iteration = true; let mut total_seq_len: usize = 0; for _ in 0..max_new_tokens { let current_input_ids = if first_iteration { text_input_ids.to_vec() } else { vec![*generate_tokens.last().unwrap()] }; let input_ids_arr = Array2::from_shape_vec((1, current_input_ids.len()), current_input_ids) .map_err(|e| TtsError::Inference(e.to_string()))?; let input_ids_tensor = Value::from_array(input_ids_arr)?; let inputs_embeds = { let mut embed = self .embed_tokens .lock() .map_err(|e| TtsError::Inference(e.to_string()))?; let embed_outputs = embed.run(ort::inputs![input_ids_tensor])?; extract_f32_tensor(&embed_outputs[0])? }; let inputs_embeds = if first_iteration { let audio_feat_3d = audio_features .view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let text_emb_3d = inputs_embeds .view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; ndarray::concatenate(ndarray::Axis(1), &[audio_feat_3d, text_emb_3d]) .map_err(|e| TtsError::Inference(e.to_string()))? } else { inputs_embeds .view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))? .to_owned() }; let seq_len = inputs_embeds.shape()[1]; let (attention_mask, position_ids) = if first_iteration { total_seq_len = seq_len; let attention_mask: Array2 = Array2::ones((1, seq_len)); let position_ids = Array2::from_shape_fn((1, seq_len), |(_, j)| j as i64); (attention_mask, position_ids) } else { total_seq_len += 1; let attention_mask: Array2 = Array2::ones((1, total_seq_len)); let position_ids = Array2::from_shape_vec((1, 1), vec![(total_seq_len - 1) as i64]) .map_err(|e| TtsError::Inference(e.to_string()))?; (attention_mask, position_ids) }; let (logits, new_kv) = self.run_language_model( inputs_embeds, position_ids, attention_mask, past_key_values, )?; past_key_values = new_kv; let logits_3d = logits .view() .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let last_idx = logits_3d.shape()[1] - 1; let mut current_logits: Vec = logits_3d .slice(ndarray::s![0, last_idx, ..]) .iter() .copied() .collect(); apply_repetition_penalty(&mut current_logits, &generate_tokens, repetition_penalty); let next_token = argmax(¤t_logits); generate_tokens.push(next_token); if next_token == STOP_SPEECH_TOKEN { break; } first_iteration = false; } if generate_tokens.len() > 2 { Ok(generate_tokens[1..generate_tokens.len() - 1].to_vec()) } else { Ok(Vec::new()) } } fn init_kv_cache(seq_len: usize) -> Vec> { let mut cache = Vec::with_capacity(NUM_LAYERS * 2); for _ in 0..NUM_LAYERS { let key = Array4::::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM)); let value = Array4::::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM)); cache.push(key); cache.push(value); } cache } fn run_language_model( &self, inputs_embeds: Array3, position_ids: Array2, attention_mask: Array2, past_key_values: Vec>, ) -> Result<(ArrayD, Vec>), TtsError> { let mut inputs: Vec<(Cow, DynValue)> = Vec::new(); inputs.push(( Cow::from("inputs_embeds"), Value::from_array(inputs_embeds)?.into_dyn(), )); inputs.push(( Cow::from("position_ids"), Value::from_array(position_ids)?.into_dyn(), )); inputs.push(( Cow::from("attention_mask"), Value::from_array(attention_mask)?.into_dyn(), )); for layer_idx in 0..NUM_LAYERS { let key_name = format!("past_key_values.{}.key", layer_idx); let value_name = format!("past_key_values.{}.value", layer_idx); let key_tensor = Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn(); let value_tensor = Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn(); inputs.push((Cow::from(key_name), key_tensor)); inputs.push((Cow::from(value_name), value_tensor)); } let mut lm = self .language_model .lock() .map_err(|e| TtsError::Inference(e.to_string()))?; let outputs = lm.run(inputs)?; let logits = extract_f32_tensor(&outputs[0])?; let mut new_kv = Vec::with_capacity(NUM_LAYERS * 2); for layer_idx in 0..NUM_LAYERS { let key_idx = 1 + layer_idx * 2; let value_idx = 2 + layer_idx * 2; let key_arr = extract_f32_tensor(&outputs[key_idx])?; let value_arr = extract_f32_tensor(&outputs[value_idx])?; let key_4d = key_arr .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; let value_4d = value_arr .into_dimensionality::() .map_err(|e| TtsError::Inference(e.to_string()))?; new_kv.push(key_4d.to_owned()); new_kv.push(value_4d.to_owned()); } Ok((logits, new_kv)) } fn decode_speech_tokens( &self, speech_tokens: &[i64], speaker_embeddings: &ArrayD, speaker_features: &ArrayD, ) -> Result, TtsError> { if speech_tokens.is_empty() { return Ok(Vec::new()); } let tokens_arr = Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec()) .map_err(|e| TtsError::Inference(e.to_string()))?; let mut inputs: Vec<(Cow, DynValue)> = Vec::new(); inputs.push(( Cow::from("speech_tokens"), Value::from_array(tokens_arr)?.into_dyn(), )); inputs.push(( Cow::from("speaker_embeddings"), Value::from_array(speaker_embeddings.clone())?.into_dyn(), )); inputs.push(( Cow::from("speaker_features"), Value::from_array(speaker_features.clone())?.into_dyn(), )); let mut decoder = self .conditional_decoder .lock() .map_err(|e| TtsError::Inference(e.to_string()))?; let outputs = decoder.run(inputs)?; let waveform = extract_f32_tensor(&outputs[0])?; Ok(waveform.iter().copied().collect()) } } #[async_trait::async_trait] impl TtsEngine for ChatterboxTTS { async fn generate( &self, text: &str, reference_audio: Option<&[f32]>, reference_sample_rate: Option, _cancel_flag: Option>, ) -> Result, TtsError> { let samples = match reference_audio { Some(audio) => { let sr = reference_sample_rate.unwrap_or(SAMPLE_RATE); self.generate_tts_with_samples(text, audio, sr)? } None => return Err(TtsError::VoiceRequired), }; Ok(vec![AudioChunk { samples, sample_rate: SAMPLE_RATE, is_final: true, }]) } fn is_ready(&self) -> bool { true } } fn download_models(target_dir: &Path) -> Result<(), TtsError> { fs::create_dir_all(target_dir)?; let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; let repo = api.model(MODEL_ID.to_string()); let model_files = [ "onnx/speech_encoder.onnx", "onnx/speech_encoder.onnx_data", "onnx/embed_tokens.onnx", "onnx/embed_tokens.onnx_data", "onnx/language_model.onnx", "onnx/language_model.onnx_data", "onnx/conditional_decoder.onnx", "onnx/conditional_decoder.onnx_data", "tokenizer.json", ]; for file in &model_files { println!("Downloading {}...", file); let downloaded_path = repo .get(file) .map_err(|e| TtsError::ModelLoad(e.to_string()))?; let filename = Path::new(file).file_name().unwrap(); let target_path = target_dir.join(filename); if !target_path.exists() { fs::copy(&downloaded_path, &target_path)?; } } println!("Models downloaded to {:?}", target_dir); Ok(()) }