1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
|
use crate::config::PreprocessorConfig;
use crate::error::{Error, Result};
use hound::{WavReader, WavSpec};
use ndarray::Array2;
use std::f32::consts::PI;
use std::path::Path;
pub fn load_audio<P: AsRef<Path>>(path: P) -> Result<(Vec<f32>, WavSpec)> {
let mut reader = WavReader::open(path)?;
let spec = reader.spec();
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader
.samples::<f32>()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Audio(format!("Failed to read float samples: {e}")))?,
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.map(|s| s as f32 / 32768.0))
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Audio(format!("Failed to read int samples: {e}")))?,
};
Ok((samples, spec))
}
pub fn apply_preemphasis(audio: &[f32], coef: f32) -> Vec<f32> {
let mut result = Vec::with_capacity(audio.len());
result.push(audio[0]);
for i in 1..audio.len() {
result.push(audio[i] - coef * audio[i - 1]);
}
result
}
fn hann_window(window_length: usize) -> Vec<f32> {
(0..window_length)
.map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / (window_length as f32 - 1.0)).cos())
.collect()
}
// We use proper FFT here instead of naive DFT because the model was trained
// on correctly computed spectrograms. Naive DFT produces wrong frequency bins
// and the model outputs all blank tokens. RustFFT gives us O(n log n) performance
// and numerically correct results that match what the model expects.
pub fn stft(audio: &[f32], n_fft: usize, hop_length: usize, win_length: usize) -> Array2<f32> {
use rustfft::{num_complex::Complex, FftPlanner};
let window = hann_window(win_length);
let num_frames = (audio.len() - win_length) / hop_length + 1;
let freq_bins = n_fft / 2 + 1;
let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
for frame_idx in 0..num_frames {
let start = frame_idx * hop_length;
let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
for i in 0..win_length.min(audio.len() - start) {
frame[i] = Complex::new(audio[start + i] * window[i], 0.0);
}
fft.process(&mut frame);
for k in 0..freq_bins {
let magnitude = frame[k].norm();
spectrogram[[k, frame_idx]] = magnitude * magnitude;
}
}
spectrogram
}
fn hz_to_mel(freq: f32) -> f32 {
2595.0 * (1.0 + freq / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
}
fn create_mel_filterbank(n_fft: usize, n_mels: usize, sample_rate: usize) -> Array2<f32> {
let freq_bins = n_fft / 2 + 1;
let mut filterbank = Array2::<f32>::zeros((n_mels, freq_bins));
let min_mel = hz_to_mel(0.0);
let max_mel = hz_to_mel(sample_rate as f32 / 2.0);
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_to_hz(min_mel + (max_mel - min_mel) * i as f32 / (n_mels + 1) as f32))
.collect();
let freq_bin_width = sample_rate as f32 / n_fft as f32;
for mel_idx in 0..n_mels {
let left = mel_points[mel_idx];
let center = mel_points[mel_idx + 1];
let right = mel_points[mel_idx + 2];
for freq_idx in 0..freq_bins {
let freq = freq_idx as f32 * freq_bin_width;
if freq >= left && freq <= center {
filterbank[[mel_idx, freq_idx]] = (freq - left) / (center - left);
} else if freq > center && freq <= right {
filterbank[[mel_idx, freq_idx]] = (right - freq) / (right - center);
}
}
}
filterbank
}
/// Extract mel spectrogram features from raw audio samples.
///
/// # Arguments
///
/// * `audio` - Audio samples as f32 values
/// * `sample_rate` - Sample rate in Hz
/// * `channels` - Number of audio channels
/// * `config` - Preprocessor configuration
///
/// # Returns
///
/// 2D array of mel spectrogram features (time_steps x feature_size)
pub fn extract_features_raw(
mut audio: Vec<f32>,
sample_rate: u32,
channels: u16,
config: &PreprocessorConfig,
) -> Result<Array2<f32>> {
if sample_rate != config.sampling_rate as u32 {
return Err(Error::Audio(format!(
"Audio sample rate {} doesn't match expected {}. Please resample your audio first.",
sample_rate, config.sampling_rate
)));
}
if channels > 1 {
let mono: Vec<f32> = audio
.chunks(channels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect();
audio = mono;
}
audio = apply_preemphasis(&audio, config.preemphasis);
let spectrogram = stft(&audio, config.n_fft, config.hop_length, config.win_length);
let mel_filterbank =
create_mel_filterbank(config.n_fft, config.feature_size, config.sampling_rate);
let mel_spectrogram = mel_filterbank.dot(&spectrogram);
let mel_spectrogram = mel_spectrogram.mapv(|x| (x.max(1e-10)).ln());
let mut mel_spectrogram = mel_spectrogram.t().to_owned();
// Normalize each feature dimension to mean=0, std=1
let num_frames = mel_spectrogram.shape()[0];
let num_features = mel_spectrogram.shape()[1];
for feat_idx in 0..num_features {
let mut column = mel_spectrogram.column_mut(feat_idx);
let mean: f32 = column.iter().sum::<f32>() / num_frames as f32;
let variance: f32 =
column.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32;
let std = variance.sqrt().max(1e-10);
for val in column.iter_mut() {
*val = (*val - mean) / std;
}
}
Ok(mel_spectrogram)
}
|