//! TTS engine abstraction and implementations. //! //! Provides a trait-based TTS engine interface using Chatterbox ONNX-based TTS. use std::path::Path; use std::sync::atomic::AtomicBool; use std::sync::Arc; pub mod chatterbox; // Re-export primary types pub use chatterbox::ChatterboxTTS; /// Audio output sample rate (both engines output 24kHz). pub const SAMPLE_RATE: u32 = 24_000; /// A chunk of generated audio for streaming output. #[derive(Debug, Clone)] pub struct AudioChunk { /// PCM f32 samples in [-1.0, 1.0]. pub samples: Vec, /// Sample rate (always 24000 for both engines). pub sample_rate: u32, /// Whether this is the final chunk in the stream. pub is_final: bool, } impl AudioChunk { /// Convert to 16-bit PCM bytes (little-endian) for WebSocket streaming. pub fn to_pcm16_bytes(&self) -> Vec { let mut buf = Vec::with_capacity(self.samples.len() * 2); for &s in &self.samples { let clamped = s.clamp(-1.0, 1.0); let int_sample = (clamped * 32767.0) as i16; buf.extend_from_slice(&int_sample.to_le_bytes()); } buf } } /// Errors that can occur during TTS operations. #[derive(Debug)] pub enum TtsError { ModelLoad(String), Inference(String), Tokenizer(String), Audio(crate::audio::AudioError), Io(std::io::Error), VoiceRequired, } impl std::fmt::Display for TtsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"), TtsError::Inference(msg) => write!(f, "inference error: {msg}"), TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"), TtsError::Audio(err) => write!(f, "audio error: {err}"), TtsError::Io(err) => write!(f, "io error: {err}"), TtsError::VoiceRequired => { write!(f, "voice reference audio is required") } } } } impl std::error::Error for TtsError {} impl From for TtsError { fn from(value: crate::audio::AudioError) -> Self { TtsError::Audio(value) } } impl From for TtsError { fn from(value: std::io::Error) -> Self { TtsError::Io(value) } } impl From for TtsError { fn from(value: ort::Error) -> Self { TtsError::ModelLoad(value.to_string()) } } /// TTS engine trait for text-to-speech synthesis. #[async_trait::async_trait] pub trait TtsEngine: Send + Sync { /// Generate complete audio from text with a voice reference. /// /// The optional `cancel_flag` can be set to `true` by another thread/task /// to request early termination of the generation loop. Engines that /// support cancellation will check this flag periodically and return /// whatever audio has been produced so far. async fn generate( &self, text: &str, reference_audio: Option<&[f32]>, reference_sample_rate: Option, cancel_flag: Option>, ) -> Result, TtsError>; /// Check if the engine is loaded and ready. fn is_ready(&self) -> bool; /// Get the engine's output sample rate. fn sample_rate(&self) -> u32 { SAMPLE_RATE } } /// Factory for creating TTS engines. pub struct TtsEngineFactory; impl TtsEngineFactory { /// Create a Chatterbox TTS engine. pub fn create(model_dir: Option<&str>) -> Result, TtsError> { let engine = ChatterboxTTS::from_pretrained(model_dir)?; Ok(Box::new(engine)) } } /// Save audio samples to a WAV file. pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> { let mut file = std::fs::File::create(path)?; write_wav(&mut file, samples, SAMPLE_RATE)?; Ok(()) } fn write_wav( writer: &mut W, samples: &[f32], sample_rate: u32, ) -> Result<(), std::io::Error> { let num_samples = samples.len() as u32; let num_channels: u16 = 1; let bits_per_sample: u16 = 16; let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; let block_align = num_channels * bits_per_sample / 8; let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; let file_size = 36 + data_size; writer.write_all(b"RIFF")?; writer.write_all(&file_size.to_le_bytes())?; writer.write_all(b"WAVE")?; writer.write_all(b"fmt ")?; writer.write_all(&16u32.to_le_bytes())?; writer.write_all(&1u16.to_le_bytes())?; writer.write_all(&num_channels.to_le_bytes())?; writer.write_all(&sample_rate.to_le_bytes())?; writer.write_all(&byte_rate.to_le_bytes())?; writer.write_all(&block_align.to_le_bytes())?; writer.write_all(&bits_per_sample.to_le_bytes())?; writer.write_all(b"data")?; writer.write_all(&data_size.to_le_bytes())?; for &sample in samples { let clamped = sample.clamp(-1.0, 1.0); let int_sample = (clamped * 32767.0) as i16; writer.write_all(&int_sample.to_le_bytes())?; } Ok(()) } /// Resample audio to 24kHz using simple linear interpolation. pub fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec { if input_rate == SAMPLE_RATE { return samples.to_vec(); } if samples.is_empty() { return Vec::new(); } let ratio = input_rate as f64 / SAMPLE_RATE as f64; let output_len = ((samples.len() as f64) / ratio).ceil() as usize; let mut output = Vec::with_capacity(output_len); for i in 0..output_len { let src_idx = (i as f64 * ratio) as usize; let sample = samples.get(src_idx).copied().unwrap_or(0.0); output.push(sample); } output } /// Apply repetition penalty to logits based on previously generated tokens. pub fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) { for &token in generated { if (token as usize) < logits.len() { let score = logits[token as usize]; logits[token as usize] = if score < 0.0 { score * penalty } else { score / penalty }; } } } /// Return the index of the maximum value in logits. pub fn argmax(logits: &[f32]) -> i64 { logits .iter() .enumerate() .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .map(|(idx, _)| idx as i64) .unwrap_or(0) } #[cfg(test)] mod tests { use super::*; #[test] fn test_argmax() { let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2]; assert_eq!(argmax(&logits), 3); } #[test] fn test_resample_same_rate() { let samples = vec![0.1, 0.2, 0.3]; let resampled = resample_to_24k(&samples, SAMPLE_RATE); assert_eq!(resampled, samples); } #[test] fn test_repetition_penalty() { let mut logits = vec![1.0, 2.0, 3.0, 4.0]; let generated = vec![1, 3]; apply_repetition_penalty(&mut logits, &generated, 1.2); assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6); assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6); } #[test] fn test_audio_chunk_to_pcm16() { let chunk = AudioChunk { samples: vec![0.0, 1.0, -1.0], sample_rate: 24_000, is_final: true, }; let bytes = chunk.to_pcm16_bytes(); assert_eq!(bytes.len(), 6); // 0.0 -> 0i16 assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0); // 1.0 -> 32767i16 assert_eq!(i16::from_le_bytes([bytes[2], bytes[3]]), 32767); // -1.0 -> -32767i16 assert_eq!(i16::from_le_bytes([bytes[4], bytes[5]]), -32767); } }