summaryrefslogtreecommitdiff
path: root/parakeet-rs/src
diff options
context:
space:
mode:
Diffstat (limited to 'parakeet-rs/src')
-rw-r--r--parakeet-rs/src/audio.rs179
-rw-r--r--parakeet-rs/src/config.rs51
-rw-r--r--parakeet-rs/src/decoder.rs211
-rw-r--r--parakeet-rs/src/decoder_tdt.rs63
-rw-r--r--parakeet-rs/src/error.rs52
-rw-r--r--parakeet-rs/src/execution.rs141
-rw-r--r--parakeet-rs/src/lib.rs74
-rw-r--r--parakeet-rs/src/model.rs93
-rw-r--r--parakeet-rs/src/model_eou.rs183
-rw-r--r--parakeet-rs/src/model_tdt.rs263
-rw-r--r--parakeet-rs/src/parakeet.rs210
-rw-r--r--parakeet-rs/src/parakeet_eou.rs304
-rw-r--r--parakeet-rs/src/parakeet_tdt.rs167
-rw-r--r--parakeet-rs/src/sortformer.rs1062
-rw-r--r--parakeet-rs/src/timestamps.rs280
-rw-r--r--parakeet-rs/src/vocab.rs63
16 files changed, 0 insertions, 3396 deletions
diff --git a/parakeet-rs/src/audio.rs b/parakeet-rs/src/audio.rs
deleted file mode 100644
index 84d2616..0000000
--- a/parakeet-rs/src/audio.rs
+++ /dev/null
@@ -1,179 +0,0 @@
-use crate::config::PreprocessorConfig;
-use crate::error::{Error, Result};
-use hound::{WavReader, WavSpec};
-use ndarray::Array2;
-use std::f32::consts::PI;
-use std::path::Path;
-
-pub fn load_audio<P: AsRef<Path>>(path: P) -> Result<(Vec<f32>, WavSpec)> {
- let mut reader = WavReader::open(path)?;
- let spec = reader.spec();
-
- let samples: Vec<f32> = match spec.sample_format {
- hound::SampleFormat::Float => reader
- .samples::<f32>()
- .collect::<std::result::Result<Vec<_>, _>>()
- .map_err(|e| Error::Audio(format!("Failed to read float samples: {e}")))?,
- hound::SampleFormat::Int => reader
- .samples::<i16>()
- .map(|s| s.map(|s| s as f32 / 32768.0))
- .collect::<std::result::Result<Vec<_>, _>>()
- .map_err(|e| Error::Audio(format!("Failed to read int samples: {e}")))?,
- };
-
- Ok((samples, spec))
-}
-
-pub fn apply_preemphasis(audio: &[f32], coef: f32) -> Vec<f32> {
- let mut result = Vec::with_capacity(audio.len());
- result.push(audio[0]);
-
- for i in 1..audio.len() {
- result.push(audio[i] - coef * audio[i - 1]);
- }
-
- result
-}
-
-fn hann_window(window_length: usize) -> Vec<f32> {
- (0..window_length)
- .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / (window_length as f32 - 1.0)).cos())
- .collect()
-}
-
-// We use proper FFT here instead of naive DFT because the model was trained
-// on correctly computed spectrograms. Naive DFT produces wrong frequency bins
-// and the model outputs all blank tokens. RustFFT gives us O(n log n) performance
-// and numerically correct results that match what the model expects.
-pub fn stft(audio: &[f32], n_fft: usize, hop_length: usize, win_length: usize) -> Array2<f32> {
- use rustfft::{num_complex::Complex, FftPlanner};
-
- let window = hann_window(win_length);
- let num_frames = (audio.len() - win_length) / hop_length + 1;
- let freq_bins = n_fft / 2 + 1;
- let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
-
- let mut planner = FftPlanner::<f32>::new();
- let fft = planner.plan_fft_forward(n_fft);
-
- for frame_idx in 0..num_frames {
- let start = frame_idx * hop_length;
-
- let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
- for i in 0..win_length.min(audio.len() - start) {
- frame[i] = Complex::new(audio[start + i] * window[i], 0.0);
- }
-
- fft.process(&mut frame);
-
- for k in 0..freq_bins {
- let magnitude = frame[k].norm();
- spectrogram[[k, frame_idx]] = magnitude * magnitude;
- }
- }
-
- spectrogram
-}
-
-fn hz_to_mel(freq: f32) -> f32 {
- 2595.0 * (1.0 + freq / 700.0).log10()
-}
-
-fn mel_to_hz(mel: f32) -> f32 {
- 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
-}
-
-fn create_mel_filterbank(n_fft: usize, n_mels: usize, sample_rate: usize) -> Array2<f32> {
- let freq_bins = n_fft / 2 + 1;
- let mut filterbank = Array2::<f32>::zeros((n_mels, freq_bins));
-
- let min_mel = hz_to_mel(0.0);
- let max_mel = hz_to_mel(sample_rate as f32 / 2.0);
-
- let mel_points: Vec<f32> = (0..=n_mels + 1)
- .map(|i| mel_to_hz(min_mel + (max_mel - min_mel) * i as f32 / (n_mels + 1) as f32))
- .collect();
-
- let freq_bin_width = sample_rate as f32 / n_fft as f32;
-
- for mel_idx in 0..n_mels {
- let left = mel_points[mel_idx];
- let center = mel_points[mel_idx + 1];
- let right = mel_points[mel_idx + 2];
-
- for freq_idx in 0..freq_bins {
- let freq = freq_idx as f32 * freq_bin_width;
-
- if freq >= left && freq <= center {
- filterbank[[mel_idx, freq_idx]] = (freq - left) / (center - left);
- } else if freq > center && freq <= right {
- filterbank[[mel_idx, freq_idx]] = (right - freq) / (right - center);
- }
- }
- }
-
- filterbank
-}
-
-/// Extract mel spectrogram features from raw audio samples.
-///
-/// # Arguments
-///
-/// * `audio` - Audio samples as f32 values
-/// * `sample_rate` - Sample rate in Hz
-/// * `channels` - Number of audio channels
-/// * `config` - Preprocessor configuration
-///
-/// # Returns
-///
-/// 2D array of mel spectrogram features (time_steps x feature_size)
-pub fn extract_features_raw(
- mut audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- config: &PreprocessorConfig,
-) -> Result<Array2<f32>> {
- if sample_rate != config.sampling_rate as u32 {
- return Err(Error::Audio(format!(
- "Audio sample rate {} doesn't match expected {}. Please resample your audio first.",
- sample_rate, config.sampling_rate
- )));
- }
-
- if channels > 1 {
- let mono: Vec<f32> = audio
- .chunks(channels as usize)
- .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
- .collect();
- audio = mono;
- }
-
- audio = apply_preemphasis(&audio, config.preemphasis);
-
- let spectrogram = stft(&audio, config.n_fft, config.hop_length, config.win_length);
-
- let mel_filterbank =
- create_mel_filterbank(config.n_fft, config.feature_size, config.sampling_rate);
- let mel_spectrogram = mel_filterbank.dot(&spectrogram);
- let mel_spectrogram = mel_spectrogram.mapv(|x| (x.max(1e-10)).ln());
-
- let mut mel_spectrogram = mel_spectrogram.t().to_owned();
-
- // Normalize each feature dimension to mean=0, std=1
- let num_frames = mel_spectrogram.shape()[0];
- let num_features = mel_spectrogram.shape()[1];
-
- for feat_idx in 0..num_features {
- let mut column = mel_spectrogram.column_mut(feat_idx);
- let mean: f32 = column.iter().sum::<f32>() / num_frames as f32;
- let variance: f32 =
- column.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32;
- let std = variance.sqrt().max(1e-10);
-
- for val in column.iter_mut() {
- *val = (*val - mean) / std;
- }
- }
-
- Ok(mel_spectrogram)
-}
diff --git a/parakeet-rs/src/config.rs b/parakeet-rs/src/config.rs
deleted file mode 100644
index 1dae890..0000000
--- a/parakeet-rs/src/config.rs
+++ /dev/null
@@ -1,51 +0,0 @@
-use serde::{Deserialize, Serialize};
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PreprocessorConfig {
- pub feature_extractor_type: String,
- pub feature_size: usize,
- pub hop_length: usize,
- pub n_fft: usize,
- pub padding_side: String,
- pub padding_value: f32,
- pub preemphasis: f32,
- pub processor_class: String,
- pub return_attention_mask: bool,
- pub sampling_rate: usize,
- pub win_length: usize,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct ModelConfig {
- pub architectures: Vec<String>,
- pub vocab_size: usize,
- pub pad_token_id: usize,
-}
-
-impl Default for PreprocessorConfig {
- fn default() -> Self {
- Self {
- feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
- feature_size: 80,
- hop_length: 160,
- n_fft: 512,
- padding_side: "right".to_string(),
- padding_value: 0.0,
- preemphasis: 0.97,
- processor_class: "ParakeetProcessor".to_string(),
- return_attention_mask: true,
- sampling_rate: 16000,
- win_length: 400,
- }
- }
-}
-
-impl Default for ModelConfig {
- fn default() -> Self {
- Self {
- architectures: vec!["ParakeetForCTC".to_string()],
- vocab_size: 1025,
- pad_token_id: 1024,
- }
- }
-}
diff --git a/parakeet-rs/src/decoder.rs b/parakeet-rs/src/decoder.rs
deleted file mode 100644
index 6da6d65..0000000
--- a/parakeet-rs/src/decoder.rs
+++ /dev/null
@@ -1,211 +0,0 @@
-use crate::error::{Error, Result};
-use ndarray::Array2;
-use std::path::Path;
-
-// Token with its timestamp information
-// start and end are in seconds
-#[derive(Debug, Clone)]
-pub struct TimedToken {
- pub text: String,
- pub start: f32,
- pub end: f32,
-}
-
-#[derive(Debug, Clone)]
-pub struct TranscriptionResult {
- pub text: String,
- pub tokens: Vec<TimedToken>,
-}
-
-// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps
-pub struct ParakeetDecoder {
- tokenizer: tokenizers::Tokenizer,
- pad_token_id: usize,
-}
-
-impl ParakeetDecoder {
- pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> {
- let tokenizer_path = tokenizer_path.as_ref();
-
- let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
- .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
-
- // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
- let pad_token_id = 1024;
-
- Ok(Self {
- tokenizer,
- pad_token_id,
- })
- }
-
- pub fn decode(&self, logits: &Array2<f32>) -> Result<String> {
- let time_steps = logits.shape()[0];
-
- let mut token_ids = Vec::new();
- for t in 0..time_steps {
- let logits_t = logits.row(t);
- let max_idx = logits_t
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx)
- .unwrap_or(0);
-
- token_ids.push(max_idx as u32);
- }
-
- let collapsed = self.ctc_collapse(&token_ids);
-
- let text = self
- .tokenizer
- .decode(&collapsed, true)
- .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
-
- Ok(text)
- }
-
- fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> {
- let mut result = Vec::new();
- let mut prev_token: Option<u32> = None;
-
- for &token_id in token_ids {
- if token_id == self.pad_token_id as u32 {
- prev_token = Some(token_id);
- continue;
- }
-
- if Some(token_id) != prev_token {
- result.push(token_id);
- }
-
- prev_token = Some(token_id);
- }
-
- result
- }
-
- // CTC collapse with frame tracking for timestamps
- fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> {
- let mut result: Vec<(u32, usize, usize)> = Vec::new();
- let mut prev_token: Option<u32> = None;
-
- for &(token_id, frame) in token_ids.iter() {
- if token_id == self.pad_token_id as u32 {
- prev_token = Some(token_id);
- continue;
- }
-
- if Some(token_id) != prev_token {
- if let Some(prev) = prev_token {
- if prev != self.pad_token_id as u32 {
- // End previous token
- if let Some(last) = result.last_mut() {
- last.2 = frame;
- }
- }
- }
- // Start new token
- result.push((token_id, frame, frame));
- }
-
- prev_token = Some(token_id);
- }
-
- // Close last token
- if let Some(last) = result.last_mut() {
- last.2 = token_ids.len();
- }
-
- result
- }
-
- // Decode with token-level timestamps
- // hop_length and sample_rate are needed to convert frames to seconds
- pub fn decode_with_timestamps(
- &self,
- logits: &Array2<f32>,
- hop_length: usize,
- sample_rate: usize,
- ) -> Result<TranscriptionResult> {
- let time_steps = logits.shape()[0];
-
- let mut token_ids_with_frames = Vec::new();
- for t in 0..time_steps {
- let logits_t = logits.row(t);
- let max_idx = logits_t
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx)
- .unwrap_or(0);
-
- token_ids_with_frames.push((max_idx as u32, t));
- }
-
- // CTC collapse with frame tracking
- let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames);
-
- // Extract just token IDs for decoding
- let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect();
-
- // Decode full text
- let full_text = self
- .tokenizer
- .decode(&token_ids, true)
- .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
-
- // Progressive decode to detect word boundaries
- // BPE tokenizers only add spaces when decoding sequences, not individual tokens
- let mut timed_tokens = Vec::new();
- let mut prev_decode = String::new();
-
- for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() {
- // Decode from start up to and including current token
- let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i]
- .iter()
- .map(|(id, _, _)| *id)
- .collect();
-
- if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) {
- // Find what this token added
- let added_text = if curr_decode.len() > prev_decode.len() {
- &curr_decode[prev_decode.len()..]
- } else {
- ""
- };
-
- if !added_text.is_empty() {
- let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32;
- let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32;
-
- timed_tokens.push(TimedToken {
- text: added_text.to_string(),
- start: start_time,
- end: end_time,
- });
- }
-
- prev_decode = curr_decode;
- }
- }
-
- Ok(TranscriptionResult {
- text: full_text,
- tokens: timed_tokens,
- })
- }
-
- // Stub - falls back to greedy decoding. Full beam search with language model is TODO.
- pub fn decode_with_beam_search(
- &self,
- logits: &Array2<f32>,
- _beam_width: usize,
- ) -> Result<String> {
- self.decode(logits)
- }
-
- pub fn pad_token_id(&self) -> usize {
- self.pad_token_id
- }
-}
diff --git a/parakeet-rs/src/decoder_tdt.rs b/parakeet-rs/src/decoder_tdt.rs
deleted file mode 100644
index 65f576d..0000000
--- a/parakeet-rs/src/decoder_tdt.rs
+++ /dev/null
@@ -1,63 +0,0 @@
-use crate::decoder::TranscriptionResult;
-use crate::error::Result;
-use crate::vocab::Vocabulary;
-
-/// TDT greedy decoder for Parakeet TDT models
-#[derive(Debug)]
-pub struct ParakeetTDTDecoder {
- vocab: Vocabulary,
-}
-
-impl ParakeetTDTDecoder {
- /// Load decoder from vocab file
- pub fn from_vocab(vocab: Vocabulary) -> Self {
- Self { vocab }
- }
-
- /// Decode tokens with timestamps
- /// For TDT models, greedy decoding is done in the model, here we just convert to text
- pub fn decode_with_timestamps(
- &self,
- tokens: &[usize],
- frame_indices: &[usize],
- _durations: &[usize],
- hop_length: usize,
- sample_rate: usize,
- ) -> Result<TranscriptionResult> {
- let mut result_tokens = Vec::new();
- let mut full_text = String::new();
- // TDT encoder does 8x subsampling
- let encoder_stride = 8;
-
- for (i, &token_id) in tokens.iter().enumerate() {
- if let Some(token_text) = self.vocab.id_to_text(token_id) {
- let frame = frame_indices[i];
- let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32;
- let end = if i + 1 < frame_indices.len() {
- (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32
- } else {
- start + 0.01
- };
-
- // Handle SentencePiece format (▁ prefix for word start)
- let display_text = token_text.replace('▁', " ");
-
- // Skip special tokens
- if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") {
- full_text.push_str(&display_text);
-
- result_tokens.push(crate::decoder::TimedToken {
- text: display_text,
- start,
- end,
- });
- }
- }
- }
-
- Ok(TranscriptionResult {
- text: full_text.trim().to_string(),
- tokens: result_tokens,
- })
- }
-}
diff --git a/parakeet-rs/src/error.rs b/parakeet-rs/src/error.rs
deleted file mode 100644
index 690e0e5..0000000
--- a/parakeet-rs/src/error.rs
+++ /dev/null
@@ -1,52 +0,0 @@
-use std::fmt;
-
-pub type Result<T> = std::result::Result<T, Error>;
-
-#[derive(Debug)]
-pub enum Error {
- Io(std::io::Error),
- Ort(ort::Error),
- Audio(String),
- Model(String),
- Tokenizer(String),
- Config(String),
-}
-
-impl fmt::Display for Error {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- match self {
- Error::Io(e) => write!(f, "IO error: {e}"),
- Error::Ort(e) => write!(f, "ONNX Runtime error: {e}"),
- Error::Audio(msg) => write!(f, "Audio processing error: {msg}"),
- Error::Model(msg) => write!(f, "Model error: {msg}"),
- Error::Tokenizer(msg) => write!(f, "Tokenizer error: {msg}"),
- Error::Config(msg) => write!(f, "Config error: {msg}"),
- }
- }
-}
-
-impl std::error::Error for Error {}
-
-impl From<std::io::Error> for Error {
- fn from(e: std::io::Error) -> Self {
- Error::Io(e)
- }
-}
-
-impl From<ort::Error> for Error {
- fn from(e: ort::Error) -> Self {
- Error::Ort(e)
- }
-}
-
-impl From<serde_json::Error> for Error {
- fn from(e: serde_json::Error) -> Self {
- Error::Config(e.to_string())
- }
-}
-
-impl From<hound::Error> for Error {
- fn from(e: hound::Error) -> Self {
- Error::Audio(e.to_string())
- }
-}
diff --git a/parakeet-rs/src/execution.rs b/parakeet-rs/src/execution.rs
deleted file mode 100644
index e29aa1d..0000000
--- a/parakeet-rs/src/execution.rs
+++ /dev/null
@@ -1,141 +0,0 @@
-use crate::error::Result;
-use ort::session::builder::SessionBuilder;
-
-// Hardware acceleration options. CPU is default and most reliable.
-// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware.
-// All GPU providers automatically fall back to CPU if they fail.
-//
-// Note: CoreML currently fails with this model due to unsupported operations.
-// WebGPU is experimental and may produce incorrect results.
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
-pub enum ExecutionProvider {
- #[default]
- Cpu,
- #[cfg(feature = "cuda")]
- Cuda,
- #[cfg(feature = "tensorrt")]
- TensorRT,
- #[cfg(feature = "coreml")]
- CoreML,
- #[cfg(feature = "directml")]
- DirectML,
- #[cfg(feature = "rocm")]
- ROCm,
- #[cfg(feature = "openvino")]
- OpenVINO,
- #[cfg(feature = "webgpu")]
- WebGPU,
-}
-
-#[derive(Debug, Clone)]
-pub struct ModelConfig {
- pub execution_provider: ExecutionProvider,
- pub intra_threads: usize,
- pub inter_threads: usize,
-}
-
-impl Default for ModelConfig {
- fn default() -> Self {
- Self {
- execution_provider: ExecutionProvider::default(),
- intra_threads: 4,
- inter_threads: 1,
- }
- }
-}
-
-impl ModelConfig {
- pub fn new() -> Self {
- Self::default()
- }
-
- pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
- self.execution_provider = provider;
- self
- }
-
- pub fn with_intra_threads(mut self, threads: usize) -> Self {
- self.intra_threads = threads;
- self
- }
-
- pub fn with_inter_threads(mut self, threads: usize) -> Self {
- self.inter_threads = threads;
- self
- }
-
- pub(crate) fn apply_to_session_builder(
- &self,
- builder: SessionBuilder,
- ) -> Result<SessionBuilder> {
- use ort::session::builder::GraphOptimizationLevel;
- #[cfg(any(
- feature = "cuda",
- feature = "tensorrt",
- feature = "coreml",
- feature = "directml",
- feature = "rocm",
- feature = "openvino",
- feature = "webgpu"
- ))]
- use ort::execution_providers::CPUExecutionProvider;
-
- let mut builder = builder
- .with_optimization_level(GraphOptimizationLevel::Level3)?
- .with_intra_threads(self.intra_threads)?
- .with_inter_threads(self.inter_threads)?;
-
- builder = match self.execution_provider {
- ExecutionProvider::Cpu => builder,
-
- #[cfg(feature = "cuda")]
- ExecutionProvider::Cuda => builder.with_execution_providers([
- ort::execution_providers::CUDAExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
-
- #[cfg(feature = "tensorrt")]
- ExecutionProvider::TensorRT => builder.with_execution_providers([
- ort::execution_providers::TensorRTExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
-
- #[cfg(feature = "coreml")]
- ExecutionProvider::CoreML => {
- use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider};
- builder.with_execution_providers([
- CoreMLExecutionProvider::default()
- .with_compute_units(CoreMLComputeUnits::CPUAndGPU)
- .build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?
- }
-
- #[cfg(feature = "directml")]
- ExecutionProvider::DirectML => builder.with_execution_providers([
- ort::execution_providers::DirectMLExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
-
- #[cfg(feature = "rocm")]
- ExecutionProvider::ROCm => builder.with_execution_providers([
- ort::execution_providers::ROCMExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
-
- #[cfg(feature = "openvino")]
- ExecutionProvider::OpenVINO => builder.with_execution_providers([
- ort::execution_providers::OpenVINOExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
-
- #[cfg(feature = "webgpu")]
- ExecutionProvider::WebGPU => builder.with_execution_providers([
- ort::execution_providers::WebGPUExecutionProvider::default().build(),
- CPUExecutionProvider::default().build().error_on_failure(),
- ])?,
- };
-
- Ok(builder)
- }
-}
diff --git a/parakeet-rs/src/lib.rs b/parakeet-rs/src/lib.rs
deleted file mode 100644
index 0aaefd1..0000000
--- a/parakeet-rs/src/lib.rs
+++ /dev/null
@@ -1,74 +0,0 @@
-//! # parakeet-rs
-//!
-//! Rust bindings for NVIDIA's Parakeet speech recognition model using ONNX Runtime.
-//!
-//! Parakeet is a state-of-the-art automatic speech recognition (ASR) model developed by NVIDIA,
-//! based on the FastConformer-TDT architecture with 600 million parameters.
-//!
-//! ## Features
-//!
-//! - Easy-to-use API for speech-to-text transcription
-//! - Support for ONNX format models
-//! - 16kHz mono audio input
-//! - Punctuation and capitalization included in output
-//! - Fast inference using ONNX Runtime
-//!
-//! ## Quick Start
-//!
-//! ```ignore
-//! use parakeet_rs::Parakeet;
-//!
-//! // Load the model
-//! let parakeet = Parakeet::from_pretrained(".")?;
-//!
-//! // Transcribe audio file
-//! let text = parakeet.transcribe_file("audio.wav")?;
-//! println!("Transcription: {}", text);
-//! ```
-//!
-//! ## Model Requirements
-//!
-//! Your model directory should contain:
-//! - `model.onnx` - The ONNX model file
-//! - `model.onnx_data` - External model weights
-//! - `config.json` - Model configuration
-//! - `preprocessor_config.json` - Audio preprocessing configuration
-//! - `tokenizer.json` - Tokenizer vocabulary
-//! - `tokenizer_config.json` - Tokenizer configuration
-//!
-//! ## Audio Requirements
-//!
-//! - Format: WAV
-//! - Sample Rate: 16kHz
-//! - Channels: Mono (stereo will be converted automatically)
-//! - Bit Depth: 16-bit PCM or 32-bit float
-
-mod audio;
-mod config;
-mod decoder;
-mod decoder_tdt;
-mod error;
-mod execution;
-mod model;
-mod model_tdt;
-mod parakeet;
-mod parakeet_tdt;
-mod timestamps;
-mod vocab;
-mod model_eou;
-mod parakeet_eou;
-#[cfg(feature = "sortformer")]
-pub mod sortformer;
-
-pub use error::{Error, Result};
-pub use execution::{ExecutionProvider, ModelConfig as ExecutionConfig};
-pub use parakeet::Parakeet;
-pub use parakeet_tdt::ParakeetTDT;
-pub use timestamps::TimestampMode;
-
-pub use config::{ModelConfig as ModelConfigJson, PreprocessorConfig};
-
-pub use decoder::{ParakeetDecoder, TimedToken, TranscriptionResult};
-pub use model::ParakeetModel;
-pub use model_eou::ParakeetEOUModel;
-pub use parakeet_eou::ParakeetEOU; \ No newline at end of file
diff --git a/parakeet-rs/src/model.rs b/parakeet-rs/src/model.rs
deleted file mode 100644
index b3cd131..0000000
--- a/parakeet-rs/src/model.rs
+++ /dev/null
@@ -1,93 +0,0 @@
-use crate::config::ModelConfig;
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use ndarray::Array2;
-use ort::session::Session;
-use std::path::Path;
-
-pub struct ParakeetModel {
- session: Session,
- config: ModelConfig,
-}
-
-impl ParakeetModel {
- pub fn from_pretrained<P: AsRef<Path>>(model_path: P) -> Result<Self> {
- Self::from_pretrained_with_config(model_path, ExecutionConfig::default())
- }
-
- pub fn from_pretrained_with_config<P: AsRef<Path>>(
- model_path: P,
- exec_config: ExecutionConfig,
- ) -> Result<Self> {
- let model_path = model_path.as_ref();
-
- // Use default config (hardcoded constants for Parakeet-CTC-0.6b: please see: json files https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
- let config = ModelConfig::default();
-
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let session = builder.commit_from_file(model_path)?;
-
- Ok(Self { session, config })
- }
- pub fn forward(&mut self, features: Array2<f32>) -> Result<Array2<f32>> {
- let batch_size = 1;
- let time_steps = features.shape()[0];
- let feature_size = features.shape()[1];
-
- let input = features
- .to_shape((batch_size, time_steps, feature_size))
- .map_err(|e| Error::Model(format!("Failed to reshape input: {e}")))?
- .to_owned();
-
- use ndarray::Array2;
- let attention_mask = Array2::<i64>::ones((batch_size, time_steps));
-
- let input_value = ort::value::Value::from_array(input)?;
- let attention_mask_value = ort::value::Value::from_array(attention_mask)?;
-
- let outputs = self.session.run(ort::inputs!(
- "input_features" => input_value,
- "attention_mask" => attention_mask_value
- ))?;
-
- let logits_value = &outputs["logits"];
- let (shape, data) = logits_value
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
-
- let shape_dims = shape.as_ref();
- if shape_dims.len() != 3 {
- return Err(Error::Model(format!(
- "Expected 3D logits, got shape: {shape_dims:?}"
- )));
- }
-
- let batch_size = shape_dims[0] as usize;
- let time_steps_out = shape_dims[1] as usize;
- let vocab_size = shape_dims[2] as usize;
-
- if batch_size != 1 {
- return Err(Error::Model(format!(
- "Expected batch size 1, got {batch_size}"
- )));
- }
-
- let logits_2d = Array2::from_shape_vec((time_steps_out, vocab_size), data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to create array: {e}")))?;
-
- Ok(logits_2d)
- }
-
- pub fn config(&self) -> &ModelConfig {
- &self.config
- }
-
- pub fn vocab_size(&self) -> usize {
- self.config.vocab_size
- }
-
- pub fn pad_token_id(&self) -> usize {
- self.config.pad_token_id
- }
-}
diff --git a/parakeet-rs/src/model_eou.rs b/parakeet-rs/src/model_eou.rs
deleted file mode 100644
index 5b56e6d..0000000
--- a/parakeet-rs/src/model_eou.rs
+++ /dev/null
@@ -1,183 +0,0 @@
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use ndarray::{Array1, Array2, Array3, Array4};
-use ort::session::Session;
-use std::path::Path;
-
-/// Encoder cache state for streaming inference
-/// The cache maintains temporal context across chunks
-pub struct EncoderCache {
- /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback
- pub cache_last_channel: Array4<f32>,
- /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps
- pub cache_last_time: Array4<f32>,
- /// cache length: [1] with value 0 initially
- pub cache_last_channel_len: Array1<i64>,
-}
-
-impl EncoderCache {
- /// 17 layers, batch=1, 70 frame lookback, 512 features
- pub fn new() -> Self {
- Self {
- cache_last_channel: Array4::zeros((17, 1, 70, 512)),
- cache_last_time: Array4::zeros((17, 1, 512, 8)),
- cache_last_channel_len: Array1::from_vec(vec![0i64]),
- }
- }
-}
-
-pub struct ParakeetEOUModel {
- encoder: Session,
- decoder_joint: Session,
-}
-
-impl ParakeetEOUModel {
- pub fn from_pretrained<P: AsRef<Path>>(
- model_dir: P,
- exec_config: ExecutionConfig,
- ) -> Result<Self> {
- let model_dir = model_dir.as_ref();
-
- let encoder_path = model_dir.join("encoder.onnx");
- let decoder_path = model_dir.join("decoder_joint.onnx");
-
- if !encoder_path.exists() || !decoder_path.exists() {
- return Err(Error::Config(format!(
- "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx",
- model_dir.display()
- )));
- }
-
- // Load encoder
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let encoder = builder.commit_from_file(&encoder_path)?;
-
- // Load decoder
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let decoder_joint = builder.commit_from_file(&decoder_path)?;
-
- Ok(Self {
- encoder,
- decoder_joint,
- })
- }
-
- /// Run the stateful encoder with cache
- /// Input: features [1, 128, T], cache state
- /// Output: (encoded [1, 512, T], new_cache)
- pub fn run_encoder(
- &mut self,
- features: &Array3<f32>,
- length: i64,
- cache: &EncoderCache
- ) -> Result<(Array3<f32>, EncoderCache)> {
- let length_arr = Array1::from_vec(vec![length]);
-
- let outputs = self.encoder.run(ort::inputs![
- "audio_signal" => ort::value::Value::from_array(features.clone())?,
- "length" => ort::value::Value::from_array(length_arr)?,
- "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
- "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
- "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
- ])?;
-
- // Extract encoder output [1, 512, T]
- let (shape, data) = outputs["outputs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
-
- let shape_dims = shape.as_ref();
- let b = shape_dims[0] as usize;
- let d = shape_dims[1] as usize;
- let t = shape_dims[2] as usize;
-
- let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
-
- // Extract new cache states
- let (ch_shape, ch_data) = outputs["new_cache_last_channel"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
-
- let (tm_shape, tm_data) = outputs["new_cache_last_time"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
-
- let (len_shape, len_data) = outputs["new_cache_last_channel_len"]
- .try_extract_tensor::<i64>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
-
- // Build new cache with extracted shapes
- let new_cache = EncoderCache {
- cache_last_channel: Array4::from_shape_vec(
- (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize),
- ch_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
-
- cache_last_time: Array4::from_shape_vec(
- (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize),
- tm_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
-
- cache_last_channel_len: Array1::from_shape_vec(
- len_shape[0] as usize,
- len_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
- };
-
- Ok((encoder_out, new_cache))
- }
-
- /// Run the stateful decoder
- /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c)
- pub fn run_decoder(
- &mut self,
- encoder_frame: &Array3<f32>, // [1, 512, 1]
- last_token: &Array2<i32>, // [1, 1]
- state_h: &Array3<f32>, // [1, 1, 640]
- state_c: &Array3<f32>, // [1, 1, 640]
- ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> {
-
- // Target length is always 1 for single step
- let target_len = Array1::from_vec(vec![1i32]);
-
- let outputs = self.decoder_joint.run(ort::inputs![
- "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
- "targets" => ort::value::Value::from_array(last_token.clone())?,
- "target_length" => ort::value::Value::from_array(target_len)?,
- "input_states_1" => ort::value::Value::from_array(state_h.clone())?,
- "input_states_2" => ort::value::Value::from_array(state_c.clone())?
- ])?;
-
- // 1. Extract Logits
- let (l_shape, l_data) = outputs["outputs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
-
- // 2. Extract States (output_states_1, output_states_2)
- let (_h_shape, h_data) = outputs["output_states_1"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?;
-
- let (_c_shape, c_data) = outputs["output_states_2"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?;
-
- // Reconstruct Arrays
- // Logits: I simplify to [1, 1, vocab]
- let vocab_size = l_shape[3] as usize;
- let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?;
-
- // States: [1, 1, 640]
- let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?;
-
- let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?;
-
- Ok((logits, new_h, new_c))
- }
-} \ No newline at end of file
diff --git a/parakeet-rs/src/model_tdt.rs b/parakeet-rs/src/model_tdt.rs
deleted file mode 100644
index e00ebdc..0000000
--- a/parakeet-rs/src/model_tdt.rs
+++ /dev/null
@@ -1,263 +0,0 @@
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use ndarray::{Array1, Array2, Array3};
-use ort::session::Session;
-use std::path::{Path, PathBuf};
-
-/// TDT model configs
-#[derive(Debug, Clone)]
-pub struct TDTModelConfig {
- pub vocab_size: usize,
-}
-
-impl Default for TDTModelConfig {
- fn default() -> Self {
- Self {
- vocab_size: 8193,
- }
- }
-}
-
-pub struct ParakeetTDTModel {
- encoder: Session,
- decoder_joint: Session,
- config: TDTModelConfig,
-}
-
-impl ParakeetTDTModel {
- /// Load TDT model from directory containing encoder and decoder_joint ONNX files
- pub fn from_pretrained<P: AsRef<Path>>(
- model_dir: P,
- exec_config: ExecutionConfig,
- ) -> Result<Self> {
- let model_dir = model_dir.as_ref();
-
- // Find encoder and decoder_joint files
- let encoder_path = Self::find_encoder(model_dir)?;
- let decoder_joint_path = Self::find_decoder_joint(model_dir)?;
-
- let config = TDTModelConfig::default();
-
- // Load encoder
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let encoder = builder.commit_from_file(&encoder_path)?;
-
- // Load decoder_joint
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let decoder_joint = builder.commit_from_file(&decoder_joint_path)?;
-
-
- Ok(Self {
- encoder,
- decoder_joint,
- config,
- })
- }
-
- fn find_encoder(dir: &Path) -> Result<PathBuf> {
- let candidates = ["encoder-model.onnx", "encoder.onnx"];
- for candidate in &candidates {
- let path = dir.join(candidate);
- if path.exists() {
- return Ok(path);
- }
- }
- Err(Error::Config(format!(
- "No encoder model found in {}",
- dir.display()
- )))
- }
-
- fn find_decoder_joint(dir: &Path) -> Result<PathBuf> {
- let candidates = [
- "decoder_joint-model.onnx",
- "decoder_joint.onnx",
- "decoder-model.onnx",
- ];
- for candidate in &candidates {
- let path = dir.join(candidate);
- if path.exists() {
- return Ok(path);
- }
- }
- Err(Error::Config(format!(
- "No decoder_joint model found in {}",
- dir.display()
- )))
- }
-
- /// Run greedy decoding - returns (token_ids, frame_indices, durations)
- pub fn forward(&mut self, features: Array2<f32>) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
- // Run encoder
- let (encoder_out, encoder_len) = self.run_encoder(&features)?;
-
- // Run greedy decoding with decoder_joint
- let (tokens, frame_indices, durations) = self.greedy_decode(&encoder_out, encoder_len)?;
-
- Ok((tokens, frame_indices, durations))
- }
-
- fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Array3<f32>, i64)> {
- let batch_size = 1;
- let time_steps = features.shape()[0];
- let feature_size = features.shape()[1];
-
- // TDT encoder expects (batch, features, time) not (batch, time, features)
- let input = features
- .t()
- .to_shape((batch_size, feature_size, time_steps))
- .map_err(|e| Error::Model(format!("Failed to reshape encoder input: {e}")))?
- .to_owned();
-
- let input_length = Array1::from_vec(vec![time_steps as i64]);
-
- let input_value = ort::value::Value::from_array(input)?;
- let length_value = ort::value::Value::from_array(input_length)?;
-
- let outputs = self.encoder.run(ort::inputs!(
- "audio_signal" => input_value,
- "length" => length_value
- ))?;
-
- let encoder_out = &outputs["outputs"];
- let encoder_lens = &outputs["encoded_lengths"];
-
- let (shape, data) = encoder_out
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
-
- let (_, lens_data) = encoder_lens
- .try_extract_tensor::<i64>()
- .map_err(|e| Error::Model(format!("Failed to extract encoder lengths: {e}")))?;
-
- let shape_dims = shape.as_ref();
- if shape_dims.len() != 3 {
- return Err(Error::Model(format!(
- "Expected 3D encoder output, got shape: {shape_dims:?}"
- )));
- }
-
- let b = shape_dims[0] as usize;
- let t = shape_dims[1] as usize;
- let d = shape_dims[2] as usize;
-
- let encoder_array = Array3::from_shape_vec((b, t, d), data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to create encoder array: {e}")))?;
-
- // TDT encoder outputs [batch, encoder_dim, time] directly
- Ok((encoder_array, lens_data[0]))
- }
-
- fn greedy_decode(&mut self, encoder_out: &Array3<f32>, _encoder_len: i64) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> {
- // encoder_out shape: [batch, encoder_dim, time]
- let encoder_dim = encoder_out.shape()[1];
- let time_steps = encoder_out.shape()[2];
- let vocab_size = self.config.vocab_size;
- let max_tokens_per_step = 10;
- let blank_id = vocab_size - 1;
-
- // States: (num_layers=2, batch=1, hidden_dim=640)
- let mut state_h = Array3::<f32>::zeros((2, 1, 640));
- let mut state_c = Array3::<f32>::zeros((2, 1, 640));
-
- let mut tokens = Vec::new();
- let mut frame_indices = Vec::new();
- let mut durations = Vec::new();
-
- let mut t = 0;
- let mut emitted_tokens = 0;
- let mut last_emitted_token = blank_id as i32;
-
- // Frame-by-frame RNN-T/TDT greedy decoding
- while t < time_steps {
- // Get single encoder frame: slice [0, :, t] and reshape to [1, encoder_dim, 1]
- let frame = encoder_out.slice(ndarray::s![0, .., t]).to_owned();
- let frame_reshaped = frame
- .to_shape((1, encoder_dim, 1))
- .map_err(|e| Error::Model(format!("Failed to reshape frame: {e}")))?
- .to_owned();
-
- // Current token for prediction network
- let targets = Array2::from_shape_vec((1, 1), vec![last_emitted_token])
- .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?;
-
- // Run decoder_joint
- let outputs = self.decoder_joint.run(ort::inputs!(
- "encoder_outputs" => ort::value::Value::from_array(frame_reshaped)?,
- "targets" => ort::value::Value::from_array(targets)?,
- "target_length" => ort::value::Value::from_array(Array1::from_vec(vec![1i32]))?,
- "input_states_1" => ort::value::Value::from_array(state_h.clone())?,
- "input_states_2" => ort::value::Value::from_array(state_c.clone())?
- ))?;
-
- // Extract logits
- let (_, logits_data) = outputs["outputs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
-
- // TDT outputs vocab_size + 5 durations (8193 + 5 = 8198)
- let vocab_logits: Vec<f32> = logits_data.iter().take(vocab_size).copied().collect();
- let duration_logits: Vec<f32> = logits_data.iter().skip(vocab_size).copied().collect();
-
- let token_id = vocab_logits
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx)
- .unwrap_or(blank_id);
-
- let duration_step = if !duration_logits.is_empty() {
- duration_logits
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx)
- .unwrap_or(0)
- } else {
- 0
- };
-
- // Check if blank token
- if token_id != blank_id {
- // Update states when we emit a token
- if let Ok((h_shape, h_data)) = outputs["output_states_1"].try_extract_tensor::<f32>() {
- let dims = h_shape.as_ref();
- state_h = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), h_data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to update state_h: {e}")))?;
- }
- if let Ok((c_shape, c_data)) = outputs["output_states_2"].try_extract_tensor::<f32>() {
- let dims = c_shape.as_ref();
- state_c = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), c_data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to update state_c: {e}")))?;
- }
-
- tokens.push(token_id);
- frame_indices.push(t);
- durations.push(duration_step);
- last_emitted_token = token_id as i32;
- emitted_tokens += 1;
-
- // Don't advance yet - try to emit more tokens from the same frame
- } else {
- // Blank token - advance frame pointer
- // Duration prediction applies when we finally move to next frame after emitting tokens
- if duration_step > 0 && emitted_tokens > 0 {
- t += duration_step;
- } else {
- t += 1;
- }
- emitted_tokens = 0;
- }
-
- // Safety check: if we've emitted too many tokens from the same frame, advance
- if emitted_tokens >= max_tokens_per_step {
- t += 1;
- emitted_tokens = 0;
- }
- }
-
- Ok((tokens, frame_indices, durations))
- }
-}
diff --git a/parakeet-rs/src/parakeet.rs b/parakeet-rs/src/parakeet.rs
deleted file mode 100644
index d2aabdd..0000000
--- a/parakeet-rs/src/parakeet.rs
+++ /dev/null
@@ -1,210 +0,0 @@
-use crate::audio;
-use crate::config::PreprocessorConfig;
-use crate::decoder::{ParakeetDecoder, TranscriptionResult};
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use crate::model::ParakeetModel;
-use crate::timestamps::{process_timestamps, TimestampMode};
-use std::path::{Path, PathBuf};
-
-pub struct Parakeet {
- model: ParakeetModel,
- decoder: ParakeetDecoder,
- preprocessor_config: PreprocessorConfig,
- model_dir: PathBuf,
-}
-
-impl Parakeet {
- /// Load Parakeet model from path with optional configuration.
- ///
- /// # Arguments
- /// * `path` - Directory containing model files, or path to specific model file
- /// * `config` - Optional execution configuration (defaults to CPU if None)
- ///
- /// # Examples
- /// ```no_run
- /// use parakeet_rs::Parakeet;
- ///
- /// // Load from directory with CPU (default)
- /// let parakeet = Parakeet::from_pretrained(".", None)?;
- ///
- /// // Or load from specific model file
- /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?;
- /// # Ok::<(), Box<dyn std::error::Error>>(())
- /// ```
- ///
- /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.)
- /// and pass an `ExecutionConfig` with the desired execution provider.
- pub fn from_pretrained<P: AsRef<Path>>(
- path: P,
- config: Option<ExecutionConfig>,
- ) -> Result<Self> {
- let path = path.as_ref();
-
- // Determine if path is a directory or file
- let (model_path, tokenizer_path, model_dir) = if path.is_dir() {
- // Directory mode: auto-detect model file
- let model_path = Self::find_model_file(path)?;
- let tokenizer_path = path.join("tokenizer.json");
- (model_path, tokenizer_path, path.to_path_buf())
- } else if path.is_file() {
- // File mode: path points directly to model file
- let model_dir = path
- .parent()
- .ok_or_else(|| Error::Config("Invalid model path".to_string()))?;
- let tokenizer_path = model_dir.join("tokenizer.json");
- (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf())
- } else {
- return Err(Error::Config(format!(
- "Path does not exist: {}",
- path.display()
- )));
- };
-
- // Check tokenizer exists
- if !tokenizer_path.exists() {
- return Err(Error::Config(format!(
- "Required file 'tokenizer.json' not found in {}",
- model_dir.display()
- )));
- }
-
- let preprocessor_config = PreprocessorConfig::default();
- let exec_config = config.unwrap_or_default();
-
- let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?;
- let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?;
-
- Ok(Self {
- model,
- decoder,
- preprocessor_config,
- model_dir,
- })
- }
-
- fn find_model_file(dir: &Path) -> Result<PathBuf> {
- // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx
- let candidates = [
- "model.onnx",
- "model_fp16.onnx",
- "model_int8.onnx",
- "model_q4.onnx",
- ];
-
- for candidate in &candidates {
- let path = dir.join(candidate);
- if path.exists() {
- return Ok(path);
- }
- }
-
- // If none of the standard names found, search for any .onnx file
- if let Ok(entries) = std::fs::read_dir(dir) {
- for entry in entries.flatten() {
- let path = entry.path();
- if path.extension().and_then(|s| s.to_str()) == Some("onnx") {
- return Ok(path);
- }
- }
- }
-
- Err(Error::Config(format!(
- "No model file (*.onnx) found in directory: {}",
- dir.display()
- )))
- }
-
- /// Transcribe audio samples.
- ///
- /// # Arguments
- ///
- /// * `audio` - Audio samples as f32 values
- /// * `sample_rate` - Sample rate in Hz
- /// * `channels` - Number of audio channels
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level.
- pub fn transcribe_samples(
- &mut self,
- audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
- let logits = self.model.forward(features)?;
-
- let mut result = self.decoder.decode_with_timestamps(
- &logits,
- self.preprocessor_config.hop_length,
- self.preprocessor_config.sampling_rate,
- )?;
-
- // Process timestamps to requested output mode
- let mode = mode.unwrap_or(TimestampMode::Tokens);
- result.tokens = process_timestamps(&result.tokens, mode);
-
- // Rebuild full text from processed tokens to ensure consistency
- result.text = result.tokens.iter()
- .map(|t| t.text.as_str())
- .collect::<Vec<_>>()
- .join(" ");
-
- Ok(result)
- }
-
- /// Transcribe an audio file with timestamps
- ///
- /// # Arguments
- ///
- /// * `audio_path` - A path to the audio file that needs to be transcribed.
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
- pub fn transcribe_file<P: AsRef<Path>>(
- &mut self,
- audio_path: P,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let audio_path = audio_path.as_ref();
- let (audio, spec) = audio::load_audio(audio_path)?;
-
- self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
- }
-
- /// Transcribes multiple audio files in batch.
- ///
- /// # Arguments
- ///
- /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
- pub fn transcribe_file_batch<P: AsRef<Path>>(
- &mut self,
- audio_paths: &[P],
- mode: Option<TimestampMode>,
- ) -> Result<Vec<TranscriptionResult>> {
- let mut results = Vec::with_capacity(audio_paths.len());
- for path in audio_paths {
- let result = self.transcribe_file(path, mode)?;
- results.push(result);
- }
- Ok(results)
- }
-
- pub fn model_dir(&self) -> &Path {
- &self.model_dir
- }
-
- pub fn preprocessor_config(&self) -> &PreprocessorConfig {
- &self.preprocessor_config
- }
-}
diff --git a/parakeet-rs/src/parakeet_eou.rs b/parakeet-rs/src/parakeet_eou.rs
deleted file mode 100644
index 25c7d64..0000000
--- a/parakeet-rs/src/parakeet_eou.rs
+++ /dev/null
@@ -1,304 +0,0 @@
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use crate::model_eou::{EncoderCache, ParakeetEOUModel};
-use ndarray::{s, Array2, Array3};
-use rustfft::{num_complex::Complex, FftPlanner};
-use std::collections::VecDeque;
-use std::f32::consts::PI;
-use std::path::Path;
-
-const SAMPLE_RATE: usize = 16000;
-
-const N_FFT: usize = 512;
-const WIN_LENGTH: usize = 400;
-const HOP_LENGTH: usize = 160;
-const N_MELS: usize = 128;
-const PREEMPH: f32 = 0.97;
-const LOG_ZERO_GUARD: f32 = 5.960464478e-8;
-const FMAX: f32 = 8000.0;
-
-/// Parakeet RealTime EOU model for streaming ASR with end-of-utterance detection.
-/// Uses cache-aware streaming with audio buffering for pre-encode context.
-pub struct ParakeetEOU {
- model: ParakeetEOUModel,
- tokenizer: tokenizers::Tokenizer,
- encoder_cache: EncoderCache,
- state_h: Array3<f32>,
- state_c: Array3<f32>,
- last_token: Array2<i32>,
- blank_id: i32,
- eou_id: i32,
- mel_basis: Array2<f32>,
- window: Vec<f32>,
- audio_buffer: VecDeque<f32>,
- buffer_size_samples: usize,
-}
-
-impl ParakeetEOU {
- /// Load Parakeet EOU model from path
- ///
- /// # Arguments
- /// * `path` - Directory containing encoder.onnx, decoder_joint.onnx, and tokenizer.json
- /// * `config` - Optional execution configuration (defaults to CPU if None)
- pub fn from_pretrained<P: AsRef<Path>>(path: P, config: Option<ExecutionConfig>) -> Result<Self> {
- let path = path.as_ref();
- let tokenizer_path = path.join("tokenizer.json");
- let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
- .map_err(|e| Error::Config(format!("Failed to load tokenizer: {e}")))?;
-
- let vocab_size = tokenizer.get_vocab_size(true);
- let blank_id = (vocab_size - 1) as i32;
- let blank_id = if blank_id < 1000 { 1026 } else { blank_id };
- let eou_id = tokenizer.token_to_id("<EOU>").map(|id| id as i32).unwrap_or(1024);
-
- let exec_config = config.unwrap_or_default();
- let model = ParakeetEOUModel::from_pretrained(path, exec_config)?;
-
- // Buffer size: 4 seconds of audio
- // Provides long history for feature extraction context
- // Note that, I pick those "magic numbers" by looking NeMo's ring buffer approach.
- let buffer_size_samples = SAMPLE_RATE * 4; // 4 seconds = 64000 samples
-
- Ok(Self {
- model,
- tokenizer,
- encoder_cache: EncoderCache::new(),
- state_h: Array3::zeros((1, 1, 640)),
- state_c: Array3::zeros((1, 1, 640)),
- last_token: Array2::from_elem((1, 1), blank_id),
- blank_id,
- eou_id,
- mel_basis: Self::create_mel_filterbank(),
- window: Self::create_window(),
- audio_buffer: VecDeque::with_capacity(buffer_size_samples),
- buffer_size_samples,
- })
- }
-
- /// Transcribe a chunk of audio samples.
- ///
- /// # Arguments
- /// * `chunk` - Audio chunk (typically 160ms / 2560 samples at 16kHz)
- /// * `reset_on_eou` - If true, reset decoder state when end-of-utterance is detected
- ///
- /// # Streaming Behavior
- /// Cache-aware streaming
- /// - Maintains 4-second ring buffer for feature extraction context
- /// - Extracts features from full buffer
- /// - Slices last (pre_encode_cache + new_frames) for encoder input
- /// - pre_encode_cache=9 frames, new_frames=~16, total=~25 frames to encoder
- pub fn transcribe(&mut self, chunk: &[f32], reset_on_eou: bool) -> Result<String> {
- // Add new chunk to rolling buffer
- self.audio_buffer.extend(chunk.iter().copied());
-
- // Trim buffer to keep only the most recent samples
- while self.audio_buffer.len() > self.buffer_size_samples {
- self.audio_buffer.pop_front();
- }
-
- // Wait until buffer has minimum samples (at least 1 second for stable features)
- const MIN_BUFFER_SAMPLES: usize = SAMPLE_RATE; // 1 second
- if self.audio_buffer.len() < MIN_BUFFER_SAMPLES {
- return Ok(String::new());
- }
-
- // Extract features from FULL buffer (provides context for feature extraction)
- let buffer_slice: Vec<f32> = self.audio_buffer.iter().copied().collect();
- let full_features = self.extract_mel_features(&buffer_slice);
- let total_frames = full_features.shape()[2];
-
- // Slice to take only (pre_encode_cache + new_frames) for encoder
- // pre_encode_cache = 9 frames, new_frames = ~16 for 160ms chunk
- const PRE_ENCODE_CACHE: usize = 9;
- const FRAMES_PER_CHUNK: usize = 16;
- const SLICE_LEN: usize = PRE_ENCODE_CACHE + FRAMES_PER_CHUNK;
-
- let start_frame = if total_frames > SLICE_LEN {
- total_frames - SLICE_LEN
- } else {
- 0
- };
-
- let features = full_features.slice(s![.., .., start_frame..]).to_owned();
- let time_steps = features.shape()[2];
-
- // Encode with cache - encoder sees full buffer context
- let (encoder_out, new_cache) = self.model.run_encoder(&features, time_steps as i64, &self.encoder_cache)?;
- self.encoder_cache = new_cache;
-
- let total_frames = encoder_out.shape()[2];
- if total_frames == 0 {
- return Ok(String::new());
- }
-
- // Process all output frames (typically 1 frame per chunk)
- let new_frames = encoder_out;
-
- let mut text_output = String::new();
-
- for t in 0..new_frames.shape()[2] {
- let current_frame = new_frames.slice(s![.., .., t..t + 1]).to_owned();
- let mut syms_added = 0;
-
- while syms_added < 5 {
- let (logits, new_h, new_c) = self.model.run_decoder(
- &current_frame,
- &self.last_token,
- &self.state_h,
- &self.state_c,
- )?;
-
- let vocab = logits.slice(s![0, 0, ..]);
-
- let mut max_idx = 0;
- let mut max_val = f32::NEG_INFINITY;
- for (i, &val) in vocab.iter().enumerate() {
- if val.is_finite() && val > max_val {
- max_val = val;
- max_idx = i as i32;
- }
- }
-
- if max_idx == self.blank_id || max_idx == 0 {
- break;
- }
-
- if max_idx == self.eou_id {
- if reset_on_eou {
- self.reset_states();
- return Ok(text_output + " [EOU]");
- }
- break;
- }
-
- if max_idx as usize >= self.tokenizer.get_vocab_size(true) {
- break;
- }
-
- self.state_h = new_h;
- self.state_c = new_c;
- self.last_token.fill(max_idx);
-
- if let Some(token) = self.tokenizer.id_to_token(max_idx as u32) {
- let clean = token.replace('▁', " ");
- text_output.push_str(&clean);
- }
- syms_added += 1;
- }
- }
- Ok(text_output)
- }
-
- fn reset_states(&mut self) {
- // Soft reset: Only reset decoder states
- // at this state, we need to keep encoder cache and audio buffer flowing for continuous context
- // self.encoder_cache = EncoderCache::new(); // DON'T reset!!!
- self.state_h.fill(0.0);
- self.state_c.fill(0.0);
- self.last_token.fill(self.blank_id);
- // self.audio_buffer.clear(); // DON'T clear!!
- }
-
- fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> {
- let audio_pre = Self::apply_preemphasis(audio);
- let spec = self.stft(&audio_pre);
- let mel = self.mel_basis.dot(&spec);
- let mel_log = mel.mapv(|x| (x.max(0.0) + LOG_ZERO_GUARD).ln());
- mel_log.insert_axis(ndarray::Axis(0))
- }
-
- fn apply_preemphasis(audio: &[f32]) -> Vec<f32> {
- let mut result = Vec::with_capacity(audio.len());
- if audio.is_empty() {
- return result;
- }
-
- let safe_x = |x: f32| if x.is_finite() { x } else { 0.0 };
-
- result.push(safe_x(audio[0]));
- for i in 1..audio.len() {
- result.push(safe_x(audio[i]) - PREEMPH * safe_x(audio[i - 1]));
- }
- result
- }
-
- fn stft(&self, audio: &[f32]) -> Array2<f32> {
- let mut planner = FftPlanner::<f32>::new();
- let fft = planner.plan_fft_forward(N_FFT);
-
- let pad_amount = N_FFT / 2;
- let mut padded_audio = vec![0.0; pad_amount];
- padded_audio.extend_from_slice(audio);
- padded_audio.extend(std::iter::repeat(0.0).take(pad_amount));
-
- let num_frames = 1 + (padded_audio.len().saturating_sub(WIN_LENGTH)) / HOP_LENGTH;
- let freq_bins = N_FFT / 2 + 1;
- let mut spec = Array2::zeros((freq_bins, num_frames));
-
- for frame_idx in 0..num_frames {
- let start = frame_idx * HOP_LENGTH;
- if start + WIN_LENGTH > padded_audio.len() {
- break;
- }
-
- let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT];
- for i in 0..WIN_LENGTH {
- buffer[i] = Complex::new(padded_audio[start + i] * self.window[i], 0.0);
- }
- fft.process(&mut buffer);
- for (i, val) in buffer.iter().take(freq_bins).enumerate() {
- let mag_sq = val.norm_sqr();
- spec[[i, frame_idx]] = if mag_sq.is_finite() { mag_sq } else { 0.0 };
- }
- }
- spec
- }
-
- fn create_window() -> Vec<f32> {
- (0..WIN_LENGTH)
- .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / ((WIN_LENGTH - 1) as f32)).cos())
- .collect()
- }
-
- fn create_mel_filterbank() -> Array2<f32> {
- let num_freqs = N_FFT / 2 + 1;
-
- let hz_to_mel = |hz: f32| 2595.0 * (1.0 + hz / 700.0).log10();
- let mel_to_hz = |mel: f32| 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0);
-
- let mel_min = hz_to_mel(0.0);
- let mel_max = hz_to_mel(FMAX);
-
- let mel_points: Vec<f32> = (0..=N_MELS + 1)
- .map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (N_MELS + 1) as f32))
- .collect();
-
- let fft_freqs: Vec<f32> = (0..num_freqs)
- .map(|i| (SAMPLE_RATE as f32 / N_FFT as f32) * i as f32)
- .collect();
-
- let mut weights = Array2::zeros((N_MELS, num_freqs));
-
- for i in 0..N_MELS {
- let left = mel_points[i];
- let center = mel_points[i + 1];
- let right = mel_points[i + 2];
- for (j, &freq) in fft_freqs.iter().enumerate() {
- if freq >= left && freq <= center {
- weights[[i, j]] = (freq - left) / (center - left);
- } else if freq > center && freq <= right {
- weights[[i, j]] = (right - freq) / (right - center);
- }
- }
- }
-
- for i in 0..N_MELS {
- let enorm = 2.0 / (mel_points[i + 2] - mel_points[i]);
- for j in 0..num_freqs {
- weights[[i, j]] *= enorm;
- }
- }
-
- weights
- }
-}
diff --git a/parakeet-rs/src/parakeet_tdt.rs b/parakeet-rs/src/parakeet_tdt.rs
deleted file mode 100644
index 719ae75..0000000
--- a/parakeet-rs/src/parakeet_tdt.rs
+++ /dev/null
@@ -1,167 +0,0 @@
-use crate::audio;
-use crate::config::PreprocessorConfig;
-use crate::decoder::TranscriptionResult;
-use crate::decoder_tdt::ParakeetTDTDecoder;
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use crate::model_tdt::ParakeetTDTModel;
-use crate::timestamps::{process_timestamps, TimestampMode};
-use crate::vocab::Vocabulary;
-use std::path::{Path, PathBuf};
-
-/// Parakeet TDT model for multilingual ASR
-pub struct ParakeetTDT {
- model: ParakeetTDTModel,
- decoder: ParakeetTDTDecoder,
- preprocessor_config: PreprocessorConfig,
- model_dir: PathBuf,
-}
-
-impl ParakeetTDT {
- /// Load Parakeet TDT model from path with optional configuration.
- ///
- /// # Arguments
- /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt
- /// * `config` - Optional execution configuration (defaults to CPU if None)
- pub fn from_pretrained<P: AsRef<Path>>(
- path: P,
- config: Option<ExecutionConfig>,
- ) -> Result<Self> {
- let path = path.as_ref();
-
- if !path.is_dir() {
- return Err(Error::Config(format!(
- "TDT model path must be a directory: {}",
- path.display()
- )));
- }
-
- let vocab_path = path.join("vocab.txt");
- if !vocab_path.exists() {
- return Err(Error::Config(format!(
- "vocab.txt not found in {}",
- path.display()
- )));
- }
-
- // TDT-specific preprocessor config (128 features instead of 80)
- let preprocessor_config = PreprocessorConfig {
- feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
- feature_size: 128,
- hop_length: 160,
- n_fft: 512,
- padding_side: "right".to_string(),
- padding_value: 0.0,
- preemphasis: 0.97,
- processor_class: "ParakeetProcessor".to_string(),
- return_attention_mask: true,
- sampling_rate: 16000,
- win_length: 400,
- };
-
- let exec_config = config.unwrap_or_default();
-
- let model = ParakeetTDTModel::from_pretrained(path, exec_config)?;
- let vocab = Vocabulary::from_file(&vocab_path)?;
- let decoder = ParakeetTDTDecoder::from_vocab(vocab);
-
- Ok(Self {
- model,
- decoder,
- preprocessor_config,
- model_dir: path.to_path_buf(),
- })
- }
-
- /// Transcribe audio samples.
- ///
- /// # Arguments
- ///
- /// * `audio` - Audio samples as f32 values
- /// * `sample_rate` - Sample rate in Hz
- /// * `channels` - Number of audio channels
- /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
- ///
- /// # Returns
- ///
- /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode.
- pub fn transcribe_samples(
- &mut self,
- audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
- let (tokens, frame_indices, durations) = self.model.forward(features)?;
-
- let mut result = self.decoder.decode_with_timestamps(
- &tokens,
- &frame_indices,
- &durations,
- self.preprocessor_config.hop_length,
- self.preprocessor_config.sampling_rate,
- )?;
-
- // Apply timestamp mode conversion
- let mode = mode.unwrap_or(TimestampMode::Tokens);
- result.tokens = process_timestamps(&result.tokens, mode);
-
- // Rebuild full text from processed tokens
- result.text = result.tokens.iter()
- .map(|t| t.text.as_str())
- .collect::<Vec<_>>()
- .join(" ");
-
- Ok(result)
- }
-
- /// Transcribe an audio file with timestamps
- ///
- /// # Arguments
- ///
- /// * `audio_path` - A path to the audio file that needs to be transcribed.
- /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
- pub fn transcribe_file<P: AsRef<Path>>(
- &mut self,
- audio_path: P,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let audio_path = audio_path.as_ref();
- let (audio, spec) = audio::load_audio(audio_path)?;
-
- self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
- }
-
- /// Transcribes multiple audio files in batch.
- ///
- /// # Arguments
- ///
- /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
- /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
- pub fn transcribe_file_batch<P: AsRef<Path>>(
- &mut self,
- audio_paths: &[P],
- mode: Option<TimestampMode>,
- ) -> Result<Vec<TranscriptionResult>> {
- let mut results = Vec::with_capacity(audio_paths.len());
- for path in audio_paths {
- let result = self.transcribe_file(path, mode)?;
- results.push(result);
- }
- Ok(results)
- }
-
- /// Get model directory path
- pub fn model_dir(&self) -> &Path {
- &self.model_dir
- }
-}
diff --git a/parakeet-rs/src/sortformer.rs b/parakeet-rs/src/sortformer.rs
deleted file mode 100644
index 2b1e5a3..0000000
--- a/parakeet-rs/src/sortformer.rs
+++ /dev/null
@@ -1,1062 +0,0 @@
-//! NVIDIA Sortformer v2 Streaming Speaker Diarization
-//!
-//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization.
-//!
-//! Key features:
-//! - Streaming inference with ~10s chunks (124 frames at 80ms each)
-//! - FIFO buffer for context management
-//! - Smart speaker cache compression (keeps important frames, not just recent)
-//! - Silence profile tracking
-//! - Post-processing: median filtering, hysteresis thresholding
-//! - Supports up to 4 speakers
-//!
-//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
-//! Note that, my ONNX export:
-//! CHUNK_LEN = 124
-//! FIFO_LEN = 124
-//! CACHE_LEN = 188
-//! FEAT_DIM = 128
-//! EMB_DIM = 512
-//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html
-
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig;
-use ndarray::{s, Array1, Array2, Array3, Axis};
-use ort::session::Session;
-use rustfft::{num_complex::Complex, FftPlanner};
-use std::f32::consts::PI;
-use std::path::Path;
-
-// Model constants
-const N_FFT: usize = 512;
-const WIN_LENGTH: usize = 400;
-const HOP_LENGTH: usize = 160;
-const N_MELS: usize = 128;
-const PREEMPH: f32 = 0.97;
-const LOG_ZERO_GUARD: f32 = 5.960464478e-8;
-const SAMPLE_RATE: usize = 16000;
-const FMIN: f32 = 0.0;
-const FMAX: f32 = 8000.0;
-
-// Streaming constants
-const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms)
-const FIFO_LEN: usize = 124; // FIFO buffer length
-const SPKCACHE_LEN: usize = 188; // Speaker cache length
-const SPKCACHE_UPDATE_PERIOD: usize = 124;
-const SUBSAMPLING: usize = 8; // Audio frames -> model frames
-const EMB_DIM: usize = 512; // Embedding dimension
-const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers
-const FRAME_DURATION: f32 = 0.08; // 80ms per frame
-
-// Cache compression params (from NeMo)
-const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3;
-const PRED_SCORE_THRESHOLD: f32 = 0.25;
-const STRONG_BOOST_RATE: f32 = 0.75;
-const WEAK_BOOST_RATE: f32 = 1.5;
-const MIN_POS_SCORES_RATE: f32 = 0.5;
-const SIL_THRESHOLD: f32 = 0.2;
-const MAX_INDEX: usize = 99999;
-
-/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs)
-///
-/// Controls how raw model predictions are converted into speaker segments.
-/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI).
-///
-/// # Parameters
-/// - `onset`: Probability threshold to START a speaker segment (higher = more strict)
-/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments)
-/// - `pad_onset`: Seconds to subtract from segment start times
-/// - `pad_offset`: Seconds to add to segment end times
-/// - `min_duration_on`: Minimum segment length in seconds (filters short blips)
-/// - `min_duration_off`: Minimum gap between segments before merging
-/// - `median_window`: Smoothing window size (odd number, higher = smoother)
-///
-/// # Pre-tuned Configs
-/// - `callhome()` - (default)
-/// - `dihard3()`
-///
-/// # Custom Config
-/// Use `custom(onset, offset)` to create your own config for fine-tuning.
-///
-/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer
-#[derive(Debug, Clone)]
-pub struct DiarizationConfig {
- pub onset: f32,
- pub offset: f32,
- pub pad_onset: f32,
- pub pad_offset: f32,
- pub min_duration_on: f32,
- pub min_duration_off: f32,
- pub median_window: usize,
-}
-
-impl Default for DiarizationConfig {
- fn default() -> Self {
- Self::callhome()
- }
-}
-
-impl DiarizationConfig {
- /// CallHome dataset config for v2 (default)
- /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml
- pub fn callhome() -> Self {
- Self {
- onset: 0.641,
- offset: 0.561,
- pad_onset: 0.229,
- pad_offset: 0.079,
- min_duration_on: 0.511,
- min_duration_off: 0.296,
- median_window: 11,
- }
- }
-
- /// DIHARD3 dataset config for v2
- /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml
- pub fn dihard3() -> Self {
- Self {
- onset: 0.56,
- offset: 1.0,
- pad_onset: 0.063,
- pad_offset: 0.002,
- min_duration_on: 0.007,
- min_duration_off: 0.151,
- median_window: 11,
- }
- }
-
- /// Create a custom config for fine-tuning diarization behavior.
- ///
- /// # Arguments
- /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7)
- /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6)
- ///
- /// # Example
- /// ```rust
- /// use parakeet_rs::sortformer::DiarizationConfig;
- ///
- /// // More sensitive detection (lower thresholds)
- /// let sensitive = DiarizationConfig::custom(0.5, 0.4);
- ///
- /// // Stricter detection (higher thresholds, fewer false positives)
- /// let strict = DiarizationConfig::custom(0.7, 0.6);
- ///
- /// // Full customization
- /// let mut config = DiarizationConfig::custom(0.6, 0.5);
- /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms
- /// config.median_window = 15; // More smoothing
- /// ```
- pub fn custom(onset: f32, offset: f32) -> Self {
- Self {
- onset,
- offset,
- pad_onset: 0.0,
- pad_offset: 0.0,
- min_duration_on: 0.1,
- min_duration_off: 0.1,
- median_window: 11,
- }
- }
-}
-
-/// Speaker segment with start time, end time, and speaker ID
-#[derive(Debug, Clone)]
-pub struct SpeakerSegment {
- pub start: f32,
- pub end: f32,
- pub speaker_id: usize,
-}
-
-/// Streaming Sortformer v2 speaker diarization engine
-pub struct Sortformer {
- session: Session,
- config: DiarizationConfig,
- // Streaming state. note that, Same way as Nemo
- spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM)
- spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS)
- fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM)
- fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS)
- mean_sil_emb: Array2<f32>, // (1, EMB_DIM)
- n_sil_frames: usize,
- // Mel filterbank (cached)
- mel_basis: Array2<f32>,
-}
-
-impl Sortformer {
- /// a new Sortformer instance from ONNX model path
- pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
- Self::with_config(model_path, None, DiarizationConfig::default())
- }
-
- /// Create with custom config
- pub fn with_config<P: AsRef<Path>>(
- model_path: P,
- execution_config: Option<ModelConfig>,
- config: DiarizationConfig,
- ) -> Result<Self> {
- let config_to_use = execution_config.unwrap_or_default();
-
- let session = config_to_use
- .apply_to_session_builder(Session::builder()?)?
- .commit_from_file(model_path.as_ref())?;
-
- let mel_basis = Self::create_mel_filterbank();
-
- let mut instance = Self {
- session,
- config,
- spkcache: Array3::zeros((1, 0, EMB_DIM)),
- spkcache_preds: None,
- fifo: Array3::zeros((1, 0, EMB_DIM)),
- fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)),
- mean_sil_emb: Array2::zeros((1, EMB_DIM)),
- n_sil_frames: 0,
- mel_basis,
- };
- instance.reset_state();
- Ok(instance)
- }
-
- /// Reset streaming state
- pub fn reset_state(&mut self) {
- self.spkcache = Array3::zeros((1, 0, EMB_DIM));
- self.spkcache_preds = None;
- self.fifo = Array3::zeros((1, 0, EMB_DIM));
- self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS));
- self.mean_sil_emb = Array2::zeros((1, EMB_DIM));
- self.n_sil_frames = 0;
- }
-
- /// Main diarization entry point
- pub fn diarize(
- &mut self,
- mut audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- ) -> Result<Vec<SpeakerSegment>> {
- // Resample if needed
- if sample_rate != SAMPLE_RATE as u32 {
- return Err(Error::Audio(format!(
- "Expected {} Hz, got {} Hz",
- SAMPLE_RATE, sample_rate
- )));
- }
-
- // Convert to mono
- if channels > 1 {
- audio = audio
- .chunks(channels as usize)
- .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
- .collect();
- }
-
- // Reset state for new audio
- self.reset_state();
-
- // Extract mel features (B, T, D)
- let features = self.extract_mel_features(&audio);
- let total_frames = features.shape()[1];
-
- // Process in chunks
- let chunk_stride = CHUNK_LEN * SUBSAMPLING;
- let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
-
- let mut all_chunk_preds = Vec::new();
-
- for chunk_idx in 0..num_chunks {
- let start = chunk_idx * chunk_stride;
- let end = (start + chunk_stride).min(total_frames);
- let current_len = end - start;
-
- // Extract chunk features
- let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
-
- // Pad last chunk if needed
- if current_len < chunk_stride {
- let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
- padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
- chunk_feat = padded;
- }
-
- // Run streaming update
- let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
- all_chunk_preds.push(chunk_preds);
- }
-
- // Concatenate all predictions
- let full_preds = Self::concat_predictions(&all_chunk_preds);
-
- // Apply median filtering
- let filtered_preds = if self.config.median_window > 1 {
- self.median_filter(&full_preds)
- } else {
- full_preds
- };
-
- // Binarize to segments
- let segments = self.binarize(&filtered_preds);
-
- Ok(segments)
- }
-
- /// Streaming diarization that maintains state across calls.
- ///
- /// Unlike `diarize()`, this method does NOT reset the internal state,
- /// allowing speaker embeddings to be preserved across multiple audio chunks.
- /// Call `reset_state()` manually when starting a new audio session.
- ///
- /// This enables consistent speaker identification across long audio streams
- /// by maintaining the speaker cache between processing windows.
- ///
- /// # Arguments
- /// * `audio` - Audio samples (will be converted to mono if multi-channel)
- /// * `sample_rate` - Must be 16000 Hz
- /// * `channels` - Number of audio channels
- ///
- /// # Example
- /// ```ignore
- /// // Start of session
- /// sortformer.reset_state();
- ///
- /// // Process sliding windows
- /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?;
- /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs
- /// ```
- pub fn diarize_streaming(
- &mut self,
- mut audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- ) -> Result<Vec<SpeakerSegment>> {
- // Resample if needed
- if sample_rate != SAMPLE_RATE as u32 {
- return Err(Error::Audio(format!(
- "Expected {} Hz, got {} Hz",
- SAMPLE_RATE, sample_rate
- )));
- }
-
- // Convert to mono
- if channels > 1 {
- audio = audio
- .chunks(channels as usize)
- .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
- .collect();
- }
-
- // NOTE: Unlike diarize(), we do NOT call reset_state() here
- // This preserves speaker embeddings across calls
-
- // Extract mel features (B, T, D)
- let features = self.extract_mel_features(&audio);
- let total_frames = features.shape()[1];
-
- // Process in chunks
- let chunk_stride = CHUNK_LEN * SUBSAMPLING;
- let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
-
- let mut all_chunk_preds = Vec::new();
-
- for chunk_idx in 0..num_chunks {
- let start = chunk_idx * chunk_stride;
- let end = (start + chunk_stride).min(total_frames);
- let current_len = end - start;
-
- // Extract chunk features
- let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
-
- // Pad last chunk if needed
- if current_len < chunk_stride {
- let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
- padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
- chunk_feat = padded;
- }
-
- // Run streaming update
- let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
- all_chunk_preds.push(chunk_preds);
- }
-
- // Concatenate all predictions
- let full_preds = Self::concat_predictions(&all_chunk_preds);
-
- // Apply median filtering
- let filtered_preds = if self.config.median_window > 1 {
- self.median_filter(&full_preds)
- } else {
- full_preds
- };
-
- // Binarize to segments
- let segments = self.binarize(&filtered_preds);
-
- Ok(segments)
- }
-
- /// NeMo's streaming_update with smart cache compression.
- /// Public to allow incremental streaming diarization.
- pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> {
- let spkcache_len = self.spkcache.shape()[1];
- let fifo_len = self.fifo.shape()[1];
-
- // Prepare inputs
- let chunk_lengths = Array1::from_vec(vec![current_len as i64]);
- let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]);
- let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]);
-
- // Prepare FIFO input
- let fifo_input = if fifo_len > 0 {
- self.fifo.clone()
- } else {
- Array3::zeros((1, 0, EMB_DIM))
- };
-
- // Prepare spkcache input (may be empty)
- let spkcache_input = if spkcache_len > 0 {
- self.spkcache.clone()
- } else {
- Array3::zeros((1, 0, EMB_DIM))
- };
-
- // Create input values
- let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?;
- let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?;
- let spkcache_value = ort::value::Value::from_array(spkcache_input)?;
- let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?;
- let fifo_value = ort::value::Value::from_array(fifo_input)?;
- let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?;
-
- // Run ONNX inference and extract all data in a block to release borrow
- let (preds, new_embs, chunk_len) = {
- let outputs = self.session.run(ort::inputs!(
- "chunk" => chunk_value,
- "chunk_lengths" => chunk_lengths_value,
- "spkcache" => spkcache_value,
- "spkcache_lengths" => spkcache_lengths_value,
- "fifo" => fifo_value,
- "fifo_lengths" => fifo_lengths_value
- ))?;
-
- // Extract outputs
- let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?;
- let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?;
-
- // Convert to ndarray
- let preds_dims = preds_shape.as_ref();
- let embs_dims = embs_shape.as_ref();
-
- let preds = Array3::from_shape_vec(
- (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize),
- preds_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?;
-
- let new_embs = Array3::from_shape_vec(
- (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize),
- embs_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?;
-
- // Calculate valid frames
- let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING;
-
- (preds, new_embs, valid_frames)
- };
-
- // Extract predictions for different parts
- let fifo_preds = if fifo_len > 0 {
- preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned()
- } else {
- Array2::zeros((0, NUM_SPEAKERS))
- };
-
- let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned();
- let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned();
-
- // Append chunk embeddings to FIFO
- self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0)));
-
- // Update FIFO predictions
- if fifo_len > 0 {
- let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds);
- self.fifo_preds = combined.insert_axis(Axis(0));
- } else {
- self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0));
- }
-
- let fifo_len_after = self.fifo.shape()[1];
-
- // Move from FIFO to cache when FIFO exceeds limit
- if fifo_len_after > FIFO_LEN {
- let mut pop_out_len = SPKCACHE_UPDATE_PERIOD;
- pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len);
- pop_out_len = pop_out_len.min(fifo_len_after);
-
- let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned();
- let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned();
-
- // Update silence profile
- self.update_silence_profile(&pop_out_embs, &pop_out_preds);
-
- // Remove from FIFO
- self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned();
- self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned();
-
- // Append to cache
- self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs);
-
- if let Some(ref cache_preds) = self.spkcache_preds {
- self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds));
- }
-
- // Smart compression when cache exceeds limit
- if self.spkcache.shape()[1] > SPKCACHE_LEN {
- if self.spkcache_preds.is_none() {
- // Initialize cache predictions from initial output
- let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned();
- let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds);
- self.spkcache_preds = Some(combined);
- }
-
- // Use smart compression
- self.compress_spkcache();
- }
- }
-
- Ok(chunk_preds)
- }
-
- /// Update mean silence embedding
- fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) {
- let preds_2d = preds.slice(s![0, .., ..]);
-
- for t in 0..preds_2d.shape()[0] {
- let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum();
- if sum < SIL_THRESHOLD {
- // This is a silence frame
- let emb = embs.slice(s![0, t, ..]);
-
- // Update running mean
- let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter()
- .map(|&x| x * self.n_sil_frames as f32)
- .collect();
-
- self.n_sil_frames += 1;
-
- for i in 0..EMB_DIM {
- self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32;
- }
- }
- }
- }
-
- /// Smart cache compression
- fn compress_spkcache(&mut self) {
- let cache_preds = match &self.spkcache_preds {
- Some(p) => p.clone(),
- None => return,
- };
-
- let n_frames = self.spkcache.shape()[1];
- let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK;
- let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize;
- let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize;
- let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize;
-
- // Calculate quality scores
- let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned();
- let mut scores = self.get_log_pred_scores(&preds_2d);
-
- // Disable low scores
- scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk);
-
- // Boost important frames
- scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0);
- scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0);
-
- // Add silence frames placeholder
- if SPKCACHE_SIL_FRAMES_PER_SPK > 0 {
- let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY);
- padded.slice_mut(s![..n_frames, ..]).assign(&scores);
- for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK {
- for j in 0..NUM_SPEAKERS {
- padded[[i, j]] = f32::INFINITY;
- }
- }
- scores = padded;
- }
-
- // Select top frames
- let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames);
-
- // Gather embeddings
- let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled);
-
- self.spkcache = new_embs;
- self.spkcache_preds = Some(new_preds);
- }
-
- /// Calculate quality scores
- fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> {
- let mut scores = Array2::zeros(preds.dim());
-
- for t in 0..preds.shape()[0] {
- let mut log_1_probs_sum = 0.0f32;
- for s in 0..NUM_SPEAKERS {
- let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
- let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
- log_1_probs_sum += log_1_p;
- }
-
- for s in 0..NUM_SPEAKERS {
- let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
- let log_p = p.ln();
- let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
- scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln();
- }
- }
-
- scores
- }
-
- /// Disable non-speech and overlapped speech
- fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> {
- // Count positive scores per speaker
- let mut pos_count = vec![0usize; NUM_SPEAKERS];
- for t in 0..scores.shape()[0] {
- for s in 0..NUM_SPEAKERS {
- if scores[[t, s]] > 0.0 {
- pos_count[s] += 1;
- }
- }
- }
-
- for t in 0..preds.shape()[0] {
- for s in 0..NUM_SPEAKERS {
- let is_speech = preds[[t, s]] > 0.5;
-
- if !is_speech {
- scores[[t, s]] = f32::NEG_INFINITY;
- } else {
- let is_pos = scores[[t, s]] > 0.0;
- if !is_pos && pos_count[s] >= min_pos_scores_per_spk {
- scores[[t, s]] = f32::NEG_INFINITY;
- }
- }
- }
- }
-
- scores
- }
-
- /// Boost top K frames per speaker
- fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> {
- for s in 0..NUM_SPEAKERS {
- // Get column for this speaker
- let col: Vec<(usize, f32)> = (0..scores.shape()[0])
- .map(|t| (t, scores[[t, s]]))
- .collect();
-
- // Sort by score descending
- let mut sorted = col.clone();
- sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
-
- // Boost top K
- for i in 0..n_boost_per_spk.min(sorted.len()) {
- let t = sorted[i].0;
- if scores[[t, s]] != f32::NEG_INFINITY {
- scores[[t, s]] -= scale_factor * 0.5f32.ln();
- }
- }
- }
-
- scores
- }
-
- /// Get indices of top frames
- fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) {
- let n_frames = scores.shape()[0];
-
- // Flatten scores as (S, T) then reshape to (S*T,)
- // This means we iterate: speaker 0 all times, then speaker 1 all times, etc.
- // flat_index = speaker * n_frames + time
- let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS);
- for s in 0..NUM_SPEAKERS {
- for t in 0..n_frames {
- let flat_idx = s * n_frames + t;
- flat_scores.push((flat_idx, scores[[t, s]]));
- }
- }
-
- // Sort by score descending to get top-K
- flat_scores.sort_by(|a, b| {
- b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
- });
-
- // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX
- let mut topk_flat: Vec<usize> = flat_scores
- .iter()
- .take(SPKCACHE_LEN)
- .map(|(idx, score)| {
- if *score == f32::NEG_INFINITY {
- MAX_INDEX
- } else {
- *idx
- }
- })
- .collect();
-
- // Sort flat indices ascending (this puts MAX_INDEX at the end)
- topk_flat.sort();
-
- // Compute is_disabled and convert to frame indices
- let mut is_disabled = vec![false; SPKCACHE_LEN];
- let mut frame_indices = vec![0usize; SPKCACHE_LEN];
-
- for (i, &flat_idx) in topk_flat.iter().enumerate() {
- if flat_idx == MAX_INDEX {
- // Invalid entries are disabled
- is_disabled[i] = true;
- frame_indices[i] = 0; // We set disabled to 0
- } else {
- // convert to frame index
- let frame_idx = flat_idx % n_frames;
-
- // check if frame is beyond valid range
- if frame_idx >= n_frames_no_sil {
- is_disabled[i] = true;
- frame_indices[i] = 0; // same as abov: set disabled to 0
- } else {
- frame_indices[i] = frame_idx;
- }
- }
- }
-
- (frame_indices, is_disabled)
- }
-
- /// Gather selected frames
- fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) {
- let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM));
- let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS));
-
- let cache_preds = self.spkcache_preds.as_ref().unwrap();
-
- for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() {
- if i >= SPKCACHE_LEN {
- break;
- }
-
- if disabled {
- // Use silence embedding
- new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..]));
- // Predictions stay zero
- } else if idx < self.spkcache.shape()[1] {
- new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..]));
- new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..]));
- }
- }
-
- (new_embs, new_preds)
- }
-
- /// Concatenate along axis 1 for 3D arrays
- fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
- if a.shape()[1] == 0 {
- return b.clone();
- }
- if b.shape()[1] == 0 {
- return a.clone();
- }
- ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap()
- }
-
- /// Concatenate along axis 0 for 2D arrays
- fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> {
- if a.shape()[0] == 0 {
- return b.clone();
- }
- if b.shape()[0] == 0 {
- return a.clone();
- }
- ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
- }
-
- /// Concatenate predictions
- fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> {
- if preds.is_empty() {
- return Array2::zeros((0, NUM_SPEAKERS));
- }
- if preds.len() == 1 {
- return preds[0].clone();
- }
-
- let views: Vec<_> = preds.iter().map(|p| p.view()).collect();
- ndarray::concatenate(Axis(0), &views).unwrap()
- }
-
- /// Apply median filter to predictions
- fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> {
- let window = self.config.median_window;
- let half = window / 2;
- let mut filtered = preds.clone();
-
- for spk in 0..NUM_SPEAKERS {
- for t in 0..preds.shape()[0] {
- let start = t.saturating_sub(half);
- let end = (t + half + 1).min(preds.shape()[0]);
-
- let mut values: Vec<f32> = (start..end)
- .map(|i| preds[[i, spk]])
- .collect();
- values.sort_by(|a, b| a.partial_cmp(b).unwrap());
-
- filtered[[t, spk]] = values[values.len() / 2];
- }
- }
-
- filtered
- }
-
- /// Binarize predictions to segments (padding applied during thresholding)
- fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> {
- let mut segments = Vec::new();
- let num_frames = preds.shape()[0];
-
- for spk in 0..NUM_SPEAKERS {
- let mut in_seg = false;
- let mut seg_start = 0;
- let mut temp_segments = Vec::new();
-
- for t in 0..num_frames {
- let p = preds[[t, spk]];
-
- if p >= self.config.onset && !in_seg {
- in_seg = true;
- seg_start = t;
- } else if p < self.config.offset && in_seg {
- in_seg = false;
-
- // Apply padding during conversion
- let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
- let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset;
-
- if end_t - start_t >= self.config.min_duration_on {
- temp_segments.push(SpeakerSegment {
- start: start_t,
- end: end_t,
- speaker_id: spk,
- });
- }
- }
- }
-
- // Handle segment at end
- if in_seg {
- let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
- let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset;
-
- if end_t - start_t >= self.config.min_duration_on {
- temp_segments.push(SpeakerSegment {
- start: start_t,
- end: end_t,
- speaker_id: spk,
- });
- }
- }
-
- // Merge close segments (min_duration_off)
- if temp_segments.len() > 1 {
- let mut filtered = vec![temp_segments[0].clone()];
- for seg in temp_segments.into_iter().skip(1) {
- let last = filtered.last_mut().unwrap();
- let gap = seg.start - last.end;
- if gap < self.config.min_duration_off {
- last.end = seg.end; // Merge
- } else {
- filtered.push(seg);
- }
- }
- segments.extend(filtered);
- } else {
- segments.extend(temp_segments);
- }
- }
-
- // Sort by start time
- segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap());
- segments
- }
-
-
- fn apply_preemphasis(audio: &[f32]) -> Vec<f32> {
- let mut result = Vec::with_capacity(audio.len());
- result.push(audio[0]);
- for i in 1..audio.len() {
- result.push(audio[i] - PREEMPH * audio[i - 1]);
- }
- result
- }
-
- fn hann_window(window_length: usize) -> Vec<f32> {
- // Librosa uses periodic window (fftbins=True): divide by N, not N-1
- (0..window_length)
- .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos())
- .collect()
- }
-
- fn stft(audio: &[f32]) -> Array2<f32> {
- let mut planner = FftPlanner::<f32>::new();
- let fft = planner.plan_fft_forward(N_FFT);
-
- // Create Hann window of length win_length, then zero-pad to n_fft (centered)
- // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft)
- let hann = Self::hann_window(WIN_LENGTH);
- let win_offset = (N_FFT - WIN_LENGTH) / 2;
- let mut fft_window = vec![0.0f32; N_FFT];
- for i in 0..WIN_LENGTH {
- fft_window[win_offset + i] = hann[i];
- }
-
- // Pad signal for center=True (like librosa/torch.stft)
- // Padding is n_fft // 2 on each side
- let pad_amount = N_FFT / 2;
- let mut padded_audio = vec![0.0; pad_amount];
- padded_audio.extend_from_slice(audio);
- padded_audio.extend(vec![0.0; pad_amount]);
-
- let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1;
- let freq_bins = N_FFT / 2 + 1;
- let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
-
- for frame_idx in 0..num_frames {
- let start = frame_idx * HOP_LENGTH;
- let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT];
-
- // Extract n_fft samples and multiply by zero-padded window
- for i in 0..N_FFT {
- if start + i < padded_audio.len() {
- frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0);
- }
- }
-
- fft.process(&mut frame);
- for k in 0..freq_bins {
- let magnitude = frame[k].norm();
- // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0
- spectrogram[[k, frame_idx]] = magnitude * magnitude;
- }
- }
-
- spectrogram
- }
-
- // Librosa's Slaney mel scale (htk=False, which is the default)
- fn hz_to_mel_slaney(hz: f64) -> f64 {
- let f_min = 0.0;
- let f_sp = 200.0 / 3.0;
- let min_log_hz = 1000.0;
- let min_log_mel = (min_log_hz - f_min) / f_sp;
- let logstep = (6.4f64).ln() / 27.0;
-
- if hz >= min_log_hz {
- min_log_mel + (hz / min_log_hz).ln() / logstep
- } else {
- (hz - f_min) / f_sp
- }
- }
-
- fn mel_to_hz_slaney(mel: f64) -> f64 {
- let f_min = 0.0;
- let f_sp = 200.0 / 3.0;
- let min_log_hz = 1000.0;
- let min_log_mel = (min_log_hz - f_min) / f_sp;
- let logstep = (6.4f64).ln() / 27.0;
-
- if mel >= min_log_mel {
- min_log_hz * (logstep * (mel - min_log_mel)).exp()
- } else {
- f_min + f_sp * mel
- }
- }
-
- fn create_mel_filterbank() -> Array2<f32> {
- // lets use f64 for intermediate calculations to avoid precision loss
- let freq_bins = N_FFT / 2 + 1;
- let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins));
-
- // FFT frequencies: fftfreqs[k] = k * sr / n_fft
- let fftfreqs: Vec<f64> = (0..freq_bins)
- .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64)
- .collect();
-
- // Mel center frequencies using Slaney scale (librosa default, htk=False)
- let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64);
- let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64);
- let mel_f: Vec<f64> = (0..=N_MELS + 1)
- .map(|i| {
- let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64;
- Self::mel_to_hz_slaney(mel)
- })
- .collect();
-
- // Differences between consecutive mel frequencies
- let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect();
-
- // Compute filterbank weights (reference: librosa's ramp method)
- // https://librosa.org/doc/main/generated/librosa.stft.html
- for i in 0..N_MELS {
- for k in 0..freq_bins {
- // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i]
- let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i];
- // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1]
- let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1];
- // Weight is max(0, min(lower, upper))
- filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32;
- }
- }
-
- // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i])
- for i in 0..N_MELS {
- let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]);
- for k in 0..freq_bins {
- filterbank[[i, k]] *= enorm as f32;
- }
- }
-
- filterbank
- }
-
- fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> {
- // 1. Add dither (small random noise to prevent log(0))
- // NeMo uses dither=1e-5, but for determinism we skip random noise
- // The log_zero_guard handles zero values
-
- // 2. Apply preemphasis (NeMo uses preemph=0.97)
- let preemphasized = Self::apply_preemphasis(audio);
-
- // 3. STFT
- let spectrogram = Self::stft(&preemphasized);
-
- // 4. Apply mel filterbank (with Slaney normalization)
- let mel_spec = self.mel_basis.dot(&spectrogram);
-
- // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24)
- // NeMo uses normalize='NA' which means NO normalization
- let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln());
-
- let num_frames = log_mel_spec.shape()[1];
- let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS));
-
- // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D)
- for t in 0..num_frames {
- for m in 0..N_MELS {
- features[[0, t, m]] = log_mel_spec[[m, t]];
- }
- }
-
- features
- }
-}
diff --git a/parakeet-rs/src/timestamps.rs b/parakeet-rs/src/timestamps.rs
deleted file mode 100644
index 81ea600..0000000
--- a/parakeet-rs/src/timestamps.rs
+++ /dev/null
@@ -1,280 +0,0 @@
-use crate::decoder::TimedToken;
-
-/// Timestamp output mode for transcription results
-///
-/// Determines how token-level timestamps are grouped and presented:
-/// - `Tokens`: Raw token-level output from the model (most detailed)
-/// - `Words`: Tokens grouped into individual words
-/// - `Sentences`: Tokens grouped by sentence boundaries (., ?, !)
-///
-/// # Model-Specific Recommendations
-///
-/// - **Parakeet CTC (English)**: Use `Words` mode. The CTC model only outputs lowercase
-/// alphabet without punctuation, so sentence segmentation is not possible.
-/// - **Parakeet TDT (Multilingual)**: Use `Sentences` mode. The TDT model predicts
-/// punctuation, enabling natural sentence boundaries.
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
-pub enum TimestampMode {
- /// Raw token-level timestamps from the model
- Tokens,
- /// Word-level timestamps (groups subword tokens)
- Words,
- /// Sentence-level timestamps (groups by punctuation)
- ///
- /// Note: Only works with models that predict punctuation (e.g., Parakeet TDT).
- /// CTC models don't predict punctuation, so use `Words` mode instead.
- Sentences,
-}
-
-impl Default for TimestampMode {
- fn default() -> Self {
- Self::Tokens
- }
-}
-
-/// Convert token timestamps to the requested output mode
-///
-/// Takes raw token-level timestamps from the model and optionally groups them
-/// into words or sentences while preserving the original timing information.
-///
-/// # Arguments
-///
-/// * `tokens` - Raw token-level timestamps from model output
-/// * `mode` - Desired grouping level (Tokens, Words, or Sentences)
-///
-/// # Returns
-///
-/// Vector of TimedToken with timestamps at the requested granularity
-pub fn process_timestamps(tokens: &[TimedToken], mode: TimestampMode) -> Vec<TimedToken> {
- match mode {
- TimestampMode::Tokens => tokens.to_vec(),
- TimestampMode::Words => group_by_words(tokens),
- TimestampMode::Sentences => group_by_sentences(tokens),
- }
-}
-
-// Group tokens into words based on word boundary markers
-fn group_by_words(tokens: &[TimedToken]) -> Vec<TimedToken> {
- if tokens.is_empty() {
- return Vec::new();
- }
-
- let mut words = Vec::new();
- let mut current_word_text = String::new();
- let mut current_word_start = 0.0;
- let mut last_word_lower = String::new();
-
- for (i, token) in tokens.iter().enumerate() {
- // Skip empty tokens
- if token.text.trim().is_empty() {
- continue;
- }
-
- // Check if this starts a new word (SentencePiece uses ▁ or space prefix)
- // Also treat PURE punctuation marks (like ".", ",") as separate words
- // But NOT contractions like "'re" or "'s" which should attach to previous word
- let is_pure_punctuation = !token.text.is_empty() &&
- token.text.chars().all(|c| c.is_ascii_punctuation());
-
- // Check if this is a contraction suffix
- // These should NOT start a new word - they attach to the previous word
- let token_without_marker = token.text.trim_start_matches('▁').trim_start_matches(' ');
- let is_contraction = token_without_marker.starts_with('\'');
-
- let starts_word = (token.text.starts_with('▁')
- || token.text.starts_with(' ')
- || is_pure_punctuation)
- && !is_contraction
- || i == 0;
-
- if starts_word && !current_word_text.is_empty() {
- // Save previous word (with deduplication)
- let word_lower = current_word_text.to_lowercase();
- if word_lower != last_word_lower {
- words.push(TimedToken {
- text: current_word_text.clone(),
- start: current_word_start,
- end: tokens[i - 1].end,
- });
- last_word_lower = word_lower;
- }
- current_word_text.clear();
- }
-
- // Start new word or append to current
- if current_word_text.is_empty() {
- current_word_start = token.start;
- }
-
- // Add token text, removing word boundary markers
- let token_text = token
- .text
- .trim_start_matches('▁')
- .trim_start_matches(' ');
- current_word_text.push_str(token_text);
- }
-
- // Add final word
- if !current_word_text.is_empty() {
- let word_lower = current_word_text.to_lowercase();
- if word_lower != last_word_lower {
- words.push(TimedToken {
- text: current_word_text,
- start: current_word_start,
- end: tokens.last().unwrap().end,
- });
- }
- }
-
- words
-}
-
-// Group words into sentences based on punctuation
-fn group_by_sentences(tokens: &[TimedToken]) -> Vec<TimedToken> {
- // First get word-level grouping
- let words = group_by_words(tokens);
- if words.is_empty() {
- return Vec::new();
- }
-
- let mut sentences = Vec::new();
- let mut current_sentence = Vec::new();
-
- for word in words {
- current_sentence.push(word.clone());
-
- // Check if word ends with sentence terminator
- let ends_sentence = word.text.contains('.')
- || word.text.contains('?')
- || word.text.contains('!');
-
- if ends_sentence {
- let sentence_text = format_sentence(&current_sentence);
- let start = current_sentence.first().unwrap().start;
- let end = current_sentence.last().unwrap().end;
-
- if !sentence_text.is_empty() {
- sentences.push(TimedToken {
- text: sentence_text,
- start,
- end,
- });
- }
- current_sentence.clear();
- }
- }
-
- // Add final sentence if exists
- if !current_sentence.is_empty() {
- let sentence_text = format_sentence(&current_sentence);
- let start = current_sentence.first().unwrap().start;
- let end = current_sentence.last().unwrap().end;
-
- if !sentence_text.is_empty() {
- sentences.push(TimedToken {
- text: sentence_text,
- start,
- end,
- });
- }
- }
-
- sentences
-}
-
-// Join words with punctuation spacing
-fn format_sentence(words: &[TimedToken]) -> String {
- let result: Vec<&str> = words.iter().map(|w| w.text.as_str()).collect();
-
- // Join words, but don't add space before certain punctuation
- let mut output = String::new();
- for (i, word) in result.iter().enumerate() {
- // Check if this word is standalone punctuation that shouldn't have space before it
- // Contractions like "'re" or "'s" should have spaces before them
- let is_standalone_punct = word.len() == 1 &&
- word.chars().all(|c| matches!(c, '.' | ',' | '!' | '?' | ';' | ':' | ')'));
-
- if i > 0 && !is_standalone_punct {
- output.push(' ');
- }
- output.push_str(word);
- }
- output
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_word_grouping() {
- let tokens = vec![
- TimedToken {
- text: "▁Hello".to_string(),
- start: 0.0,
- end: 0.5,
- },
- TimedToken {
- text: "▁world".to_string(),
- start: 0.5,
- end: 1.0,
- },
- ];
-
- let words = group_by_words(&tokens);
- assert_eq!(words.len(), 2);
- assert_eq!(words[0].text, "Hello");
- assert_eq!(words[1].text, "world");
- }
-
- #[test]
- fn test_sentence_grouping() {
- let tokens = vec![
- TimedToken {
- text: "▁Hello".to_string(),
- start: 0.0,
- end: 0.5,
- },
- TimedToken {
- text: "▁world".to_string(),
- start: 0.5,
- end: 1.0,
- },
- TimedToken {
- text: ".".to_string(),
- start: 1.0,
- end: 1.1,
- },
- ];
-
- let sentences = group_by_sentences(&tokens);
- assert_eq!(sentences.len(), 1);
- assert_eq!(sentences[0].text, "Hello world.");
- assert_eq!(sentences[0].start, 0.0);
- assert_eq!(sentences[0].end, 1.1);
- }
-
- #[test]
- fn test_repetition_preservation() {
- let words = vec![
- TimedToken {
- text: "uh".to_string(),
- start: 0.0,
- end: 0.5,
- },
- TimedToken {
- text: "uh".to_string(),
- start: 0.5,
- end: 1.0,
- },
- TimedToken {
- text: "hello".to_string(),
- start: 1.0,
- end: 1.5,
- },
- ];
-
- let result = format_sentence(&words);
- assert_eq!(result, "uh uh hello");
- }
-}
diff --git a/parakeet-rs/src/vocab.rs b/parakeet-rs/src/vocab.rs
deleted file mode 100644
index 888568e..0000000
--- a/parakeet-rs/src/vocab.rs
+++ /dev/null
@@ -1,63 +0,0 @@
-use crate::error::{Error, Result};
-use std::fs::File;
-use std::io::{BufRead, BufReader};
-use std::path::Path;
-
-/// Vocabulary parser for vocab.txt format used by TDT models
-#[derive(Debug, Clone)]
-pub struct Vocabulary {
- pub id_to_token: Vec<String>,
- pub _blank_id: usize,
-}
-
-impl Vocabulary {
- /// Load vocabulary from vocab.txt file
- pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
- let file = File::open(path.as_ref()).map_err(|e| {
- Error::Config(format!("Failed to open vocab file: {}", e))
- })?;
-
- let reader = BufReader::new(file);
- let mut id_to_token = Vec::new();
- let mut blank_id = 0;
-
- for line in reader.lines() {
- let line = line.map_err(|e| {
- Error::Config(format!("Failed to read vocab file: {}", e))
- })?;
-
- let parts: Vec<&str> = line.splitn(2, ' ').collect();
- if parts.len() == 2 {
- let token = parts[0].to_string();
- let id: usize = parts[1].parse().map_err(|e| {
- Error::Config(format!("Invalid token ID in vocab: {}", e))
- })?;
-
- if id >= id_to_token.len() {
- id_to_token.resize(id + 1, String::new());
- }
- id_to_token[id] = token.clone();
-
- // Track blank token
- if token == "<blk>" || token == "<blank>" {
- blank_id = id;
- }
- }
- }
-
- // Default to last token if no blank found
- if blank_id == 0 && !id_to_token.is_empty() {
- blank_id = id_to_token.len() - 1;
- }
-
- Ok(Self {
- id_to_token,
- _blank_id: blank_id,
- })
- }
-
- /// Get token by ID
- pub fn id_to_text(&self, id: usize) -> Option<&str> {
- self.id_to_token.get(id).map(|s| s.as_str())
- }
-}