diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 01:27:02 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch) | |
| tree | 497bffd67001501a003739cfe0bb790502ffd50a /parakeet-rs/src | |
| parent | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff) | |
| download | soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip | |
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'parakeet-rs/src')
| -rw-r--r-- | parakeet-rs/src/audio.rs | 179 | ||||
| -rw-r--r-- | parakeet-rs/src/config.rs | 51 | ||||
| -rw-r--r-- | parakeet-rs/src/decoder.rs | 211 | ||||
| -rw-r--r-- | parakeet-rs/src/decoder_tdt.rs | 63 | ||||
| -rw-r--r-- | parakeet-rs/src/error.rs | 52 | ||||
| -rw-r--r-- | parakeet-rs/src/execution.rs | 141 | ||||
| -rw-r--r-- | parakeet-rs/src/lib.rs | 74 | ||||
| -rw-r--r-- | parakeet-rs/src/model.rs | 93 | ||||
| -rw-r--r-- | parakeet-rs/src/model_eou.rs | 183 | ||||
| -rw-r--r-- | parakeet-rs/src/model_tdt.rs | 263 | ||||
| -rw-r--r-- | parakeet-rs/src/parakeet.rs | 210 | ||||
| -rw-r--r-- | parakeet-rs/src/parakeet_eou.rs | 304 | ||||
| -rw-r--r-- | parakeet-rs/src/parakeet_tdt.rs | 167 | ||||
| -rw-r--r-- | parakeet-rs/src/sortformer.rs | 1062 | ||||
| -rw-r--r-- | parakeet-rs/src/timestamps.rs | 280 | ||||
| -rw-r--r-- | parakeet-rs/src/vocab.rs | 63 |
16 files changed, 0 insertions, 3396 deletions
diff --git a/parakeet-rs/src/audio.rs b/parakeet-rs/src/audio.rs deleted file mode 100644 index 84d2616..0000000 --- a/parakeet-rs/src/audio.rs +++ /dev/null @@ -1,179 +0,0 @@ -use crate::config::PreprocessorConfig; -use crate::error::{Error, Result}; -use hound::{WavReader, WavSpec}; -use ndarray::Array2; -use std::f32::consts::PI; -use std::path::Path; - -pub fn load_audio<P: AsRef<Path>>(path: P) -> Result<(Vec<f32>, WavSpec)> { - let mut reader = WavReader::open(path)?; - let spec = reader.spec(); - - let samples: Vec<f32> = match spec.sample_format { - hound::SampleFormat::Float => reader - .samples::<f32>() - .collect::<std::result::Result<Vec<_>, _>>() - .map_err(|e| Error::Audio(format!("Failed to read float samples: {e}")))?, - hound::SampleFormat::Int => reader - .samples::<i16>() - .map(|s| s.map(|s| s as f32 / 32768.0)) - .collect::<std::result::Result<Vec<_>, _>>() - .map_err(|e| Error::Audio(format!("Failed to read int samples: {e}")))?, - }; - - Ok((samples, spec)) -} - -pub fn apply_preemphasis(audio: &[f32], coef: f32) -> Vec<f32> { - let mut result = Vec::with_capacity(audio.len()); - result.push(audio[0]); - - for i in 1..audio.len() { - result.push(audio[i] - coef * audio[i - 1]); - } - - result -} - -fn hann_window(window_length: usize) -> Vec<f32> { - (0..window_length) - .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / (window_length as f32 - 1.0)).cos()) - .collect() -} - -// We use proper FFT here instead of naive DFT because the model was trained -// on correctly computed spectrograms. Naive DFT produces wrong frequency bins -// and the model outputs all blank tokens. RustFFT gives us O(n log n) performance -// and numerically correct results that match what the model expects. -pub fn stft(audio: &[f32], n_fft: usize, hop_length: usize, win_length: usize) -> Array2<f32> { - use rustfft::{num_complex::Complex, FftPlanner}; - - let window = hann_window(win_length); - let num_frames = (audio.len() - win_length) / hop_length + 1; - let freq_bins = n_fft / 2 + 1; - let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames)); - - let mut planner = FftPlanner::<f32>::new(); - let fft = planner.plan_fft_forward(n_fft); - - for frame_idx in 0..num_frames { - let start = frame_idx * hop_length; - - let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft]; - for i in 0..win_length.min(audio.len() - start) { - frame[i] = Complex::new(audio[start + i] * window[i], 0.0); - } - - fft.process(&mut frame); - - for k in 0..freq_bins { - let magnitude = frame[k].norm(); - spectrogram[[k, frame_idx]] = magnitude * magnitude; - } - } - - spectrogram -} - -fn hz_to_mel(freq: f32) -> f32 { - 2595.0 * (1.0 + freq / 700.0).log10() -} - -fn mel_to_hz(mel: f32) -> f32 { - 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0) -} - -fn create_mel_filterbank(n_fft: usize, n_mels: usize, sample_rate: usize) -> Array2<f32> { - let freq_bins = n_fft / 2 + 1; - let mut filterbank = Array2::<f32>::zeros((n_mels, freq_bins)); - - let min_mel = hz_to_mel(0.0); - let max_mel = hz_to_mel(sample_rate as f32 / 2.0); - - let mel_points: Vec<f32> = (0..=n_mels + 1) - .map(|i| mel_to_hz(min_mel + (max_mel - min_mel) * i as f32 / (n_mels + 1) as f32)) - .collect(); - - let freq_bin_width = sample_rate as f32 / n_fft as f32; - - for mel_idx in 0..n_mels { - let left = mel_points[mel_idx]; - let center = mel_points[mel_idx + 1]; - let right = mel_points[mel_idx + 2]; - - for freq_idx in 0..freq_bins { - let freq = freq_idx as f32 * freq_bin_width; - - if freq >= left && freq <= center { - filterbank[[mel_idx, freq_idx]] = (freq - left) / (center - left); - } else if freq > center && freq <= right { - filterbank[[mel_idx, freq_idx]] = (right - freq) / (right - center); - } - } - } - - filterbank -} - -/// Extract mel spectrogram features from raw audio samples. -/// -/// # Arguments -/// -/// * `audio` - Audio samples as f32 values -/// * `sample_rate` - Sample rate in Hz -/// * `channels` - Number of audio channels -/// * `config` - Preprocessor configuration -/// -/// # Returns -/// -/// 2D array of mel spectrogram features (time_steps x feature_size) -pub fn extract_features_raw( - mut audio: Vec<f32>, - sample_rate: u32, - channels: u16, - config: &PreprocessorConfig, -) -> Result<Array2<f32>> { - if sample_rate != config.sampling_rate as u32 { - return Err(Error::Audio(format!( - "Audio sample rate {} doesn't match expected {}. Please resample your audio first.", - sample_rate, config.sampling_rate - ))); - } - - if channels > 1 { - let mono: Vec<f32> = audio - .chunks(channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) - .collect(); - audio = mono; - } - - audio = apply_preemphasis(&audio, config.preemphasis); - - let spectrogram = stft(&audio, config.n_fft, config.hop_length, config.win_length); - - let mel_filterbank = - create_mel_filterbank(config.n_fft, config.feature_size, config.sampling_rate); - let mel_spectrogram = mel_filterbank.dot(&spectrogram); - let mel_spectrogram = mel_spectrogram.mapv(|x| (x.max(1e-10)).ln()); - - let mut mel_spectrogram = mel_spectrogram.t().to_owned(); - - // Normalize each feature dimension to mean=0, std=1 - let num_frames = mel_spectrogram.shape()[0]; - let num_features = mel_spectrogram.shape()[1]; - - for feat_idx in 0..num_features { - let mut column = mel_spectrogram.column_mut(feat_idx); - let mean: f32 = column.iter().sum::<f32>() / num_frames as f32; - let variance: f32 = - column.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32; - let std = variance.sqrt().max(1e-10); - - for val in column.iter_mut() { - *val = (*val - mean) / std; - } - } - - Ok(mel_spectrogram) -} diff --git a/parakeet-rs/src/config.rs b/parakeet-rs/src/config.rs deleted file mode 100644 index 1dae890..0000000 --- a/parakeet-rs/src/config.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PreprocessorConfig { - pub feature_extractor_type: String, - pub feature_size: usize, - pub hop_length: usize, - pub n_fft: usize, - pub padding_side: String, - pub padding_value: f32, - pub preemphasis: f32, - pub processor_class: String, - pub return_attention_mask: bool, - pub sampling_rate: usize, - pub win_length: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelConfig { - pub architectures: Vec<String>, - pub vocab_size: usize, - pub pad_token_id: usize, -} - -impl Default for PreprocessorConfig { - fn default() -> Self { - Self { - feature_extractor_type: "ParakeetFeatureExtractor".to_string(), - feature_size: 80, - hop_length: 160, - n_fft: 512, - padding_side: "right".to_string(), - padding_value: 0.0, - preemphasis: 0.97, - processor_class: "ParakeetProcessor".to_string(), - return_attention_mask: true, - sampling_rate: 16000, - win_length: 400, - } - } -} - -impl Default for ModelConfig { - fn default() -> Self { - Self { - architectures: vec!["ParakeetForCTC".to_string()], - vocab_size: 1025, - pad_token_id: 1024, - } - } -} diff --git a/parakeet-rs/src/decoder.rs b/parakeet-rs/src/decoder.rs deleted file mode 100644 index 6da6d65..0000000 --- a/parakeet-rs/src/decoder.rs +++ /dev/null @@ -1,211 +0,0 @@ -use crate::error::{Error, Result}; -use ndarray::Array2; -use std::path::Path; - -// Token with its timestamp information -// start and end are in seconds -#[derive(Debug, Clone)] -pub struct TimedToken { - pub text: String, - pub start: f32, - pub end: f32, -} - -#[derive(Debug, Clone)] -pub struct TranscriptionResult { - pub text: String, - pub tokens: Vec<TimedToken>, -} - -// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps -pub struct ParakeetDecoder { - tokenizer: tokenizers::Tokenizer, - pad_token_id: usize, -} - -impl ParakeetDecoder { - pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> { - let tokenizer_path = tokenizer_path.as_ref(); - - let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) - .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?; - - // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main) - let pad_token_id = 1024; - - Ok(Self { - tokenizer, - pad_token_id, - }) - } - - pub fn decode(&self, logits: &Array2<f32>) -> Result<String> { - let time_steps = logits.shape()[0]; - - let mut token_ids = Vec::new(); - for t in 0..time_steps { - let logits_t = logits.row(t); - let max_idx = logits_t - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(0); - - token_ids.push(max_idx as u32); - } - - let collapsed = self.ctc_collapse(&token_ids); - - let text = self - .tokenizer - .decode(&collapsed, true) - .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?; - - Ok(text) - } - - fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> { - let mut result = Vec::new(); - let mut prev_token: Option<u32> = None; - - for &token_id in token_ids { - if token_id == self.pad_token_id as u32 { - prev_token = Some(token_id); - continue; - } - - if Some(token_id) != prev_token { - result.push(token_id); - } - - prev_token = Some(token_id); - } - - result - } - - // CTC collapse with frame tracking for timestamps - fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> { - let mut result: Vec<(u32, usize, usize)> = Vec::new(); - let mut prev_token: Option<u32> = None; - - for &(token_id, frame) in token_ids.iter() { - if token_id == self.pad_token_id as u32 { - prev_token = Some(token_id); - continue; - } - - if Some(token_id) != prev_token { - if let Some(prev) = prev_token { - if prev != self.pad_token_id as u32 { - // End previous token - if let Some(last) = result.last_mut() { - last.2 = frame; - } - } - } - // Start new token - result.push((token_id, frame, frame)); - } - - prev_token = Some(token_id); - } - - // Close last token - if let Some(last) = result.last_mut() { - last.2 = token_ids.len(); - } - - result - } - - // Decode with token-level timestamps - // hop_length and sample_rate are needed to convert frames to seconds - pub fn decode_with_timestamps( - &self, - logits: &Array2<f32>, - hop_length: usize, - sample_rate: usize, - ) -> Result<TranscriptionResult> { - let time_steps = logits.shape()[0]; - - let mut token_ids_with_frames = Vec::new(); - for t in 0..time_steps { - let logits_t = logits.row(t); - let max_idx = logits_t - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(0); - - token_ids_with_frames.push((max_idx as u32, t)); - } - - // CTC collapse with frame tracking - let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames); - - // Extract just token IDs for decoding - let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect(); - - // Decode full text - let full_text = self - .tokenizer - .decode(&token_ids, true) - .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?; - - // Progressive decode to detect word boundaries - // BPE tokenizers only add spaces when decoding sequences, not individual tokens - let mut timed_tokens = Vec::new(); - let mut prev_decode = String::new(); - - for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() { - // Decode from start up to and including current token - let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i] - .iter() - .map(|(id, _, _)| *id) - .collect(); - - if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) { - // Find what this token added - let added_text = if curr_decode.len() > prev_decode.len() { - &curr_decode[prev_decode.len()..] - } else { - "" - }; - - if !added_text.is_empty() { - let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32; - let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32; - - timed_tokens.push(TimedToken { - text: added_text.to_string(), - start: start_time, - end: end_time, - }); - } - - prev_decode = curr_decode; - } - } - - Ok(TranscriptionResult { - text: full_text, - tokens: timed_tokens, - }) - } - - // Stub - falls back to greedy decoding. Full beam search with language model is TODO. - pub fn decode_with_beam_search( - &self, - logits: &Array2<f32>, - _beam_width: usize, - ) -> Result<String> { - self.decode(logits) - } - - pub fn pad_token_id(&self) -> usize { - self.pad_token_id - } -} diff --git a/parakeet-rs/src/decoder_tdt.rs b/parakeet-rs/src/decoder_tdt.rs deleted file mode 100644 index 65f576d..0000000 --- a/parakeet-rs/src/decoder_tdt.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::decoder::TranscriptionResult; -use crate::error::Result; -use crate::vocab::Vocabulary; - -/// TDT greedy decoder for Parakeet TDT models -#[derive(Debug)] -pub struct ParakeetTDTDecoder { - vocab: Vocabulary, -} - -impl ParakeetTDTDecoder { - /// Load decoder from vocab file - pub fn from_vocab(vocab: Vocabulary) -> Self { - Self { vocab } - } - - /// Decode tokens with timestamps - /// For TDT models, greedy decoding is done in the model, here we just convert to text - pub fn decode_with_timestamps( - &self, - tokens: &[usize], - frame_indices: &[usize], - _durations: &[usize], - hop_length: usize, - sample_rate: usize, - ) -> Result<TranscriptionResult> { - let mut result_tokens = Vec::new(); - let mut full_text = String::new(); - // TDT encoder does 8x subsampling - let encoder_stride = 8; - - for (i, &token_id) in tokens.iter().enumerate() { - if let Some(token_text) = self.vocab.id_to_text(token_id) { - let frame = frame_indices[i]; - let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32; - let end = if i + 1 < frame_indices.len() { - (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32 - } else { - start + 0.01 - }; - - // Handle SentencePiece format (▁ prefix for word start) - let display_text = token_text.replace('▁', " "); - - // Skip special tokens - if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") { - full_text.push_str(&display_text); - - result_tokens.push(crate::decoder::TimedToken { - text: display_text, - start, - end, - }); - } - } - } - - Ok(TranscriptionResult { - text: full_text.trim().to_string(), - tokens: result_tokens, - }) - } -} diff --git a/parakeet-rs/src/error.rs b/parakeet-rs/src/error.rs deleted file mode 100644 index 690e0e5..0000000 --- a/parakeet-rs/src/error.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::fmt; - -pub type Result<T> = std::result::Result<T, Error>; - -#[derive(Debug)] -pub enum Error { - Io(std::io::Error), - Ort(ort::Error), - Audio(String), - Model(String), - Tokenizer(String), - Config(String), -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Error::Io(e) => write!(f, "IO error: {e}"), - Error::Ort(e) => write!(f, "ONNX Runtime error: {e}"), - Error::Audio(msg) => write!(f, "Audio processing error: {msg}"), - Error::Model(msg) => write!(f, "Model error: {msg}"), - Error::Tokenizer(msg) => write!(f, "Tokenizer error: {msg}"), - Error::Config(msg) => write!(f, "Config error: {msg}"), - } - } -} - -impl std::error::Error for Error {} - -impl From<std::io::Error> for Error { - fn from(e: std::io::Error) -> Self { - Error::Io(e) - } -} - -impl From<ort::Error> for Error { - fn from(e: ort::Error) -> Self { - Error::Ort(e) - } -} - -impl From<serde_json::Error> for Error { - fn from(e: serde_json::Error) -> Self { - Error::Config(e.to_string()) - } -} - -impl From<hound::Error> for Error { - fn from(e: hound::Error) -> Self { - Error::Audio(e.to_string()) - } -} diff --git a/parakeet-rs/src/execution.rs b/parakeet-rs/src/execution.rs deleted file mode 100644 index e29aa1d..0000000 --- a/parakeet-rs/src/execution.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::error::Result; -use ort::session::builder::SessionBuilder; - -// Hardware acceleration options. CPU is default and most reliable. -// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware. -// All GPU providers automatically fall back to CPU if they fail. -// -// Note: CoreML currently fails with this model due to unsupported operations. -// WebGPU is experimental and may produce incorrect results. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum ExecutionProvider { - #[default] - Cpu, - #[cfg(feature = "cuda")] - Cuda, - #[cfg(feature = "tensorrt")] - TensorRT, - #[cfg(feature = "coreml")] - CoreML, - #[cfg(feature = "directml")] - DirectML, - #[cfg(feature = "rocm")] - ROCm, - #[cfg(feature = "openvino")] - OpenVINO, - #[cfg(feature = "webgpu")] - WebGPU, -} - -#[derive(Debug, Clone)] -pub struct ModelConfig { - pub execution_provider: ExecutionProvider, - pub intra_threads: usize, - pub inter_threads: usize, -} - -impl Default for ModelConfig { - fn default() -> Self { - Self { - execution_provider: ExecutionProvider::default(), - intra_threads: 4, - inter_threads: 1, - } - } -} - -impl ModelConfig { - pub fn new() -> Self { - Self::default() - } - - pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self { - self.execution_provider = provider; - self - } - - pub fn with_intra_threads(mut self, threads: usize) -> Self { - self.intra_threads = threads; - self - } - - pub fn with_inter_threads(mut self, threads: usize) -> Self { - self.inter_threads = threads; - self - } - - pub(crate) fn apply_to_session_builder( - &self, - builder: SessionBuilder, - ) -> Result<SessionBuilder> { - use ort::session::builder::GraphOptimizationLevel; - #[cfg(any( - feature = "cuda", - feature = "tensorrt", - feature = "coreml", - feature = "directml", - feature = "rocm", - feature = "openvino", - feature = "webgpu" - ))] - use ort::execution_providers::CPUExecutionProvider; - - let mut builder = builder - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(self.intra_threads)? - .with_inter_threads(self.inter_threads)?; - - builder = match self.execution_provider { - ExecutionProvider::Cpu => builder, - - #[cfg(feature = "cuda")] - ExecutionProvider::Cuda => builder.with_execution_providers([ - ort::execution_providers::CUDAExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "tensorrt")] - ExecutionProvider::TensorRT => builder.with_execution_providers([ - ort::execution_providers::TensorRTExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "coreml")] - ExecutionProvider::CoreML => { - use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider}; - builder.with_execution_providers([ - CoreMLExecutionProvider::default() - .with_compute_units(CoreMLComputeUnits::CPUAndGPU) - .build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])? - } - - #[cfg(feature = "directml")] - ExecutionProvider::DirectML => builder.with_execution_providers([ - ort::execution_providers::DirectMLExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "rocm")] - ExecutionProvider::ROCm => builder.with_execution_providers([ - ort::execution_providers::ROCMExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "openvino")] - ExecutionProvider::OpenVINO => builder.with_execution_providers([ - ort::execution_providers::OpenVINOExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "webgpu")] - ExecutionProvider::WebGPU => builder.with_execution_providers([ - ort::execution_providers::WebGPUExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - }; - - Ok(builder) - } -} diff --git a/parakeet-rs/src/lib.rs b/parakeet-rs/src/lib.rs deleted file mode 100644 index 0aaefd1..0000000 --- a/parakeet-rs/src/lib.rs +++ /dev/null @@ -1,74 +0,0 @@ -//! # parakeet-rs -//! -//! Rust bindings for NVIDIA's Parakeet speech recognition model using ONNX Runtime. -//! -//! Parakeet is a state-of-the-art automatic speech recognition (ASR) model developed by NVIDIA, -//! based on the FastConformer-TDT architecture with 600 million parameters. -//! -//! ## Features -//! -//! - Easy-to-use API for speech-to-text transcription -//! - Support for ONNX format models -//! - 16kHz mono audio input -//! - Punctuation and capitalization included in output -//! - Fast inference using ONNX Runtime -//! -//! ## Quick Start -//! -//! ```ignore -//! use parakeet_rs::Parakeet; -//! -//! // Load the model -//! let parakeet = Parakeet::from_pretrained(".")?; -//! -//! // Transcribe audio file -//! let text = parakeet.transcribe_file("audio.wav")?; -//! println!("Transcription: {}", text); -//! ``` -//! -//! ## Model Requirements -//! -//! Your model directory should contain: -//! - `model.onnx` - The ONNX model file -//! - `model.onnx_data` - External model weights -//! - `config.json` - Model configuration -//! - `preprocessor_config.json` - Audio preprocessing configuration -//! - `tokenizer.json` - Tokenizer vocabulary -//! - `tokenizer_config.json` - Tokenizer configuration -//! -//! ## Audio Requirements -//! -//! - Format: WAV -//! - Sample Rate: 16kHz -//! - Channels: Mono (stereo will be converted automatically) -//! - Bit Depth: 16-bit PCM or 32-bit float - -mod audio; -mod config; -mod decoder; -mod decoder_tdt; -mod error; -mod execution; -mod model; -mod model_tdt; -mod parakeet; -mod parakeet_tdt; -mod timestamps; -mod vocab; -mod model_eou; -mod parakeet_eou; -#[cfg(feature = "sortformer")] -pub mod sortformer; - -pub use error::{Error, Result}; -pub use execution::{ExecutionProvider, ModelConfig as ExecutionConfig}; -pub use parakeet::Parakeet; -pub use parakeet_tdt::ParakeetTDT; -pub use timestamps::TimestampMode; - -pub use config::{ModelConfig as ModelConfigJson, PreprocessorConfig}; - -pub use decoder::{ParakeetDecoder, TimedToken, TranscriptionResult}; -pub use model::ParakeetModel; -pub use model_eou::ParakeetEOUModel; -pub use parakeet_eou::ParakeetEOU;
\ No newline at end of file diff --git a/parakeet-rs/src/model.rs b/parakeet-rs/src/model.rs deleted file mode 100644 index b3cd131..0000000 --- a/parakeet-rs/src/model.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::config::ModelConfig; -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::Array2; -use ort::session::Session; -use std::path::Path; - -pub struct ParakeetModel { - session: Session, - config: ModelConfig, -} - -impl ParakeetModel { - pub fn from_pretrained<P: AsRef<Path>>(model_path: P) -> Result<Self> { - Self::from_pretrained_with_config(model_path, ExecutionConfig::default()) - } - - pub fn from_pretrained_with_config<P: AsRef<Path>>( - model_path: P, - exec_config: ExecutionConfig, - ) -> Result<Self> { - let model_path = model_path.as_ref(); - - // Use default config (hardcoded constants for Parakeet-CTC-0.6b: please see: json files https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main) - let config = ModelConfig::default(); - - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let session = builder.commit_from_file(model_path)?; - - Ok(Self { session, config }) - } - pub fn forward(&mut self, features: Array2<f32>) -> Result<Array2<f32>> { - let batch_size = 1; - let time_steps = features.shape()[0]; - let feature_size = features.shape()[1]; - - let input = features - .to_shape((batch_size, time_steps, feature_size)) - .map_err(|e| Error::Model(format!("Failed to reshape input: {e}")))? - .to_owned(); - - use ndarray::Array2; - let attention_mask = Array2::<i64>::ones((batch_size, time_steps)); - - let input_value = ort::value::Value::from_array(input)?; - let attention_mask_value = ort::value::Value::from_array(attention_mask)?; - - let outputs = self.session.run(ort::inputs!( - "input_features" => input_value, - "attention_mask" => attention_mask_value - ))?; - - let logits_value = &outputs["logits"]; - let (shape, data) = logits_value - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - let shape_dims = shape.as_ref(); - if shape_dims.len() != 3 { - return Err(Error::Model(format!( - "Expected 3D logits, got shape: {shape_dims:?}" - ))); - } - - let batch_size = shape_dims[0] as usize; - let time_steps_out = shape_dims[1] as usize; - let vocab_size = shape_dims[2] as usize; - - if batch_size != 1 { - return Err(Error::Model(format!( - "Expected batch size 1, got {batch_size}" - ))); - } - - let logits_2d = Array2::from_shape_vec((time_steps_out, vocab_size), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to create array: {e}")))?; - - Ok(logits_2d) - } - - pub fn config(&self) -> &ModelConfig { - &self.config - } - - pub fn vocab_size(&self) -> usize { - self.config.vocab_size - } - - pub fn pad_token_id(&self) -> usize { - self.config.pad_token_id - } -} diff --git a/parakeet-rs/src/model_eou.rs b/parakeet-rs/src/model_eou.rs deleted file mode 100644 index 5b56e6d..0000000 --- a/parakeet-rs/src/model_eou.rs +++ /dev/null @@ -1,183 +0,0 @@ -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::{Array1, Array2, Array3, Array4}; -use ort::session::Session; -use std::path::Path; - -/// Encoder cache state for streaming inference -/// The cache maintains temporal context across chunks -pub struct EncoderCache { - /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback - pub cache_last_channel: Array4<f32>, - /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps - pub cache_last_time: Array4<f32>, - /// cache length: [1] with value 0 initially - pub cache_last_channel_len: Array1<i64>, -} - -impl EncoderCache { - /// 17 layers, batch=1, 70 frame lookback, 512 features - pub fn new() -> Self { - Self { - cache_last_channel: Array4::zeros((17, 1, 70, 512)), - cache_last_time: Array4::zeros((17, 1, 512, 8)), - cache_last_channel_len: Array1::from_vec(vec![0i64]), - } - } -} - -pub struct ParakeetEOUModel { - encoder: Session, - decoder_joint: Session, -} - -impl ParakeetEOUModel { - pub fn from_pretrained<P: AsRef<Path>>( - model_dir: P, - exec_config: ExecutionConfig, - ) -> Result<Self> { - let model_dir = model_dir.as_ref(); - - let encoder_path = model_dir.join("encoder.onnx"); - let decoder_path = model_dir.join("decoder_joint.onnx"); - - if !encoder_path.exists() || !decoder_path.exists() { - return Err(Error::Config(format!( - "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx", - model_dir.display() - ))); - } - - // Load encoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let encoder = builder.commit_from_file(&encoder_path)?; - - // Load decoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let decoder_joint = builder.commit_from_file(&decoder_path)?; - - Ok(Self { - encoder, - decoder_joint, - }) - } - - /// Run the stateful encoder with cache - /// Input: features [1, 128, T], cache state - /// Output: (encoded [1, 512, T], new_cache) - pub fn run_encoder( - &mut self, - features: &Array3<f32>, - length: i64, - cache: &EncoderCache - ) -> Result<(Array3<f32>, EncoderCache)> { - let length_arr = Array1::from_vec(vec![length]); - - let outputs = self.encoder.run(ort::inputs![ - "audio_signal" => ort::value::Value::from_array(features.clone())?, - "length" => ort::value::Value::from_array(length_arr)?, - "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?, - "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?, - "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())? - ])?; - - // Extract encoder output [1, 512, T] - let (shape, data) = outputs["outputs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; - - let shape_dims = shape.as_ref(); - let b = shape_dims[0] as usize; - let d = shape_dims[1] as usize; - let t = shape_dims[2] as usize; - - let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?; - - // Extract new cache states - let (ch_shape, ch_data) = outputs["new_cache_last_channel"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?; - - let (tm_shape, tm_data) = outputs["new_cache_last_time"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?; - - let (len_shape, len_data) = outputs["new_cache_last_channel_len"] - .try_extract_tensor::<i64>() - .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?; - - // Build new cache with extracted shapes - let new_cache = EncoderCache { - cache_last_channel: Array4::from_shape_vec( - (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize), - ch_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?, - - cache_last_time: Array4::from_shape_vec( - (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize), - tm_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?, - - cache_last_channel_len: Array1::from_shape_vec( - len_shape[0] as usize, - len_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?, - }; - - Ok((encoder_out, new_cache)) - } - - /// Run the stateful decoder - /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c) - pub fn run_decoder( - &mut self, - encoder_frame: &Array3<f32>, // [1, 512, 1] - last_token: &Array2<i32>, // [1, 1] - state_h: &Array3<f32>, // [1, 1, 640] - state_c: &Array3<f32>, // [1, 1, 640] - ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> { - - // Target length is always 1 for single step - let target_len = Array1::from_vec(vec![1i32]); - - let outputs = self.decoder_joint.run(ort::inputs![ - "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?, - "targets" => ort::value::Value::from_array(last_token.clone())?, - "target_length" => ort::value::Value::from_array(target_len)?, - "input_states_1" => ort::value::Value::from_array(state_h.clone())?, - "input_states_2" => ort::value::Value::from_array(state_c.clone())? - ])?; - - // 1. Extract Logits - let (l_shape, l_data) = outputs["outputs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - // 2. Extract States (output_states_1, output_states_2) - let (_h_shape, h_data) = outputs["output_states_1"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?; - - let (_c_shape, c_data) = outputs["output_states_2"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?; - - // Reconstruct Arrays - // Logits: I simplify to [1, 1, vocab] - let vocab_size = l_shape[3] as usize; - let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?; - - // States: [1, 1, 640] - let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?; - - let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?; - - Ok((logits, new_h, new_c)) - } -}
\ No newline at end of file diff --git a/parakeet-rs/src/model_tdt.rs b/parakeet-rs/src/model_tdt.rs deleted file mode 100644 index e00ebdc..0000000 --- a/parakeet-rs/src/model_tdt.rs +++ /dev/null @@ -1,263 +0,0 @@ -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::{Array1, Array2, Array3}; -use ort::session::Session; -use std::path::{Path, PathBuf}; - -/// TDT model configs -#[derive(Debug, Clone)] -pub struct TDTModelConfig { - pub vocab_size: usize, -} - -impl Default for TDTModelConfig { - fn default() -> Self { - Self { - vocab_size: 8193, - } - } -} - -pub struct ParakeetTDTModel { - encoder: Session, - decoder_joint: Session, - config: TDTModelConfig, -} - -impl ParakeetTDTModel { - /// Load TDT model from directory containing encoder and decoder_joint ONNX files - pub fn from_pretrained<P: AsRef<Path>>( - model_dir: P, - exec_config: ExecutionConfig, - ) -> Result<Self> { - let model_dir = model_dir.as_ref(); - - // Find encoder and decoder_joint files - let encoder_path = Self::find_encoder(model_dir)?; - let decoder_joint_path = Self::find_decoder_joint(model_dir)?; - - let config = TDTModelConfig::default(); - - // Load encoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let encoder = builder.commit_from_file(&encoder_path)?; - - // Load decoder_joint - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let decoder_joint = builder.commit_from_file(&decoder_joint_path)?; - - - Ok(Self { - encoder, - decoder_joint, - config, - }) - } - - fn find_encoder(dir: &Path) -> Result<PathBuf> { - let candidates = ["encoder-model.onnx", "encoder.onnx"]; - for candidate in &candidates { - let path = dir.join(candidate); - if path.exists() { - return Ok(path); - } - } - Err(Error::Config(format!( - "No encoder model found in {}", - dir.display() - ))) - } - - fn find_decoder_joint(dir: &Path) -> Result<PathBuf> { - let candidates = [ - "decoder_joint-model.onnx", - "decoder_joint.onnx", - "decoder-model.onnx", - ]; - for candidate in &candidates { - let path = dir.join(candidate); - if path.exists() { - return Ok(path); - } - } - Err(Error::Config(format!( - "No decoder_joint model found in {}", - dir.display() - ))) - } - - /// Run greedy decoding - returns (token_ids, frame_indices, durations) - pub fn forward(&mut self, features: Array2<f32>) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> { - // Run encoder - let (encoder_out, encoder_len) = self.run_encoder(&features)?; - - // Run greedy decoding with decoder_joint - let (tokens, frame_indices, durations) = self.greedy_decode(&encoder_out, encoder_len)?; - - Ok((tokens, frame_indices, durations)) - } - - fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Array3<f32>, i64)> { - let batch_size = 1; - let time_steps = features.shape()[0]; - let feature_size = features.shape()[1]; - - // TDT encoder expects (batch, features, time) not (batch, time, features) - let input = features - .t() - .to_shape((batch_size, feature_size, time_steps)) - .map_err(|e| Error::Model(format!("Failed to reshape encoder input: {e}")))? - .to_owned(); - - let input_length = Array1::from_vec(vec![time_steps as i64]); - - let input_value = ort::value::Value::from_array(input)?; - let length_value = ort::value::Value::from_array(input_length)?; - - let outputs = self.encoder.run(ort::inputs!( - "audio_signal" => input_value, - "length" => length_value - ))?; - - let encoder_out = &outputs["outputs"]; - let encoder_lens = &outputs["encoded_lengths"]; - - let (shape, data) = encoder_out - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; - - let (_, lens_data) = encoder_lens - .try_extract_tensor::<i64>() - .map_err(|e| Error::Model(format!("Failed to extract encoder lengths: {e}")))?; - - let shape_dims = shape.as_ref(); - if shape_dims.len() != 3 { - return Err(Error::Model(format!( - "Expected 3D encoder output, got shape: {shape_dims:?}" - ))); - } - - let b = shape_dims[0] as usize; - let t = shape_dims[1] as usize; - let d = shape_dims[2] as usize; - - let encoder_array = Array3::from_shape_vec((b, t, d), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to create encoder array: {e}")))?; - - // TDT encoder outputs [batch, encoder_dim, time] directly - Ok((encoder_array, lens_data[0])) - } - - fn greedy_decode(&mut self, encoder_out: &Array3<f32>, _encoder_len: i64) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> { - // encoder_out shape: [batch, encoder_dim, time] - let encoder_dim = encoder_out.shape()[1]; - let time_steps = encoder_out.shape()[2]; - let vocab_size = self.config.vocab_size; - let max_tokens_per_step = 10; - let blank_id = vocab_size - 1; - - // States: (num_layers=2, batch=1, hidden_dim=640) - let mut state_h = Array3::<f32>::zeros((2, 1, 640)); - let mut state_c = Array3::<f32>::zeros((2, 1, 640)); - - let mut tokens = Vec::new(); - let mut frame_indices = Vec::new(); - let mut durations = Vec::new(); - - let mut t = 0; - let mut emitted_tokens = 0; - let mut last_emitted_token = blank_id as i32; - - // Frame-by-frame RNN-T/TDT greedy decoding - while t < time_steps { - // Get single encoder frame: slice [0, :, t] and reshape to [1, encoder_dim, 1] - let frame = encoder_out.slice(ndarray::s![0, .., t]).to_owned(); - let frame_reshaped = frame - .to_shape((1, encoder_dim, 1)) - .map_err(|e| Error::Model(format!("Failed to reshape frame: {e}")))? - .to_owned(); - - // Current token for prediction network - let targets = Array2::from_shape_vec((1, 1), vec![last_emitted_token]) - .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?; - - // Run decoder_joint - let outputs = self.decoder_joint.run(ort::inputs!( - "encoder_outputs" => ort::value::Value::from_array(frame_reshaped)?, - "targets" => ort::value::Value::from_array(targets)?, - "target_length" => ort::value::Value::from_array(Array1::from_vec(vec![1i32]))?, - "input_states_1" => ort::value::Value::from_array(state_h.clone())?, - "input_states_2" => ort::value::Value::from_array(state_c.clone())? - ))?; - - // Extract logits - let (_, logits_data) = outputs["outputs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - // TDT outputs vocab_size + 5 durations (8193 + 5 = 8198) - let vocab_logits: Vec<f32> = logits_data.iter().take(vocab_size).copied().collect(); - let duration_logits: Vec<f32> = logits_data.iter().skip(vocab_size).copied().collect(); - - let token_id = vocab_logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(blank_id); - - let duration_step = if !duration_logits.is_empty() { - duration_logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(0) - } else { - 0 - }; - - // Check if blank token - if token_id != blank_id { - // Update states when we emit a token - if let Ok((h_shape, h_data)) = outputs["output_states_1"].try_extract_tensor::<f32>() { - let dims = h_shape.as_ref(); - state_h = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), h_data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to update state_h: {e}")))?; - } - if let Ok((c_shape, c_data)) = outputs["output_states_2"].try_extract_tensor::<f32>() { - let dims = c_shape.as_ref(); - state_c = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), c_data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to update state_c: {e}")))?; - } - - tokens.push(token_id); - frame_indices.push(t); - durations.push(duration_step); - last_emitted_token = token_id as i32; - emitted_tokens += 1; - - // Don't advance yet - try to emit more tokens from the same frame - } else { - // Blank token - advance frame pointer - // Duration prediction applies when we finally move to next frame after emitting tokens - if duration_step > 0 && emitted_tokens > 0 { - t += duration_step; - } else { - t += 1; - } - emitted_tokens = 0; - } - - // Safety check: if we've emitted too many tokens from the same frame, advance - if emitted_tokens >= max_tokens_per_step { - t += 1; - emitted_tokens = 0; - } - } - - Ok((tokens, frame_indices, durations)) - } -} diff --git a/parakeet-rs/src/parakeet.rs b/parakeet-rs/src/parakeet.rs deleted file mode 100644 index d2aabdd..0000000 --- a/parakeet-rs/src/parakeet.rs +++ /dev/null @@ -1,210 +0,0 @@ -use crate::audio; -use crate::config::PreprocessorConfig; -use crate::decoder::{ParakeetDecoder, TranscriptionResult}; -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use crate::model::ParakeetModel; -use crate::timestamps::{process_timestamps, TimestampMode}; -use std::path::{Path, PathBuf}; - -pub struct Parakeet { - model: ParakeetModel, - decoder: ParakeetDecoder, - preprocessor_config: PreprocessorConfig, - model_dir: PathBuf, -} - -impl Parakeet { - /// Load Parakeet model from path with optional configuration. - /// - /// # Arguments - /// * `path` - Directory containing model files, or path to specific model file - /// * `config` - Optional execution configuration (defaults to CPU if None) - /// - /// # Examples - /// ```no_run - /// use parakeet_rs::Parakeet; - /// - /// // Load from directory with CPU (default) - /// let parakeet = Parakeet::from_pretrained(".", None)?; - /// - /// // Or load from specific model file - /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?; - /// # Ok::<(), Box<dyn std::error::Error>>(()) - /// ``` - /// - /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.) - /// and pass an `ExecutionConfig` with the desired execution provider. - pub fn from_pretrained<P: AsRef<Path>>( - path: P, - config: Option<ExecutionConfig>, - ) -> Result<Self> { - let path = path.as_ref(); - - // Determine if path is a directory or file - let (model_path, tokenizer_path, model_dir) = if path.is_dir() { - // Directory mode: auto-detect model file - let model_path = Self::find_model_file(path)?; - let tokenizer_path = path.join("tokenizer.json"); - (model_path, tokenizer_path, path.to_path_buf()) - } else if path.is_file() { - // File mode: path points directly to model file - let model_dir = path - .parent() - .ok_or_else(|| Error::Config("Invalid model path".to_string()))?; - let tokenizer_path = model_dir.join("tokenizer.json"); - (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf()) - } else { - return Err(Error::Config(format!( - "Path does not exist: {}", - path.display() - ))); - }; - - // Check tokenizer exists - if !tokenizer_path.exists() { - return Err(Error::Config(format!( - "Required file 'tokenizer.json' not found in {}", - model_dir.display() - ))); - } - - let preprocessor_config = PreprocessorConfig::default(); - let exec_config = config.unwrap_or_default(); - - let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?; - let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?; - - Ok(Self { - model, - decoder, - preprocessor_config, - model_dir, - }) - } - - fn find_model_file(dir: &Path) -> Result<PathBuf> { - // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx - let candidates = [ - "model.onnx", - "model_fp16.onnx", - "model_int8.onnx", - "model_q4.onnx", - ]; - - for candidate in &candidates { - let path = dir.join(candidate); - if path.exists() { - return Ok(path); - } - } - - // If none of the standard names found, search for any .onnx file - if let Ok(entries) = std::fs::read_dir(dir) { - for entry in entries.flatten() { - let path = entry.path(); - if path.extension().and_then(|s| s.to_str()) == Some("onnx") { - return Ok(path); - } - } - } - - Err(Error::Config(format!( - "No model file (*.onnx) found in directory: {}", - dir.display() - ))) - } - - /// Transcribe audio samples. - /// - /// # Arguments - /// - /// * `audio` - Audio samples as f32 values - /// * `sample_rate` - Sample rate in Hz - /// * `channels` - Number of audio channels - /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) - /// - /// # Returns - /// - /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level. - pub fn transcribe_samples( - &mut self, - audio: Vec<f32>, - sample_rate: u32, - channels: u16, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; - let logits = self.model.forward(features)?; - - let mut result = self.decoder.decode_with_timestamps( - &logits, - self.preprocessor_config.hop_length, - self.preprocessor_config.sampling_rate, - )?; - - // Process timestamps to requested output mode - let mode = mode.unwrap_or(TimestampMode::Tokens); - result.tokens = process_timestamps(&result.tokens, mode); - - // Rebuild full text from processed tokens to ensure consistency - result.text = result.tokens.iter() - .map(|t| t.text.as_str()) - .collect::<Vec<_>>() - .join(" "); - - Ok(result) - } - - /// Transcribe an audio file with timestamps - /// - /// # Arguments - /// - /// * `audio_path` - A path to the audio file that needs to be transcribed. - /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. - pub fn transcribe_file<P: AsRef<Path>>( - &mut self, - audio_path: P, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let audio_path = audio_path.as_ref(); - let (audio, spec) = audio::load_audio(audio_path)?; - - self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) - } - - /// Transcribes multiple audio files in batch. - /// - /// # Arguments - /// - /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. - /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. - pub fn transcribe_file_batch<P: AsRef<Path>>( - &mut self, - audio_paths: &[P], - mode: Option<TimestampMode>, - ) -> Result<Vec<TranscriptionResult>> { - let mut results = Vec::with_capacity(audio_paths.len()); - for path in audio_paths { - let result = self.transcribe_file(path, mode)?; - results.push(result); - } - Ok(results) - } - - pub fn model_dir(&self) -> &Path { - &self.model_dir - } - - pub fn preprocessor_config(&self) -> &PreprocessorConfig { - &self.preprocessor_config - } -} diff --git a/parakeet-rs/src/parakeet_eou.rs b/parakeet-rs/src/parakeet_eou.rs deleted file mode 100644 index 25c7d64..0000000 --- a/parakeet-rs/src/parakeet_eou.rs +++ /dev/null @@ -1,304 +0,0 @@ -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use crate::model_eou::{EncoderCache, ParakeetEOUModel}; -use ndarray::{s, Array2, Array3}; -use rustfft::{num_complex::Complex, FftPlanner}; -use std::collections::VecDeque; -use std::f32::consts::PI; -use std::path::Path; - -const SAMPLE_RATE: usize = 16000; - -const N_FFT: usize = 512; -const WIN_LENGTH: usize = 400; -const HOP_LENGTH: usize = 160; -const N_MELS: usize = 128; -const PREEMPH: f32 = 0.97; -const LOG_ZERO_GUARD: f32 = 5.960464478e-8; -const FMAX: f32 = 8000.0; - -/// Parakeet RealTime EOU model for streaming ASR with end-of-utterance detection. -/// Uses cache-aware streaming with audio buffering for pre-encode context. -pub struct ParakeetEOU { - model: ParakeetEOUModel, - tokenizer: tokenizers::Tokenizer, - encoder_cache: EncoderCache, - state_h: Array3<f32>, - state_c: Array3<f32>, - last_token: Array2<i32>, - blank_id: i32, - eou_id: i32, - mel_basis: Array2<f32>, - window: Vec<f32>, - audio_buffer: VecDeque<f32>, - buffer_size_samples: usize, -} - -impl ParakeetEOU { - /// Load Parakeet EOU model from path - /// - /// # Arguments - /// * `path` - Directory containing encoder.onnx, decoder_joint.onnx, and tokenizer.json - /// * `config` - Optional execution configuration (defaults to CPU if None) - pub fn from_pretrained<P: AsRef<Path>>(path: P, config: Option<ExecutionConfig>) -> Result<Self> { - let path = path.as_ref(); - let tokenizer_path = path.join("tokenizer.json"); - let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path) - .map_err(|e| Error::Config(format!("Failed to load tokenizer: {e}")))?; - - let vocab_size = tokenizer.get_vocab_size(true); - let blank_id = (vocab_size - 1) as i32; - let blank_id = if blank_id < 1000 { 1026 } else { blank_id }; - let eou_id = tokenizer.token_to_id("<EOU>").map(|id| id as i32).unwrap_or(1024); - - let exec_config = config.unwrap_or_default(); - let model = ParakeetEOUModel::from_pretrained(path, exec_config)?; - - // Buffer size: 4 seconds of audio - // Provides long history for feature extraction context - // Note that, I pick those "magic numbers" by looking NeMo's ring buffer approach. - let buffer_size_samples = SAMPLE_RATE * 4; // 4 seconds = 64000 samples - - Ok(Self { - model, - tokenizer, - encoder_cache: EncoderCache::new(), - state_h: Array3::zeros((1, 1, 640)), - state_c: Array3::zeros((1, 1, 640)), - last_token: Array2::from_elem((1, 1), blank_id), - blank_id, - eou_id, - mel_basis: Self::create_mel_filterbank(), - window: Self::create_window(), - audio_buffer: VecDeque::with_capacity(buffer_size_samples), - buffer_size_samples, - }) - } - - /// Transcribe a chunk of audio samples. - /// - /// # Arguments - /// * `chunk` - Audio chunk (typically 160ms / 2560 samples at 16kHz) - /// * `reset_on_eou` - If true, reset decoder state when end-of-utterance is detected - /// - /// # Streaming Behavior - /// Cache-aware streaming - /// - Maintains 4-second ring buffer for feature extraction context - /// - Extracts features from full buffer - /// - Slices last (pre_encode_cache + new_frames) for encoder input - /// - pre_encode_cache=9 frames, new_frames=~16, total=~25 frames to encoder - pub fn transcribe(&mut self, chunk: &[f32], reset_on_eou: bool) -> Result<String> { - // Add new chunk to rolling buffer - self.audio_buffer.extend(chunk.iter().copied()); - - // Trim buffer to keep only the most recent samples - while self.audio_buffer.len() > self.buffer_size_samples { - self.audio_buffer.pop_front(); - } - - // Wait until buffer has minimum samples (at least 1 second for stable features) - const MIN_BUFFER_SAMPLES: usize = SAMPLE_RATE; // 1 second - if self.audio_buffer.len() < MIN_BUFFER_SAMPLES { - return Ok(String::new()); - } - - // Extract features from FULL buffer (provides context for feature extraction) - let buffer_slice: Vec<f32> = self.audio_buffer.iter().copied().collect(); - let full_features = self.extract_mel_features(&buffer_slice); - let total_frames = full_features.shape()[2]; - - // Slice to take only (pre_encode_cache + new_frames) for encoder - // pre_encode_cache = 9 frames, new_frames = ~16 for 160ms chunk - const PRE_ENCODE_CACHE: usize = 9; - const FRAMES_PER_CHUNK: usize = 16; - const SLICE_LEN: usize = PRE_ENCODE_CACHE + FRAMES_PER_CHUNK; - - let start_frame = if total_frames > SLICE_LEN { - total_frames - SLICE_LEN - } else { - 0 - }; - - let features = full_features.slice(s![.., .., start_frame..]).to_owned(); - let time_steps = features.shape()[2]; - - // Encode with cache - encoder sees full buffer context - let (encoder_out, new_cache) = self.model.run_encoder(&features, time_steps as i64, &self.encoder_cache)?; - self.encoder_cache = new_cache; - - let total_frames = encoder_out.shape()[2]; - if total_frames == 0 { - return Ok(String::new()); - } - - // Process all output frames (typically 1 frame per chunk) - let new_frames = encoder_out; - - let mut text_output = String::new(); - - for t in 0..new_frames.shape()[2] { - let current_frame = new_frames.slice(s![.., .., t..t + 1]).to_owned(); - let mut syms_added = 0; - - while syms_added < 5 { - let (logits, new_h, new_c) = self.model.run_decoder( - ¤t_frame, - &self.last_token, - &self.state_h, - &self.state_c, - )?; - - let vocab = logits.slice(s![0, 0, ..]); - - let mut max_idx = 0; - let mut max_val = f32::NEG_INFINITY; - for (i, &val) in vocab.iter().enumerate() { - if val.is_finite() && val > max_val { - max_val = val; - max_idx = i as i32; - } - } - - if max_idx == self.blank_id || max_idx == 0 { - break; - } - - if max_idx == self.eou_id { - if reset_on_eou { - self.reset_states(); - return Ok(text_output + " [EOU]"); - } - break; - } - - if max_idx as usize >= self.tokenizer.get_vocab_size(true) { - break; - } - - self.state_h = new_h; - self.state_c = new_c; - self.last_token.fill(max_idx); - - if let Some(token) = self.tokenizer.id_to_token(max_idx as u32) { - let clean = token.replace('▁', " "); - text_output.push_str(&clean); - } - syms_added += 1; - } - } - Ok(text_output) - } - - fn reset_states(&mut self) { - // Soft reset: Only reset decoder states - // at this state, we need to keep encoder cache and audio buffer flowing for continuous context - // self.encoder_cache = EncoderCache::new(); // DON'T reset!!! - self.state_h.fill(0.0); - self.state_c.fill(0.0); - self.last_token.fill(self.blank_id); - // self.audio_buffer.clear(); // DON'T clear!! - } - - fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> { - let audio_pre = Self::apply_preemphasis(audio); - let spec = self.stft(&audio_pre); - let mel = self.mel_basis.dot(&spec); - let mel_log = mel.mapv(|x| (x.max(0.0) + LOG_ZERO_GUARD).ln()); - mel_log.insert_axis(ndarray::Axis(0)) - } - - fn apply_preemphasis(audio: &[f32]) -> Vec<f32> { - let mut result = Vec::with_capacity(audio.len()); - if audio.is_empty() { - return result; - } - - let safe_x = |x: f32| if x.is_finite() { x } else { 0.0 }; - - result.push(safe_x(audio[0])); - for i in 1..audio.len() { - result.push(safe_x(audio[i]) - PREEMPH * safe_x(audio[i - 1])); - } - result - } - - fn stft(&self, audio: &[f32]) -> Array2<f32> { - let mut planner = FftPlanner::<f32>::new(); - let fft = planner.plan_fft_forward(N_FFT); - - let pad_amount = N_FFT / 2; - let mut padded_audio = vec![0.0; pad_amount]; - padded_audio.extend_from_slice(audio); - padded_audio.extend(std::iter::repeat(0.0).take(pad_amount)); - - let num_frames = 1 + (padded_audio.len().saturating_sub(WIN_LENGTH)) / HOP_LENGTH; - let freq_bins = N_FFT / 2 + 1; - let mut spec = Array2::zeros((freq_bins, num_frames)); - - for frame_idx in 0..num_frames { - let start = frame_idx * HOP_LENGTH; - if start + WIN_LENGTH > padded_audio.len() { - break; - } - - let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT]; - for i in 0..WIN_LENGTH { - buffer[i] = Complex::new(padded_audio[start + i] * self.window[i], 0.0); - } - fft.process(&mut buffer); - for (i, val) in buffer.iter().take(freq_bins).enumerate() { - let mag_sq = val.norm_sqr(); - spec[[i, frame_idx]] = if mag_sq.is_finite() { mag_sq } else { 0.0 }; - } - } - spec - } - - fn create_window() -> Vec<f32> { - (0..WIN_LENGTH) - .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / ((WIN_LENGTH - 1) as f32)).cos()) - .collect() - } - - fn create_mel_filterbank() -> Array2<f32> { - let num_freqs = N_FFT / 2 + 1; - - let hz_to_mel = |hz: f32| 2595.0 * (1.0 + hz / 700.0).log10(); - let mel_to_hz = |mel: f32| 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0); - - let mel_min = hz_to_mel(0.0); - let mel_max = hz_to_mel(FMAX); - - let mel_points: Vec<f32> = (0..=N_MELS + 1) - .map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (N_MELS + 1) as f32)) - .collect(); - - let fft_freqs: Vec<f32> = (0..num_freqs) - .map(|i| (SAMPLE_RATE as f32 / N_FFT as f32) * i as f32) - .collect(); - - let mut weights = Array2::zeros((N_MELS, num_freqs)); - - for i in 0..N_MELS { - let left = mel_points[i]; - let center = mel_points[i + 1]; - let right = mel_points[i + 2]; - for (j, &freq) in fft_freqs.iter().enumerate() { - if freq >= left && freq <= center { - weights[[i, j]] = (freq - left) / (center - left); - } else if freq > center && freq <= right { - weights[[i, j]] = (right - freq) / (right - center); - } - } - } - - for i in 0..N_MELS { - let enorm = 2.0 / (mel_points[i + 2] - mel_points[i]); - for j in 0..num_freqs { - weights[[i, j]] *= enorm; - } - } - - weights - } -} diff --git a/parakeet-rs/src/parakeet_tdt.rs b/parakeet-rs/src/parakeet_tdt.rs deleted file mode 100644 index 719ae75..0000000 --- a/parakeet-rs/src/parakeet_tdt.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::audio; -use crate::config::PreprocessorConfig; -use crate::decoder::TranscriptionResult; -use crate::decoder_tdt::ParakeetTDTDecoder; -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use crate::model_tdt::ParakeetTDTModel; -use crate::timestamps::{process_timestamps, TimestampMode}; -use crate::vocab::Vocabulary; -use std::path::{Path, PathBuf}; - -/// Parakeet TDT model for multilingual ASR -pub struct ParakeetTDT { - model: ParakeetTDTModel, - decoder: ParakeetTDTDecoder, - preprocessor_config: PreprocessorConfig, - model_dir: PathBuf, -} - -impl ParakeetTDT { - /// Load Parakeet TDT model from path with optional configuration. - /// - /// # Arguments - /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt - /// * `config` - Optional execution configuration (defaults to CPU if None) - pub fn from_pretrained<P: AsRef<Path>>( - path: P, - config: Option<ExecutionConfig>, - ) -> Result<Self> { - let path = path.as_ref(); - - if !path.is_dir() { - return Err(Error::Config(format!( - "TDT model path must be a directory: {}", - path.display() - ))); - } - - let vocab_path = path.join("vocab.txt"); - if !vocab_path.exists() { - return Err(Error::Config(format!( - "vocab.txt not found in {}", - path.display() - ))); - } - - // TDT-specific preprocessor config (128 features instead of 80) - let preprocessor_config = PreprocessorConfig { - feature_extractor_type: "ParakeetFeatureExtractor".to_string(), - feature_size: 128, - hop_length: 160, - n_fft: 512, - padding_side: "right".to_string(), - padding_value: 0.0, - preemphasis: 0.97, - processor_class: "ParakeetProcessor".to_string(), - return_attention_mask: true, - sampling_rate: 16000, - win_length: 400, - }; - - let exec_config = config.unwrap_or_default(); - - let model = ParakeetTDTModel::from_pretrained(path, exec_config)?; - let vocab = Vocabulary::from_file(&vocab_path)?; - let decoder = ParakeetTDTDecoder::from_vocab(vocab); - - Ok(Self { - model, - decoder, - preprocessor_config, - model_dir: path.to_path_buf(), - }) - } - - /// Transcribe audio samples. - /// - /// # Arguments - /// - /// * `audio` - Audio samples as f32 values - /// * `sample_rate` - Sample rate in Hz - /// * `channels` - Number of audio channels - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode. - pub fn transcribe_samples( - &mut self, - audio: Vec<f32>, - sample_rate: u32, - channels: u16, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; - let (tokens, frame_indices, durations) = self.model.forward(features)?; - - let mut result = self.decoder.decode_with_timestamps( - &tokens, - &frame_indices, - &durations, - self.preprocessor_config.hop_length, - self.preprocessor_config.sampling_rate, - )?; - - // Apply timestamp mode conversion - let mode = mode.unwrap_or(TimestampMode::Tokens); - result.tokens = process_timestamps(&result.tokens, mode); - - // Rebuild full text from processed tokens - result.text = result.tokens.iter() - .map(|t| t.text.as_str()) - .collect::<Vec<_>>() - .join(" "); - - Ok(result) - } - - /// Transcribe an audio file with timestamps - /// - /// # Arguments - /// - /// * `audio_path` - A path to the audio file that needs to be transcribed. - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. - pub fn transcribe_file<P: AsRef<Path>>( - &mut self, - audio_path: P, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let audio_path = audio_path.as_ref(); - let (audio, spec) = audio::load_audio(audio_path)?; - - self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) - } - - /// Transcribes multiple audio files in batch. - /// - /// # Arguments - /// - /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. - pub fn transcribe_file_batch<P: AsRef<Path>>( - &mut self, - audio_paths: &[P], - mode: Option<TimestampMode>, - ) -> Result<Vec<TranscriptionResult>> { - let mut results = Vec::with_capacity(audio_paths.len()); - for path in audio_paths { - let result = self.transcribe_file(path, mode)?; - results.push(result); - } - Ok(results) - } - - /// Get model directory path - pub fn model_dir(&self) -> &Path { - &self.model_dir - } -} diff --git a/parakeet-rs/src/sortformer.rs b/parakeet-rs/src/sortformer.rs deleted file mode 100644 index 2b1e5a3..0000000 --- a/parakeet-rs/src/sortformer.rs +++ /dev/null @@ -1,1062 +0,0 @@ -//! NVIDIA Sortformer v2 Streaming Speaker Diarization -//! -//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization. -//! -//! Key features: -//! - Streaming inference with ~10s chunks (124 frames at 80ms each) -//! - FIFO buffer for context management -//! - Smart speaker cache compression (keeps important frames, not just recent) -//! - Silence profile tracking -//! - Post-processing: median filtering, hysteresis thresholding -//! - Supports up to 4 speakers -//! -//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 -//! Note that, my ONNX export: -//! CHUNK_LEN = 124 -//! FIFO_LEN = 124 -//! CACHE_LEN = 188 -//! FEAT_DIM = 128 -//! EMB_DIM = 512 -//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html - -use crate::error::{Error, Result}; -use crate::execution::ModelConfig; -use ndarray::{s, Array1, Array2, Array3, Axis}; -use ort::session::Session; -use rustfft::{num_complex::Complex, FftPlanner}; -use std::f32::consts::PI; -use std::path::Path; - -// Model constants -const N_FFT: usize = 512; -const WIN_LENGTH: usize = 400; -const HOP_LENGTH: usize = 160; -const N_MELS: usize = 128; -const PREEMPH: f32 = 0.97; -const LOG_ZERO_GUARD: f32 = 5.960464478e-8; -const SAMPLE_RATE: usize = 16000; -const FMIN: f32 = 0.0; -const FMAX: f32 = 8000.0; - -// Streaming constants -const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms) -const FIFO_LEN: usize = 124; // FIFO buffer length -const SPKCACHE_LEN: usize = 188; // Speaker cache length -const SPKCACHE_UPDATE_PERIOD: usize = 124; -const SUBSAMPLING: usize = 8; // Audio frames -> model frames -const EMB_DIM: usize = 512; // Embedding dimension -const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers -const FRAME_DURATION: f32 = 0.08; // 80ms per frame - -// Cache compression params (from NeMo) -const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3; -const PRED_SCORE_THRESHOLD: f32 = 0.25; -const STRONG_BOOST_RATE: f32 = 0.75; -const WEAK_BOOST_RATE: f32 = 1.5; -const MIN_POS_SCORES_RATE: f32 = 0.5; -const SIL_THRESHOLD: f32 = 0.2; -const MAX_INDEX: usize = 99999; - -/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs) -/// -/// Controls how raw model predictions are converted into speaker segments. -/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI). -/// -/// # Parameters -/// - `onset`: Probability threshold to START a speaker segment (higher = more strict) -/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments) -/// - `pad_onset`: Seconds to subtract from segment start times -/// - `pad_offset`: Seconds to add to segment end times -/// - `min_duration_on`: Minimum segment length in seconds (filters short blips) -/// - `min_duration_off`: Minimum gap between segments before merging -/// - `median_window`: Smoothing window size (odd number, higher = smoother) -/// -/// # Pre-tuned Configs -/// - `callhome()` - (default) -/// - `dihard3()` -/// -/// # Custom Config -/// Use `custom(onset, offset)` to create your own config for fine-tuning. -/// -/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer -#[derive(Debug, Clone)] -pub struct DiarizationConfig { - pub onset: f32, - pub offset: f32, - pub pad_onset: f32, - pub pad_offset: f32, - pub min_duration_on: f32, - pub min_duration_off: f32, - pub median_window: usize, -} - -impl Default for DiarizationConfig { - fn default() -> Self { - Self::callhome() - } -} - -impl DiarizationConfig { - /// CallHome dataset config for v2 (default) - /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml - pub fn callhome() -> Self { - Self { - onset: 0.641, - offset: 0.561, - pad_onset: 0.229, - pad_offset: 0.079, - min_duration_on: 0.511, - min_duration_off: 0.296, - median_window: 11, - } - } - - /// DIHARD3 dataset config for v2 - /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml - pub fn dihard3() -> Self { - Self { - onset: 0.56, - offset: 1.0, - pad_onset: 0.063, - pad_offset: 0.002, - min_duration_on: 0.007, - min_duration_off: 0.151, - median_window: 11, - } - } - - /// Create a custom config for fine-tuning diarization behavior. - /// - /// # Arguments - /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7) - /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6) - /// - /// # Example - /// ```rust - /// use parakeet_rs::sortformer::DiarizationConfig; - /// - /// // More sensitive detection (lower thresholds) - /// let sensitive = DiarizationConfig::custom(0.5, 0.4); - /// - /// // Stricter detection (higher thresholds, fewer false positives) - /// let strict = DiarizationConfig::custom(0.7, 0.6); - /// - /// // Full customization - /// let mut config = DiarizationConfig::custom(0.6, 0.5); - /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms - /// config.median_window = 15; // More smoothing - /// ``` - pub fn custom(onset: f32, offset: f32) -> Self { - Self { - onset, - offset, - pad_onset: 0.0, - pad_offset: 0.0, - min_duration_on: 0.1, - min_duration_off: 0.1, - median_window: 11, - } - } -} - -/// Speaker segment with start time, end time, and speaker ID -#[derive(Debug, Clone)] -pub struct SpeakerSegment { - pub start: f32, - pub end: f32, - pub speaker_id: usize, -} - -/// Streaming Sortformer v2 speaker diarization engine -pub struct Sortformer { - session: Session, - config: DiarizationConfig, - // Streaming state. note that, Same way as Nemo - spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM) - spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS) - fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM) - fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS) - mean_sil_emb: Array2<f32>, // (1, EMB_DIM) - n_sil_frames: usize, - // Mel filterbank (cached) - mel_basis: Array2<f32>, -} - -impl Sortformer { - /// a new Sortformer instance from ONNX model path - pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> { - Self::with_config(model_path, None, DiarizationConfig::default()) - } - - /// Create with custom config - pub fn with_config<P: AsRef<Path>>( - model_path: P, - execution_config: Option<ModelConfig>, - config: DiarizationConfig, - ) -> Result<Self> { - let config_to_use = execution_config.unwrap_or_default(); - - let session = config_to_use - .apply_to_session_builder(Session::builder()?)? - .commit_from_file(model_path.as_ref())?; - - let mel_basis = Self::create_mel_filterbank(); - - let mut instance = Self { - session, - config, - spkcache: Array3::zeros((1, 0, EMB_DIM)), - spkcache_preds: None, - fifo: Array3::zeros((1, 0, EMB_DIM)), - fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)), - mean_sil_emb: Array2::zeros((1, EMB_DIM)), - n_sil_frames: 0, - mel_basis, - }; - instance.reset_state(); - Ok(instance) - } - - /// Reset streaming state - pub fn reset_state(&mut self) { - self.spkcache = Array3::zeros((1, 0, EMB_DIM)); - self.spkcache_preds = None; - self.fifo = Array3::zeros((1, 0, EMB_DIM)); - self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS)); - self.mean_sil_emb = Array2::zeros((1, EMB_DIM)); - self.n_sil_frames = 0; - } - - /// Main diarization entry point - pub fn diarize( - &mut self, - mut audio: Vec<f32>, - sample_rate: u32, - channels: u16, - ) -> Result<Vec<SpeakerSegment>> { - // Resample if needed - if sample_rate != SAMPLE_RATE as u32 { - return Err(Error::Audio(format!( - "Expected {} Hz, got {} Hz", - SAMPLE_RATE, sample_rate - ))); - } - - // Convert to mono - if channels > 1 { - audio = audio - .chunks(channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) - .collect(); - } - - // Reset state for new audio - self.reset_state(); - - // Extract mel features (B, T, D) - let features = self.extract_mel_features(&audio); - let total_frames = features.shape()[1]; - - // Process in chunks - let chunk_stride = CHUNK_LEN * SUBSAMPLING; - let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; - - let mut all_chunk_preds = Vec::new(); - - for chunk_idx in 0..num_chunks { - let start = chunk_idx * chunk_stride; - let end = (start + chunk_stride).min(total_frames); - let current_len = end - start; - - // Extract chunk features - let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); - - // Pad last chunk if needed - if current_len < chunk_stride { - let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); - padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); - chunk_feat = padded; - } - - // Run streaming update - let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; - all_chunk_preds.push(chunk_preds); - } - - // Concatenate all predictions - let full_preds = Self::concat_predictions(&all_chunk_preds); - - // Apply median filtering - let filtered_preds = if self.config.median_window > 1 { - self.median_filter(&full_preds) - } else { - full_preds - }; - - // Binarize to segments - let segments = self.binarize(&filtered_preds); - - Ok(segments) - } - - /// Streaming diarization that maintains state across calls. - /// - /// Unlike `diarize()`, this method does NOT reset the internal state, - /// allowing speaker embeddings to be preserved across multiple audio chunks. - /// Call `reset_state()` manually when starting a new audio session. - /// - /// This enables consistent speaker identification across long audio streams - /// by maintaining the speaker cache between processing windows. - /// - /// # Arguments - /// * `audio` - Audio samples (will be converted to mono if multi-channel) - /// * `sample_rate` - Must be 16000 Hz - /// * `channels` - Number of audio channels - /// - /// # Example - /// ```ignore - /// // Start of session - /// sortformer.reset_state(); - /// - /// // Process sliding windows - /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?; - /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs - /// ``` - pub fn diarize_streaming( - &mut self, - mut audio: Vec<f32>, - sample_rate: u32, - channels: u16, - ) -> Result<Vec<SpeakerSegment>> { - // Resample if needed - if sample_rate != SAMPLE_RATE as u32 { - return Err(Error::Audio(format!( - "Expected {} Hz, got {} Hz", - SAMPLE_RATE, sample_rate - ))); - } - - // Convert to mono - if channels > 1 { - audio = audio - .chunks(channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) - .collect(); - } - - // NOTE: Unlike diarize(), we do NOT call reset_state() here - // This preserves speaker embeddings across calls - - // Extract mel features (B, T, D) - let features = self.extract_mel_features(&audio); - let total_frames = features.shape()[1]; - - // Process in chunks - let chunk_stride = CHUNK_LEN * SUBSAMPLING; - let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; - - let mut all_chunk_preds = Vec::new(); - - for chunk_idx in 0..num_chunks { - let start = chunk_idx * chunk_stride; - let end = (start + chunk_stride).min(total_frames); - let current_len = end - start; - - // Extract chunk features - let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); - - // Pad last chunk if needed - if current_len < chunk_stride { - let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); - padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); - chunk_feat = padded; - } - - // Run streaming update - let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; - all_chunk_preds.push(chunk_preds); - } - - // Concatenate all predictions - let full_preds = Self::concat_predictions(&all_chunk_preds); - - // Apply median filtering - let filtered_preds = if self.config.median_window > 1 { - self.median_filter(&full_preds) - } else { - full_preds - }; - - // Binarize to segments - let segments = self.binarize(&filtered_preds); - - Ok(segments) - } - - /// NeMo's streaming_update with smart cache compression. - /// Public to allow incremental streaming diarization. - pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> { - let spkcache_len = self.spkcache.shape()[1]; - let fifo_len = self.fifo.shape()[1]; - - // Prepare inputs - let chunk_lengths = Array1::from_vec(vec![current_len as i64]); - let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]); - let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]); - - // Prepare FIFO input - let fifo_input = if fifo_len > 0 { - self.fifo.clone() - } else { - Array3::zeros((1, 0, EMB_DIM)) - }; - - // Prepare spkcache input (may be empty) - let spkcache_input = if spkcache_len > 0 { - self.spkcache.clone() - } else { - Array3::zeros((1, 0, EMB_DIM)) - }; - - // Create input values - let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?; - let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?; - let spkcache_value = ort::value::Value::from_array(spkcache_input)?; - let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?; - let fifo_value = ort::value::Value::from_array(fifo_input)?; - let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?; - - // Run ONNX inference and extract all data in a block to release borrow - let (preds, new_embs, chunk_len) = { - let outputs = self.session.run(ort::inputs!( - "chunk" => chunk_value, - "chunk_lengths" => chunk_lengths_value, - "spkcache" => spkcache_value, - "spkcache_lengths" => spkcache_lengths_value, - "fifo" => fifo_value, - "fifo_lengths" => fifo_lengths_value - ))?; - - // Extract outputs - let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?; - let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?; - - // Convert to ndarray - let preds_dims = preds_shape.as_ref(); - let embs_dims = embs_shape.as_ref(); - - let preds = Array3::from_shape_vec( - (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize), - preds_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?; - - let new_embs = Array3::from_shape_vec( - (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize), - embs_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?; - - // Calculate valid frames - let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING; - - (preds, new_embs, valid_frames) - }; - - // Extract predictions for different parts - let fifo_preds = if fifo_len > 0 { - preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned() - } else { - Array2::zeros((0, NUM_SPEAKERS)) - }; - - let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned(); - let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned(); - - // Append chunk embeddings to FIFO - self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0))); - - // Update FIFO predictions - if fifo_len > 0 { - let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds); - self.fifo_preds = combined.insert_axis(Axis(0)); - } else { - self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0)); - } - - let fifo_len_after = self.fifo.shape()[1]; - - // Move from FIFO to cache when FIFO exceeds limit - if fifo_len_after > FIFO_LEN { - let mut pop_out_len = SPKCACHE_UPDATE_PERIOD; - pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len); - pop_out_len = pop_out_len.min(fifo_len_after); - - let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned(); - let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned(); - - // Update silence profile - self.update_silence_profile(&pop_out_embs, &pop_out_preds); - - // Remove from FIFO - self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned(); - self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned(); - - // Append to cache - self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs); - - if let Some(ref cache_preds) = self.spkcache_preds { - self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds)); - } - - // Smart compression when cache exceeds limit - if self.spkcache.shape()[1] > SPKCACHE_LEN { - if self.spkcache_preds.is_none() { - // Initialize cache predictions from initial output - let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned(); - let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds); - self.spkcache_preds = Some(combined); - } - - // Use smart compression - self.compress_spkcache(); - } - } - - Ok(chunk_preds) - } - - /// Update mean silence embedding - fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) { - let preds_2d = preds.slice(s![0, .., ..]); - - for t in 0..preds_2d.shape()[0] { - let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum(); - if sum < SIL_THRESHOLD { - // This is a silence frame - let emb = embs.slice(s![0, t, ..]); - - // Update running mean - let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter() - .map(|&x| x * self.n_sil_frames as f32) - .collect(); - - self.n_sil_frames += 1; - - for i in 0..EMB_DIM { - self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32; - } - } - } - } - - /// Smart cache compression - fn compress_spkcache(&mut self) { - let cache_preds = match &self.spkcache_preds { - Some(p) => p.clone(), - None => return, - }; - - let n_frames = self.spkcache.shape()[1]; - let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK; - let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize; - let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize; - let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize; - - // Calculate quality scores - let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned(); - let mut scores = self.get_log_pred_scores(&preds_2d); - - // Disable low scores - scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk); - - // Boost important frames - scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0); - scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0); - - // Add silence frames placeholder - if SPKCACHE_SIL_FRAMES_PER_SPK > 0 { - let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY); - padded.slice_mut(s![..n_frames, ..]).assign(&scores); - for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK { - for j in 0..NUM_SPEAKERS { - padded[[i, j]] = f32::INFINITY; - } - } - scores = padded; - } - - // Select top frames - let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames); - - // Gather embeddings - let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled); - - self.spkcache = new_embs; - self.spkcache_preds = Some(new_preds); - } - - /// Calculate quality scores - fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> { - let mut scores = Array2::zeros(preds.dim()); - - for t in 0..preds.shape()[0] { - let mut log_1_probs_sum = 0.0f32; - for s in 0..NUM_SPEAKERS { - let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); - let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); - log_1_probs_sum += log_1_p; - } - - for s in 0..NUM_SPEAKERS { - let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); - let log_p = p.ln(); - let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); - scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln(); - } - } - - scores - } - - /// Disable non-speech and overlapped speech - fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> { - // Count positive scores per speaker - let mut pos_count = vec![0usize; NUM_SPEAKERS]; - for t in 0..scores.shape()[0] { - for s in 0..NUM_SPEAKERS { - if scores[[t, s]] > 0.0 { - pos_count[s] += 1; - } - } - } - - for t in 0..preds.shape()[0] { - for s in 0..NUM_SPEAKERS { - let is_speech = preds[[t, s]] > 0.5; - - if !is_speech { - scores[[t, s]] = f32::NEG_INFINITY; - } else { - let is_pos = scores[[t, s]] > 0.0; - if !is_pos && pos_count[s] >= min_pos_scores_per_spk { - scores[[t, s]] = f32::NEG_INFINITY; - } - } - } - } - - scores - } - - /// Boost top K frames per speaker - fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> { - for s in 0..NUM_SPEAKERS { - // Get column for this speaker - let col: Vec<(usize, f32)> = (0..scores.shape()[0]) - .map(|t| (t, scores[[t, s]])) - .collect(); - - // Sort by score descending - let mut sorted = col.clone(); - sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Boost top K - for i in 0..n_boost_per_spk.min(sorted.len()) { - let t = sorted[i].0; - if scores[[t, s]] != f32::NEG_INFINITY { - scores[[t, s]] -= scale_factor * 0.5f32.ln(); - } - } - } - - scores - } - - /// Get indices of top frames - fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) { - let n_frames = scores.shape()[0]; - - // Flatten scores as (S, T) then reshape to (S*T,) - // This means we iterate: speaker 0 all times, then speaker 1 all times, etc. - // flat_index = speaker * n_frames + time - let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS); - for s in 0..NUM_SPEAKERS { - for t in 0..n_frames { - let flat_idx = s * n_frames + t; - flat_scores.push((flat_idx, scores[[t, s]])); - } - } - - // Sort by score descending to get top-K - flat_scores.sort_by(|a, b| { - b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) - }); - - // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX - let mut topk_flat: Vec<usize> = flat_scores - .iter() - .take(SPKCACHE_LEN) - .map(|(idx, score)| { - if *score == f32::NEG_INFINITY { - MAX_INDEX - } else { - *idx - } - }) - .collect(); - - // Sort flat indices ascending (this puts MAX_INDEX at the end) - topk_flat.sort(); - - // Compute is_disabled and convert to frame indices - let mut is_disabled = vec![false; SPKCACHE_LEN]; - let mut frame_indices = vec![0usize; SPKCACHE_LEN]; - - for (i, &flat_idx) in topk_flat.iter().enumerate() { - if flat_idx == MAX_INDEX { - // Invalid entries are disabled - is_disabled[i] = true; - frame_indices[i] = 0; // We set disabled to 0 - } else { - // convert to frame index - let frame_idx = flat_idx % n_frames; - - // check if frame is beyond valid range - if frame_idx >= n_frames_no_sil { - is_disabled[i] = true; - frame_indices[i] = 0; // same as abov: set disabled to 0 - } else { - frame_indices[i] = frame_idx; - } - } - } - - (frame_indices, is_disabled) - } - - /// Gather selected frames - fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) { - let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM)); - let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS)); - - let cache_preds = self.spkcache_preds.as_ref().unwrap(); - - for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() { - if i >= SPKCACHE_LEN { - break; - } - - if disabled { - // Use silence embedding - new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..])); - // Predictions stay zero - } else if idx < self.spkcache.shape()[1] { - new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..])); - new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..])); - } - } - - (new_embs, new_preds) - } - - /// Concatenate along axis 1 for 3D arrays - fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> { - if a.shape()[1] == 0 { - return b.clone(); - } - if b.shape()[1] == 0 { - return a.clone(); - } - ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap() - } - - /// Concatenate along axis 0 for 2D arrays - fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> { - if a.shape()[0] == 0 { - return b.clone(); - } - if b.shape()[0] == 0 { - return a.clone(); - } - ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap() - } - - /// Concatenate predictions - fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> { - if preds.is_empty() { - return Array2::zeros((0, NUM_SPEAKERS)); - } - if preds.len() == 1 { - return preds[0].clone(); - } - - let views: Vec<_> = preds.iter().map(|p| p.view()).collect(); - ndarray::concatenate(Axis(0), &views).unwrap() - } - - /// Apply median filter to predictions - fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> { - let window = self.config.median_window; - let half = window / 2; - let mut filtered = preds.clone(); - - for spk in 0..NUM_SPEAKERS { - for t in 0..preds.shape()[0] { - let start = t.saturating_sub(half); - let end = (t + half + 1).min(preds.shape()[0]); - - let mut values: Vec<f32> = (start..end) - .map(|i| preds[[i, spk]]) - .collect(); - values.sort_by(|a, b| a.partial_cmp(b).unwrap()); - - filtered[[t, spk]] = values[values.len() / 2]; - } - } - - filtered - } - - /// Binarize predictions to segments (padding applied during thresholding) - fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> { - let mut segments = Vec::new(); - let num_frames = preds.shape()[0]; - - for spk in 0..NUM_SPEAKERS { - let mut in_seg = false; - let mut seg_start = 0; - let mut temp_segments = Vec::new(); - - for t in 0..num_frames { - let p = preds[[t, spk]]; - - if p >= self.config.onset && !in_seg { - in_seg = true; - seg_start = t; - } else if p < self.config.offset && in_seg { - in_seg = false; - - // Apply padding during conversion - let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); - let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset; - - if end_t - start_t >= self.config.min_duration_on { - temp_segments.push(SpeakerSegment { - start: start_t, - end: end_t, - speaker_id: spk, - }); - } - } - } - - // Handle segment at end - if in_seg { - let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); - let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset; - - if end_t - start_t >= self.config.min_duration_on { - temp_segments.push(SpeakerSegment { - start: start_t, - end: end_t, - speaker_id: spk, - }); - } - } - - // Merge close segments (min_duration_off) - if temp_segments.len() > 1 { - let mut filtered = vec![temp_segments[0].clone()]; - for seg in temp_segments.into_iter().skip(1) { - let last = filtered.last_mut().unwrap(); - let gap = seg.start - last.end; - if gap < self.config.min_duration_off { - last.end = seg.end; // Merge - } else { - filtered.push(seg); - } - } - segments.extend(filtered); - } else { - segments.extend(temp_segments); - } - } - - // Sort by start time - segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap()); - segments - } - - - fn apply_preemphasis(audio: &[f32]) -> Vec<f32> { - let mut result = Vec::with_capacity(audio.len()); - result.push(audio[0]); - for i in 1..audio.len() { - result.push(audio[i] - PREEMPH * audio[i - 1]); - } - result - } - - fn hann_window(window_length: usize) -> Vec<f32> { - // Librosa uses periodic window (fftbins=True): divide by N, not N-1 - (0..window_length) - .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos()) - .collect() - } - - fn stft(audio: &[f32]) -> Array2<f32> { - let mut planner = FftPlanner::<f32>::new(); - let fft = planner.plan_fft_forward(N_FFT); - - // Create Hann window of length win_length, then zero-pad to n_fft (centered) - // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft) - let hann = Self::hann_window(WIN_LENGTH); - let win_offset = (N_FFT - WIN_LENGTH) / 2; - let mut fft_window = vec![0.0f32; N_FFT]; - for i in 0..WIN_LENGTH { - fft_window[win_offset + i] = hann[i]; - } - - // Pad signal for center=True (like librosa/torch.stft) - // Padding is n_fft // 2 on each side - let pad_amount = N_FFT / 2; - let mut padded_audio = vec![0.0; pad_amount]; - padded_audio.extend_from_slice(audio); - padded_audio.extend(vec![0.0; pad_amount]); - - let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1; - let freq_bins = N_FFT / 2 + 1; - let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames)); - - for frame_idx in 0..num_frames { - let start = frame_idx * HOP_LENGTH; - let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT]; - - // Extract n_fft samples and multiply by zero-padded window - for i in 0..N_FFT { - if start + i < padded_audio.len() { - frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0); - } - } - - fft.process(&mut frame); - for k in 0..freq_bins { - let magnitude = frame[k].norm(); - // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0 - spectrogram[[k, frame_idx]] = magnitude * magnitude; - } - } - - spectrogram - } - - // Librosa's Slaney mel scale (htk=False, which is the default) - fn hz_to_mel_slaney(hz: f64) -> f64 { - let f_min = 0.0; - let f_sp = 200.0 / 3.0; - let min_log_hz = 1000.0; - let min_log_mel = (min_log_hz - f_min) / f_sp; - let logstep = (6.4f64).ln() / 27.0; - - if hz >= min_log_hz { - min_log_mel + (hz / min_log_hz).ln() / logstep - } else { - (hz - f_min) / f_sp - } - } - - fn mel_to_hz_slaney(mel: f64) -> f64 { - let f_min = 0.0; - let f_sp = 200.0 / 3.0; - let min_log_hz = 1000.0; - let min_log_mel = (min_log_hz - f_min) / f_sp; - let logstep = (6.4f64).ln() / 27.0; - - if mel >= min_log_mel { - min_log_hz * (logstep * (mel - min_log_mel)).exp() - } else { - f_min + f_sp * mel - } - } - - fn create_mel_filterbank() -> Array2<f32> { - // lets use f64 for intermediate calculations to avoid precision loss - let freq_bins = N_FFT / 2 + 1; - let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins)); - - // FFT frequencies: fftfreqs[k] = k * sr / n_fft - let fftfreqs: Vec<f64> = (0..freq_bins) - .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64) - .collect(); - - // Mel center frequencies using Slaney scale (librosa default, htk=False) - let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64); - let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64); - let mel_f: Vec<f64> = (0..=N_MELS + 1) - .map(|i| { - let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64; - Self::mel_to_hz_slaney(mel) - }) - .collect(); - - // Differences between consecutive mel frequencies - let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect(); - - // Compute filterbank weights (reference: librosa's ramp method) - // https://librosa.org/doc/main/generated/librosa.stft.html - for i in 0..N_MELS { - for k in 0..freq_bins { - // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i] - let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i]; - // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1] - let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1]; - // Weight is max(0, min(lower, upper)) - filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32; - } - } - - // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i]) - for i in 0..N_MELS { - let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]); - for k in 0..freq_bins { - filterbank[[i, k]] *= enorm as f32; - } - } - - filterbank - } - - fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> { - // 1. Add dither (small random noise to prevent log(0)) - // NeMo uses dither=1e-5, but for determinism we skip random noise - // The log_zero_guard handles zero values - - // 2. Apply preemphasis (NeMo uses preemph=0.97) - let preemphasized = Self::apply_preemphasis(audio); - - // 3. STFT - let spectrogram = Self::stft(&preemphasized); - - // 4. Apply mel filterbank (with Slaney normalization) - let mel_spec = self.mel_basis.dot(&spectrogram); - - // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24) - // NeMo uses normalize='NA' which means NO normalization - let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln()); - - let num_frames = log_mel_spec.shape()[1]; - let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS)); - - // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D) - for t in 0..num_frames { - for m in 0..N_MELS { - features[[0, t, m]] = log_mel_spec[[m, t]]; - } - } - - features - } -} diff --git a/parakeet-rs/src/timestamps.rs b/parakeet-rs/src/timestamps.rs deleted file mode 100644 index 81ea600..0000000 --- a/parakeet-rs/src/timestamps.rs +++ /dev/null @@ -1,280 +0,0 @@ -use crate::decoder::TimedToken; - -/// Timestamp output mode for transcription results -/// -/// Determines how token-level timestamps are grouped and presented: -/// - `Tokens`: Raw token-level output from the model (most detailed) -/// - `Words`: Tokens grouped into individual words -/// - `Sentences`: Tokens grouped by sentence boundaries (., ?, !) -/// -/// # Model-Specific Recommendations -/// -/// - **Parakeet CTC (English)**: Use `Words` mode. The CTC model only outputs lowercase -/// alphabet without punctuation, so sentence segmentation is not possible. -/// - **Parakeet TDT (Multilingual)**: Use `Sentences` mode. The TDT model predicts -/// punctuation, enabling natural sentence boundaries. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TimestampMode { - /// Raw token-level timestamps from the model - Tokens, - /// Word-level timestamps (groups subword tokens) - Words, - /// Sentence-level timestamps (groups by punctuation) - /// - /// Note: Only works with models that predict punctuation (e.g., Parakeet TDT). - /// CTC models don't predict punctuation, so use `Words` mode instead. - Sentences, -} - -impl Default for TimestampMode { - fn default() -> Self { - Self::Tokens - } -} - -/// Convert token timestamps to the requested output mode -/// -/// Takes raw token-level timestamps from the model and optionally groups them -/// into words or sentences while preserving the original timing information. -/// -/// # Arguments -/// -/// * `tokens` - Raw token-level timestamps from model output -/// * `mode` - Desired grouping level (Tokens, Words, or Sentences) -/// -/// # Returns -/// -/// Vector of TimedToken with timestamps at the requested granularity -pub fn process_timestamps(tokens: &[TimedToken], mode: TimestampMode) -> Vec<TimedToken> { - match mode { - TimestampMode::Tokens => tokens.to_vec(), - TimestampMode::Words => group_by_words(tokens), - TimestampMode::Sentences => group_by_sentences(tokens), - } -} - -// Group tokens into words based on word boundary markers -fn group_by_words(tokens: &[TimedToken]) -> Vec<TimedToken> { - if tokens.is_empty() { - return Vec::new(); - } - - let mut words = Vec::new(); - let mut current_word_text = String::new(); - let mut current_word_start = 0.0; - let mut last_word_lower = String::new(); - - for (i, token) in tokens.iter().enumerate() { - // Skip empty tokens - if token.text.trim().is_empty() { - continue; - } - - // Check if this starts a new word (SentencePiece uses ▁ or space prefix) - // Also treat PURE punctuation marks (like ".", ",") as separate words - // But NOT contractions like "'re" or "'s" which should attach to previous word - let is_pure_punctuation = !token.text.is_empty() && - token.text.chars().all(|c| c.is_ascii_punctuation()); - - // Check if this is a contraction suffix - // These should NOT start a new word - they attach to the previous word - let token_without_marker = token.text.trim_start_matches('▁').trim_start_matches(' '); - let is_contraction = token_without_marker.starts_with('\''); - - let starts_word = (token.text.starts_with('▁') - || token.text.starts_with(' ') - || is_pure_punctuation) - && !is_contraction - || i == 0; - - if starts_word && !current_word_text.is_empty() { - // Save previous word (with deduplication) - let word_lower = current_word_text.to_lowercase(); - if word_lower != last_word_lower { - words.push(TimedToken { - text: current_word_text.clone(), - start: current_word_start, - end: tokens[i - 1].end, - }); - last_word_lower = word_lower; - } - current_word_text.clear(); - } - - // Start new word or append to current - if current_word_text.is_empty() { - current_word_start = token.start; - } - - // Add token text, removing word boundary markers - let token_text = token - .text - .trim_start_matches('▁') - .trim_start_matches(' '); - current_word_text.push_str(token_text); - } - - // Add final word - if !current_word_text.is_empty() { - let word_lower = current_word_text.to_lowercase(); - if word_lower != last_word_lower { - words.push(TimedToken { - text: current_word_text, - start: current_word_start, - end: tokens.last().unwrap().end, - }); - } - } - - words -} - -// Group words into sentences based on punctuation -fn group_by_sentences(tokens: &[TimedToken]) -> Vec<TimedToken> { - // First get word-level grouping - let words = group_by_words(tokens); - if words.is_empty() { - return Vec::new(); - } - - let mut sentences = Vec::new(); - let mut current_sentence = Vec::new(); - - for word in words { - current_sentence.push(word.clone()); - - // Check if word ends with sentence terminator - let ends_sentence = word.text.contains('.') - || word.text.contains('?') - || word.text.contains('!'); - - if ends_sentence { - let sentence_text = format_sentence(¤t_sentence); - let start = current_sentence.first().unwrap().start; - let end = current_sentence.last().unwrap().end; - - if !sentence_text.is_empty() { - sentences.push(TimedToken { - text: sentence_text, - start, - end, - }); - } - current_sentence.clear(); - } - } - - // Add final sentence if exists - if !current_sentence.is_empty() { - let sentence_text = format_sentence(¤t_sentence); - let start = current_sentence.first().unwrap().start; - let end = current_sentence.last().unwrap().end; - - if !sentence_text.is_empty() { - sentences.push(TimedToken { - text: sentence_text, - start, - end, - }); - } - } - - sentences -} - -// Join words with punctuation spacing -fn format_sentence(words: &[TimedToken]) -> String { - let result: Vec<&str> = words.iter().map(|w| w.text.as_str()).collect(); - - // Join words, but don't add space before certain punctuation - let mut output = String::new(); - for (i, word) in result.iter().enumerate() { - // Check if this word is standalone punctuation that shouldn't have space before it - // Contractions like "'re" or "'s" should have spaces before them - let is_standalone_punct = word.len() == 1 && - word.chars().all(|c| matches!(c, '.' | ',' | '!' | '?' | ';' | ':' | ')')); - - if i > 0 && !is_standalone_punct { - output.push(' '); - } - output.push_str(word); - } - output -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_word_grouping() { - let tokens = vec![ - TimedToken { - text: "▁Hello".to_string(), - start: 0.0, - end: 0.5, - }, - TimedToken { - text: "▁world".to_string(), - start: 0.5, - end: 1.0, - }, - ]; - - let words = group_by_words(&tokens); - assert_eq!(words.len(), 2); - assert_eq!(words[0].text, "Hello"); - assert_eq!(words[1].text, "world"); - } - - #[test] - fn test_sentence_grouping() { - let tokens = vec![ - TimedToken { - text: "▁Hello".to_string(), - start: 0.0, - end: 0.5, - }, - TimedToken { - text: "▁world".to_string(), - start: 0.5, - end: 1.0, - }, - TimedToken { - text: ".".to_string(), - start: 1.0, - end: 1.1, - }, - ]; - - let sentences = group_by_sentences(&tokens); - assert_eq!(sentences.len(), 1); - assert_eq!(sentences[0].text, "Hello world."); - assert_eq!(sentences[0].start, 0.0); - assert_eq!(sentences[0].end, 1.1); - } - - #[test] - fn test_repetition_preservation() { - let words = vec![ - TimedToken { - text: "uh".to_string(), - start: 0.0, - end: 0.5, - }, - TimedToken { - text: "uh".to_string(), - start: 0.5, - end: 1.0, - }, - TimedToken { - text: "hello".to_string(), - start: 1.0, - end: 1.5, - }, - ]; - - let result = format_sentence(&words); - assert_eq!(result, "uh uh hello"); - } -} diff --git a/parakeet-rs/src/vocab.rs b/parakeet-rs/src/vocab.rs deleted file mode 100644 index 888568e..0000000 --- a/parakeet-rs/src/vocab.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::error::{Error, Result}; -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::path::Path; - -/// Vocabulary parser for vocab.txt format used by TDT models -#[derive(Debug, Clone)] -pub struct Vocabulary { - pub id_to_token: Vec<String>, - pub _blank_id: usize, -} - -impl Vocabulary { - /// Load vocabulary from vocab.txt file - pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> { - let file = File::open(path.as_ref()).map_err(|e| { - Error::Config(format!("Failed to open vocab file: {}", e)) - })?; - - let reader = BufReader::new(file); - let mut id_to_token = Vec::new(); - let mut blank_id = 0; - - for line in reader.lines() { - let line = line.map_err(|e| { - Error::Config(format!("Failed to read vocab file: {}", e)) - })?; - - let parts: Vec<&str> = line.splitn(2, ' ').collect(); - if parts.len() == 2 { - let token = parts[0].to_string(); - let id: usize = parts[1].parse().map_err(|e| { - Error::Config(format!("Invalid token ID in vocab: {}", e)) - })?; - - if id >= id_to_token.len() { - id_to_token.resize(id + 1, String::new()); - } - id_to_token[id] = token.clone(); - - // Track blank token - if token == "<blk>" || token == "<blank>" { - blank_id = id; - } - } - } - - // Default to last token if no blank found - if blank_id == 0 && !id_to_token.is_empty() { - blank_id = id_to_token.len() - 1; - } - - Ok(Self { - id_to_token, - _blank_id: blank_id, - }) - } - - /// Get token by ID - pub fn id_to_text(&self, id: usize) -> Option<&str> { - self.id_to_token.get(id).map(|s| s.as_str()) - } -} |
