1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
|
//! TTS engine abstraction and implementations.
//!
//! Provides a trait-based TTS engine interface using Chatterbox ONNX-based TTS.
use std::path::Path;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
pub mod chatterbox;
// Re-export primary types
pub use chatterbox::ChatterboxTTS;
/// Audio output sample rate (both engines output 24kHz).
pub const SAMPLE_RATE: u32 = 24_000;
/// A chunk of generated audio for streaming output.
#[derive(Debug, Clone)]
pub struct AudioChunk {
/// PCM f32 samples in [-1.0, 1.0].
pub samples: Vec<f32>,
/// Sample rate (always 24000 for both engines).
pub sample_rate: u32,
/// Whether this is the final chunk in the stream.
pub is_final: bool,
}
impl AudioChunk {
/// Convert to 16-bit PCM bytes (little-endian) for WebSocket streaming.
pub fn to_pcm16_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.samples.len() * 2);
for &s in &self.samples {
let clamped = s.clamp(-1.0, 1.0);
let int_sample = (clamped * 32767.0) as i16;
buf.extend_from_slice(&int_sample.to_le_bytes());
}
buf
}
}
/// Errors that can occur during TTS operations.
#[derive(Debug)]
pub enum TtsError {
ModelLoad(String),
Inference(String),
Tokenizer(String),
Audio(crate::audio::AudioError),
Io(std::io::Error),
VoiceRequired,
}
impl std::fmt::Display for TtsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"),
TtsError::Inference(msg) => write!(f, "inference error: {msg}"),
TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"),
TtsError::Audio(err) => write!(f, "audio error: {err}"),
TtsError::Io(err) => write!(f, "io error: {err}"),
TtsError::VoiceRequired => {
write!(f, "voice reference audio is required")
}
}
}
}
impl std::error::Error for TtsError {}
impl From<crate::audio::AudioError> for TtsError {
fn from(value: crate::audio::AudioError) -> Self {
TtsError::Audio(value)
}
}
impl From<std::io::Error> for TtsError {
fn from(value: std::io::Error) -> Self {
TtsError::Io(value)
}
}
impl From<ort::Error> for TtsError {
fn from(value: ort::Error) -> Self {
TtsError::ModelLoad(value.to_string())
}
}
/// TTS engine trait for text-to-speech synthesis.
#[async_trait::async_trait]
pub trait TtsEngine: Send + Sync {
/// Generate complete audio from text with a voice reference.
///
/// The optional `cancel_flag` can be set to `true` by another thread/task
/// to request early termination of the generation loop. Engines that
/// support cancellation will check this flag periodically and return
/// whatever audio has been produced so far.
async fn generate(
&self,
text: &str,
reference_audio: Option<&[f32]>,
reference_sample_rate: Option<u32>,
cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<Vec<AudioChunk>, TtsError>;
/// Check if the engine is loaded and ready.
fn is_ready(&self) -> bool;
/// Get the engine's output sample rate.
fn sample_rate(&self) -> u32 {
SAMPLE_RATE
}
}
/// Factory for creating TTS engines.
pub struct TtsEngineFactory;
impl TtsEngineFactory {
/// Create a Chatterbox TTS engine.
pub fn create(model_dir: Option<&str>) -> Result<Box<dyn TtsEngine>, TtsError> {
let engine = ChatterboxTTS::from_pretrained(model_dir)?;
Ok(Box::new(engine))
}
}
/// Save audio samples to a WAV file.
pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> {
let mut file = std::fs::File::create(path)?;
write_wav(&mut file, samples, SAMPLE_RATE)?;
Ok(())
}
fn write_wav<W: std::io::Write>(
writer: &mut W,
samples: &[f32],
sample_rate: u32,
) -> Result<(), std::io::Error> {
let num_samples = samples.len() as u32;
let num_channels: u16 = 1;
let bits_per_sample: u16 = 16;
let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
let block_align = num_channels * bits_per_sample / 8;
let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
let file_size = 36 + data_size;
writer.write_all(b"RIFF")?;
writer.write_all(&file_size.to_le_bytes())?;
writer.write_all(b"WAVE")?;
writer.write_all(b"fmt ")?;
writer.write_all(&16u32.to_le_bytes())?;
writer.write_all(&1u16.to_le_bytes())?;
writer.write_all(&num_channels.to_le_bytes())?;
writer.write_all(&sample_rate.to_le_bytes())?;
writer.write_all(&byte_rate.to_le_bytes())?;
writer.write_all(&block_align.to_le_bytes())?;
writer.write_all(&bits_per_sample.to_le_bytes())?;
writer.write_all(b"data")?;
writer.write_all(&data_size.to_le_bytes())?;
for &sample in samples {
let clamped = sample.clamp(-1.0, 1.0);
let int_sample = (clamped * 32767.0) as i16;
writer.write_all(&int_sample.to_le_bytes())?;
}
Ok(())
}
/// Resample audio to 24kHz using simple linear interpolation.
pub fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec<f32> {
if input_rate == SAMPLE_RATE {
return samples.to_vec();
}
if samples.is_empty() {
return Vec::new();
}
let ratio = input_rate as f64 / SAMPLE_RATE as f64;
let output_len = ((samples.len() as f64) / ratio).ceil() as usize;
let mut output = Vec::with_capacity(output_len);
for i in 0..output_len {
let src_idx = (i as f64 * ratio) as usize;
let sample = samples.get(src_idx).copied().unwrap_or(0.0);
output.push(sample);
}
output
}
/// Apply repetition penalty to logits based on previously generated tokens.
pub fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) {
for &token in generated {
if (token as usize) < logits.len() {
let score = logits[token as usize];
logits[token as usize] = if score < 0.0 {
score * penalty
} else {
score / penalty
};
}
}
}
/// Return the index of the maximum value in logits.
pub fn argmax(logits: &[f32]) -> i64 {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as i64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_argmax() {
let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
assert_eq!(argmax(&logits), 3);
}
#[test]
fn test_resample_same_rate() {
let samples = vec![0.1, 0.2, 0.3];
let resampled = resample_to_24k(&samples, SAMPLE_RATE);
assert_eq!(resampled, samples);
}
#[test]
fn test_repetition_penalty() {
let mut logits = vec![1.0, 2.0, 3.0, 4.0];
let generated = vec![1, 3];
apply_repetition_penalty(&mut logits, &generated, 1.2);
assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6);
assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6);
}
#[test]
fn test_audio_chunk_to_pcm16() {
let chunk = AudioChunk {
samples: vec![0.0, 1.0, -1.0],
sample_rate: 24_000,
is_final: true,
};
let bytes = chunk.to_pcm16_bytes();
assert_eq!(bytes.len(), 6);
// 0.0 -> 0i16
assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0);
// 1.0 -> 32767i16
assert_eq!(i16::from_le_bytes([bytes[2], bytes[3]]), 32767);
// -1.0 -> -32767i16
assert_eq!(i16::from_le_bytes([bytes[4], bytes[5]]), -32767);
}
}
|