From 55cacf6e1a087c0fa6950a1ddeb09060f787e541 Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 21 Dec 2025 00:40:04 +0000 Subject: Add EOU detection and streaming diarization --- parakeet-rs/src/audio.rs | 179 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 parakeet-rs/src/audio.rs (limited to 'parakeet-rs/src/audio.rs') diff --git a/parakeet-rs/src/audio.rs b/parakeet-rs/src/audio.rs new file mode 100644 index 0000000..84d2616 --- /dev/null +++ b/parakeet-rs/src/audio.rs @@ -0,0 +1,179 @@ +use crate::config::PreprocessorConfig; +use crate::error::{Error, Result}; +use hound::{WavReader, WavSpec}; +use ndarray::Array2; +use std::f32::consts::PI; +use std::path::Path; + +pub fn load_audio>(path: P) -> Result<(Vec, WavSpec)> { + let mut reader = WavReader::open(path)?; + let spec = reader.spec(); + + let samples: Vec = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::() + .collect::, _>>() + .map_err(|e| Error::Audio(format!("Failed to read float samples: {e}")))?, + hound::SampleFormat::Int => reader + .samples::() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::, _>>() + .map_err(|e| Error::Audio(format!("Failed to read int samples: {e}")))?, + }; + + Ok((samples, spec)) +} + +pub fn apply_preemphasis(audio: &[f32], coef: f32) -> Vec { + let mut result = Vec::with_capacity(audio.len()); + result.push(audio[0]); + + for i in 1..audio.len() { + result.push(audio[i] - coef * audio[i - 1]); + } + + result +} + +fn hann_window(window_length: usize) -> Vec { + (0..window_length) + .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / (window_length as f32 - 1.0)).cos()) + .collect() +} + +// We use proper FFT here instead of naive DFT because the model was trained +// on correctly computed spectrograms. Naive DFT produces wrong frequency bins +// and the model outputs all blank tokens. RustFFT gives us O(n log n) performance +// and numerically correct results that match what the model expects. +pub fn stft(audio: &[f32], n_fft: usize, hop_length: usize, win_length: usize) -> Array2 { + use rustfft::{num_complex::Complex, FftPlanner}; + + let window = hann_window(win_length); + let num_frames = (audio.len() - win_length) / hop_length + 1; + let freq_bins = n_fft / 2 + 1; + let mut spectrogram = Array2::::zeros((freq_bins, num_frames)); + + let mut planner = FftPlanner::::new(); + let fft = planner.plan_fft_forward(n_fft); + + for frame_idx in 0..num_frames { + let start = frame_idx * hop_length; + + let mut frame: Vec> = vec![Complex::new(0.0, 0.0); n_fft]; + for i in 0..win_length.min(audio.len() - start) { + frame[i] = Complex::new(audio[start + i] * window[i], 0.0); + } + + fft.process(&mut frame); + + for k in 0..freq_bins { + let magnitude = frame[k].norm(); + spectrogram[[k, frame_idx]] = magnitude * magnitude; + } + } + + spectrogram +} + +fn hz_to_mel(freq: f32) -> f32 { + 2595.0 * (1.0 + freq / 700.0).log10() +} + +fn mel_to_hz(mel: f32) -> f32 { + 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0) +} + +fn create_mel_filterbank(n_fft: usize, n_mels: usize, sample_rate: usize) -> Array2 { + let freq_bins = n_fft / 2 + 1; + let mut filterbank = Array2::::zeros((n_mels, freq_bins)); + + let min_mel = hz_to_mel(0.0); + let max_mel = hz_to_mel(sample_rate as f32 / 2.0); + + let mel_points: Vec = (0..=n_mels + 1) + .map(|i| mel_to_hz(min_mel + (max_mel - min_mel) * i as f32 / (n_mels + 1) as f32)) + .collect(); + + let freq_bin_width = sample_rate as f32 / n_fft as f32; + + for mel_idx in 0..n_mels { + let left = mel_points[mel_idx]; + let center = mel_points[mel_idx + 1]; + let right = mel_points[mel_idx + 2]; + + for freq_idx in 0..freq_bins { + let freq = freq_idx as f32 * freq_bin_width; + + if freq >= left && freq <= center { + filterbank[[mel_idx, freq_idx]] = (freq - left) / (center - left); + } else if freq > center && freq <= right { + filterbank[[mel_idx, freq_idx]] = (right - freq) / (right - center); + } + } + } + + filterbank +} + +/// Extract mel spectrogram features from raw audio samples. +/// +/// # Arguments +/// +/// * `audio` - Audio samples as f32 values +/// * `sample_rate` - Sample rate in Hz +/// * `channels` - Number of audio channels +/// * `config` - Preprocessor configuration +/// +/// # Returns +/// +/// 2D array of mel spectrogram features (time_steps x feature_size) +pub fn extract_features_raw( + mut audio: Vec, + sample_rate: u32, + channels: u16, + config: &PreprocessorConfig, +) -> Result> { + if sample_rate != config.sampling_rate as u32 { + return Err(Error::Audio(format!( + "Audio sample rate {} doesn't match expected {}. Please resample your audio first.", + sample_rate, config.sampling_rate + ))); + } + + if channels > 1 { + let mono: Vec = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::() / channels as f32) + .collect(); + audio = mono; + } + + audio = apply_preemphasis(&audio, config.preemphasis); + + let spectrogram = stft(&audio, config.n_fft, config.hop_length, config.win_length); + + let mel_filterbank = + create_mel_filterbank(config.n_fft, config.feature_size, config.sampling_rate); + let mel_spectrogram = mel_filterbank.dot(&spectrogram); + let mel_spectrogram = mel_spectrogram.mapv(|x| (x.max(1e-10)).ln()); + + let mut mel_spectrogram = mel_spectrogram.t().to_owned(); + + // Normalize each feature dimension to mean=0, std=1 + let num_frames = mel_spectrogram.shape()[0]; + let num_features = mel_spectrogram.shape()[1]; + + for feat_idx in 0..num_features { + let mut column = mel_spectrogram.column_mut(feat_idx); + let mean: f32 = column.iter().sum::() / num_frames as f32; + let variance: f32 = + column.iter().map(|&x| (x - mean).powi(2)).sum::() / num_frames as f32; + let std = variance.sqrt().max(1e-10); + + for val in column.iter_mut() { + *val = (*val - mean) / std; + } + } + + Ok(mel_spectrogram) +} -- cgit v1.2.3