From 55cacf6e1a087c0fa6950a1ddeb09060f787e541 Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 21 Dec 2025 00:40:04 +0000 Subject: Add EOU detection and streaming diarization --- parakeet-rs/src/sortformer.rs | 1062 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1062 insertions(+) create mode 100644 parakeet-rs/src/sortformer.rs (limited to 'parakeet-rs/src/sortformer.rs') diff --git a/parakeet-rs/src/sortformer.rs b/parakeet-rs/src/sortformer.rs new file mode 100644 index 0000000..2b1e5a3 --- /dev/null +++ b/parakeet-rs/src/sortformer.rs @@ -0,0 +1,1062 @@ +//! NVIDIA Sortformer v2 Streaming Speaker Diarization +//! +//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization. +//! +//! Key features: +//! - Streaming inference with ~10s chunks (124 frames at 80ms each) +//! - FIFO buffer for context management +//! - Smart speaker cache compression (keeps important frames, not just recent) +//! - Silence profile tracking +//! - Post-processing: median filtering, hysteresis thresholding +//! - Supports up to 4 speakers +//! +//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 +//! Note that, my ONNX export: +//! CHUNK_LEN = 124 +//! FIFO_LEN = 124 +//! CACHE_LEN = 188 +//! FEAT_DIM = 128 +//! EMB_DIM = 512 +//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html + +use crate::error::{Error, Result}; +use crate::execution::ModelConfig; +use ndarray::{s, Array1, Array2, Array3, Axis}; +use ort::session::Session; +use rustfft::{num_complex::Complex, FftPlanner}; +use std::f32::consts::PI; +use std::path::Path; + +// Model constants +const N_FFT: usize = 512; +const WIN_LENGTH: usize = 400; +const HOP_LENGTH: usize = 160; +const N_MELS: usize = 128; +const PREEMPH: f32 = 0.97; +const LOG_ZERO_GUARD: f32 = 5.960464478e-8; +const SAMPLE_RATE: usize = 16000; +const FMIN: f32 = 0.0; +const FMAX: f32 = 8000.0; + +// Streaming constants +const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms) +const FIFO_LEN: usize = 124; // FIFO buffer length +const SPKCACHE_LEN: usize = 188; // Speaker cache length +const SPKCACHE_UPDATE_PERIOD: usize = 124; +const SUBSAMPLING: usize = 8; // Audio frames -> model frames +const EMB_DIM: usize = 512; // Embedding dimension +const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers +const FRAME_DURATION: f32 = 0.08; // 80ms per frame + +// Cache compression params (from NeMo) +const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3; +const PRED_SCORE_THRESHOLD: f32 = 0.25; +const STRONG_BOOST_RATE: f32 = 0.75; +const WEAK_BOOST_RATE: f32 = 1.5; +const MIN_POS_SCORES_RATE: f32 = 0.5; +const SIL_THRESHOLD: f32 = 0.2; +const MAX_INDEX: usize = 99999; + +/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs) +/// +/// Controls how raw model predictions are converted into speaker segments. +/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI). +/// +/// # Parameters +/// - `onset`: Probability threshold to START a speaker segment (higher = more strict) +/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments) +/// - `pad_onset`: Seconds to subtract from segment start times +/// - `pad_offset`: Seconds to add to segment end times +/// - `min_duration_on`: Minimum segment length in seconds (filters short blips) +/// - `min_duration_off`: Minimum gap between segments before merging +/// - `median_window`: Smoothing window size (odd number, higher = smoother) +/// +/// # Pre-tuned Configs +/// - `callhome()` - (default) +/// - `dihard3()` +/// +/// # Custom Config +/// Use `custom(onset, offset)` to create your own config for fine-tuning. +/// +/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer +#[derive(Debug, Clone)] +pub struct DiarizationConfig { + pub onset: f32, + pub offset: f32, + pub pad_onset: f32, + pub pad_offset: f32, + pub min_duration_on: f32, + pub min_duration_off: f32, + pub median_window: usize, +} + +impl Default for DiarizationConfig { + fn default() -> Self { + Self::callhome() + } +} + +impl DiarizationConfig { + /// CallHome dataset config for v2 (default) + /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml + pub fn callhome() -> Self { + Self { + onset: 0.641, + offset: 0.561, + pad_onset: 0.229, + pad_offset: 0.079, + min_duration_on: 0.511, + min_duration_off: 0.296, + median_window: 11, + } + } + + /// DIHARD3 dataset config for v2 + /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml + pub fn dihard3() -> Self { + Self { + onset: 0.56, + offset: 1.0, + pad_onset: 0.063, + pad_offset: 0.002, + min_duration_on: 0.007, + min_duration_off: 0.151, + median_window: 11, + } + } + + /// Create a custom config for fine-tuning diarization behavior. + /// + /// # Arguments + /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7) + /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6) + /// + /// # Example + /// ```rust + /// use parakeet_rs::sortformer::DiarizationConfig; + /// + /// // More sensitive detection (lower thresholds) + /// let sensitive = DiarizationConfig::custom(0.5, 0.4); + /// + /// // Stricter detection (higher thresholds, fewer false positives) + /// let strict = DiarizationConfig::custom(0.7, 0.6); + /// + /// // Full customization + /// let mut config = DiarizationConfig::custom(0.6, 0.5); + /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms + /// config.median_window = 15; // More smoothing + /// ``` + pub fn custom(onset: f32, offset: f32) -> Self { + Self { + onset, + offset, + pad_onset: 0.0, + pad_offset: 0.0, + min_duration_on: 0.1, + min_duration_off: 0.1, + median_window: 11, + } + } +} + +/// Speaker segment with start time, end time, and speaker ID +#[derive(Debug, Clone)] +pub struct SpeakerSegment { + pub start: f32, + pub end: f32, + pub speaker_id: usize, +} + +/// Streaming Sortformer v2 speaker diarization engine +pub struct Sortformer { + session: Session, + config: DiarizationConfig, + // Streaming state. note that, Same way as Nemo + spkcache: Array3, // (1, 0..SPKCACHE_LEN, EMB_DIM) + spkcache_preds: Option>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS) + fifo: Array3, // (1, 0..FIFO_LEN, EMB_DIM) + fifo_preds: Array3, // (1, 0..FIFO_LEN, NUM_SPEAKERS) + mean_sil_emb: Array2, // (1, EMB_DIM) + n_sil_frames: usize, + // Mel filterbank (cached) + mel_basis: Array2, +} + +impl Sortformer { + /// a new Sortformer instance from ONNX model path + pub fn new>(model_path: P) -> Result { + Self::with_config(model_path, None, DiarizationConfig::default()) + } + + /// Create with custom config + pub fn with_config>( + model_path: P, + execution_config: Option, + config: DiarizationConfig, + ) -> Result { + let config_to_use = execution_config.unwrap_or_default(); + + let session = config_to_use + .apply_to_session_builder(Session::builder()?)? + .commit_from_file(model_path.as_ref())?; + + let mel_basis = Self::create_mel_filterbank(); + + let mut instance = Self { + session, + config, + spkcache: Array3::zeros((1, 0, EMB_DIM)), + spkcache_preds: None, + fifo: Array3::zeros((1, 0, EMB_DIM)), + fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)), + mean_sil_emb: Array2::zeros((1, EMB_DIM)), + n_sil_frames: 0, + mel_basis, + }; + instance.reset_state(); + Ok(instance) + } + + /// Reset streaming state + pub fn reset_state(&mut self) { + self.spkcache = Array3::zeros((1, 0, EMB_DIM)); + self.spkcache_preds = None; + self.fifo = Array3::zeros((1, 0, EMB_DIM)); + self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS)); + self.mean_sil_emb = Array2::zeros((1, EMB_DIM)); + self.n_sil_frames = 0; + } + + /// Main diarization entry point + pub fn diarize( + &mut self, + mut audio: Vec, + sample_rate: u32, + channels: u16, + ) -> Result> { + // Resample if needed + if sample_rate != SAMPLE_RATE as u32 { + return Err(Error::Audio(format!( + "Expected {} Hz, got {} Hz", + SAMPLE_RATE, sample_rate + ))); + } + + // Convert to mono + if channels > 1 { + audio = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::() / channels as f32) + .collect(); + } + + // Reset state for new audio + self.reset_state(); + + // Extract mel features (B, T, D) + let features = self.extract_mel_features(&audio); + let total_frames = features.shape()[1]; + + // Process in chunks + let chunk_stride = CHUNK_LEN * SUBSAMPLING; + let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; + + let mut all_chunk_preds = Vec::new(); + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_stride; + let end = (start + chunk_stride).min(total_frames); + let current_len = end - start; + + // Extract chunk features + let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); + + // Pad last chunk if needed + if current_len < chunk_stride { + let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); + padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); + chunk_feat = padded; + } + + // Run streaming update + let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; + all_chunk_preds.push(chunk_preds); + } + + // Concatenate all predictions + let full_preds = Self::concat_predictions(&all_chunk_preds); + + // Apply median filtering + let filtered_preds = if self.config.median_window > 1 { + self.median_filter(&full_preds) + } else { + full_preds + }; + + // Binarize to segments + let segments = self.binarize(&filtered_preds); + + Ok(segments) + } + + /// Streaming diarization that maintains state across calls. + /// + /// Unlike `diarize()`, this method does NOT reset the internal state, + /// allowing speaker embeddings to be preserved across multiple audio chunks. + /// Call `reset_state()` manually when starting a new audio session. + /// + /// This enables consistent speaker identification across long audio streams + /// by maintaining the speaker cache between processing windows. + /// + /// # Arguments + /// * `audio` - Audio samples (will be converted to mono if multi-channel) + /// * `sample_rate` - Must be 16000 Hz + /// * `channels` - Number of audio channels + /// + /// # Example + /// ```ignore + /// // Start of session + /// sortformer.reset_state(); + /// + /// // Process sliding windows + /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?; + /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs + /// ``` + pub fn diarize_streaming( + &mut self, + mut audio: Vec, + sample_rate: u32, + channels: u16, + ) -> Result> { + // Resample if needed + if sample_rate != SAMPLE_RATE as u32 { + return Err(Error::Audio(format!( + "Expected {} Hz, got {} Hz", + SAMPLE_RATE, sample_rate + ))); + } + + // Convert to mono + if channels > 1 { + audio = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::() / channels as f32) + .collect(); + } + + // NOTE: Unlike diarize(), we do NOT call reset_state() here + // This preserves speaker embeddings across calls + + // Extract mel features (B, T, D) + let features = self.extract_mel_features(&audio); + let total_frames = features.shape()[1]; + + // Process in chunks + let chunk_stride = CHUNK_LEN * SUBSAMPLING; + let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; + + let mut all_chunk_preds = Vec::new(); + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_stride; + let end = (start + chunk_stride).min(total_frames); + let current_len = end - start; + + // Extract chunk features + let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); + + // Pad last chunk if needed + if current_len < chunk_stride { + let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); + padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); + chunk_feat = padded; + } + + // Run streaming update + let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; + all_chunk_preds.push(chunk_preds); + } + + // Concatenate all predictions + let full_preds = Self::concat_predictions(&all_chunk_preds); + + // Apply median filtering + let filtered_preds = if self.config.median_window > 1 { + self.median_filter(&full_preds) + } else { + full_preds + }; + + // Binarize to segments + let segments = self.binarize(&filtered_preds); + + Ok(segments) + } + + /// NeMo's streaming_update with smart cache compression. + /// Public to allow incremental streaming diarization. + pub fn streaming_update(&mut self, chunk_feat: &Array3, current_len: usize) -> Result> { + let spkcache_len = self.spkcache.shape()[1]; + let fifo_len = self.fifo.shape()[1]; + + // Prepare inputs + let chunk_lengths = Array1::from_vec(vec![current_len as i64]); + let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]); + let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]); + + // Prepare FIFO input + let fifo_input = if fifo_len > 0 { + self.fifo.clone() + } else { + Array3::zeros((1, 0, EMB_DIM)) + }; + + // Prepare spkcache input (may be empty) + let spkcache_input = if spkcache_len > 0 { + self.spkcache.clone() + } else { + Array3::zeros((1, 0, EMB_DIM)) + }; + + // Create input values + let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?; + let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?; + let spkcache_value = ort::value::Value::from_array(spkcache_input)?; + let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?; + let fifo_value = ort::value::Value::from_array(fifo_input)?; + let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?; + + // Run ONNX inference and extract all data in a block to release borrow + let (preds, new_embs, chunk_len) = { + let outputs = self.session.run(ort::inputs!( + "chunk" => chunk_value, + "chunk_lengths" => chunk_lengths_value, + "spkcache" => spkcache_value, + "spkcache_lengths" => spkcache_lengths_value, + "fifo" => fifo_value, + "fifo_lengths" => fifo_lengths_value + ))?; + + // Extract outputs + let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"] + .try_extract_tensor::() + .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?; + let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"] + .try_extract_tensor::() + .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?; + + // Convert to ndarray + let preds_dims = preds_shape.as_ref(); + let embs_dims = embs_shape.as_ref(); + + let preds = Array3::from_shape_vec( + (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize), + preds_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?; + + let new_embs = Array3::from_shape_vec( + (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize), + embs_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?; + + // Calculate valid frames + let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING; + + (preds, new_embs, valid_frames) + }; + + // Extract predictions for different parts + let fifo_preds = if fifo_len > 0 { + preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned() + } else { + Array2::zeros((0, NUM_SPEAKERS)) + }; + + let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned(); + let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned(); + + // Append chunk embeddings to FIFO + self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0))); + + // Update FIFO predictions + if fifo_len > 0 { + let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds); + self.fifo_preds = combined.insert_axis(Axis(0)); + } else { + self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0)); + } + + let fifo_len_after = self.fifo.shape()[1]; + + // Move from FIFO to cache when FIFO exceeds limit + if fifo_len_after > FIFO_LEN { + let mut pop_out_len = SPKCACHE_UPDATE_PERIOD; + pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len); + pop_out_len = pop_out_len.min(fifo_len_after); + + let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned(); + let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned(); + + // Update silence profile + self.update_silence_profile(&pop_out_embs, &pop_out_preds); + + // Remove from FIFO + self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned(); + self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned(); + + // Append to cache + self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs); + + if let Some(ref cache_preds) = self.spkcache_preds { + self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds)); + } + + // Smart compression when cache exceeds limit + if self.spkcache.shape()[1] > SPKCACHE_LEN { + if self.spkcache_preds.is_none() { + // Initialize cache predictions from initial output + let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned(); + let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds); + self.spkcache_preds = Some(combined); + } + + // Use smart compression + self.compress_spkcache(); + } + } + + Ok(chunk_preds) + } + + /// Update mean silence embedding + fn update_silence_profile(&mut self, embs: &Array3, preds: &Array3) { + let preds_2d = preds.slice(s![0, .., ..]); + + for t in 0..preds_2d.shape()[0] { + let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum(); + if sum < SIL_THRESHOLD { + // This is a silence frame + let emb = embs.slice(s![0, t, ..]); + + // Update running mean + let old_sum: Vec = self.mean_sil_emb.slice(s![0, ..]).iter() + .map(|&x| x * self.n_sil_frames as f32) + .collect(); + + self.n_sil_frames += 1; + + for i in 0..EMB_DIM { + self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32; + } + } + } + } + + /// Smart cache compression + fn compress_spkcache(&mut self) { + let cache_preds = match &self.spkcache_preds { + Some(p) => p.clone(), + None => return, + }; + + let n_frames = self.spkcache.shape()[1]; + let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK; + let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize; + let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize; + let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize; + + // Calculate quality scores + let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned(); + let mut scores = self.get_log_pred_scores(&preds_2d); + + // Disable low scores + scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk); + + // Boost important frames + scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0); + scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0); + + // Add silence frames placeholder + if SPKCACHE_SIL_FRAMES_PER_SPK > 0 { + let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY); + padded.slice_mut(s![..n_frames, ..]).assign(&scores); + for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK { + for j in 0..NUM_SPEAKERS { + padded[[i, j]] = f32::INFINITY; + } + } + scores = padded; + } + + // Select top frames + let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames); + + // Gather embeddings + let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled); + + self.spkcache = new_embs; + self.spkcache_preds = Some(new_preds); + } + + /// Calculate quality scores + fn get_log_pred_scores(&self, preds: &Array2) -> Array2 { + let mut scores = Array2::zeros(preds.dim()); + + for t in 0..preds.shape()[0] { + let mut log_1_probs_sum = 0.0f32; + for s in 0..NUM_SPEAKERS { + let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); + let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); + log_1_probs_sum += log_1_p; + } + + for s in 0..NUM_SPEAKERS { + let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); + let log_p = p.ln(); + let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); + scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln(); + } + } + + scores + } + + /// Disable non-speech and overlapped speech + fn disable_low_scores(&self, preds: &Array2, mut scores: Array2, min_pos_scores_per_spk: usize) -> Array2 { + // Count positive scores per speaker + let mut pos_count = vec![0usize; NUM_SPEAKERS]; + for t in 0..scores.shape()[0] { + for s in 0..NUM_SPEAKERS { + if scores[[t, s]] > 0.0 { + pos_count[s] += 1; + } + } + } + + for t in 0..preds.shape()[0] { + for s in 0..NUM_SPEAKERS { + let is_speech = preds[[t, s]] > 0.5; + + if !is_speech { + scores[[t, s]] = f32::NEG_INFINITY; + } else { + let is_pos = scores[[t, s]] > 0.0; + if !is_pos && pos_count[s] >= min_pos_scores_per_spk { + scores[[t, s]] = f32::NEG_INFINITY; + } + } + } + } + + scores + } + + /// Boost top K frames per speaker + fn boost_topk_scores(&self, mut scores: Array2, n_boost_per_spk: usize, scale_factor: f32) -> Array2 { + for s in 0..NUM_SPEAKERS { + // Get column for this speaker + let col: Vec<(usize, f32)> = (0..scores.shape()[0]) + .map(|t| (t, scores[[t, s]])) + .collect(); + + // Sort by score descending + let mut sorted = col.clone(); + sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Boost top K + for i in 0..n_boost_per_spk.min(sorted.len()) { + let t = sorted[i].0; + if scores[[t, s]] != f32::NEG_INFINITY { + scores[[t, s]] -= scale_factor * 0.5f32.ln(); + } + } + } + + scores + } + + /// Get indices of top frames + fn get_topk_indices(&self, scores: &Array2, n_frames_no_sil: usize) -> (Vec, Vec) { + let n_frames = scores.shape()[0]; + + // Flatten scores as (S, T) then reshape to (S*T,) + // This means we iterate: speaker 0 all times, then speaker 1 all times, etc. + // flat_index = speaker * n_frames + time + let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS); + for s in 0..NUM_SPEAKERS { + for t in 0..n_frames { + let flat_idx = s * n_frames + t; + flat_scores.push((flat_idx, scores[[t, s]])); + } + } + + // Sort by score descending to get top-K + flat_scores.sort_by(|a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX + let mut topk_flat: Vec = flat_scores + .iter() + .take(SPKCACHE_LEN) + .map(|(idx, score)| { + if *score == f32::NEG_INFINITY { + MAX_INDEX + } else { + *idx + } + }) + .collect(); + + // Sort flat indices ascending (this puts MAX_INDEX at the end) + topk_flat.sort(); + + // Compute is_disabled and convert to frame indices + let mut is_disabled = vec![false; SPKCACHE_LEN]; + let mut frame_indices = vec![0usize; SPKCACHE_LEN]; + + for (i, &flat_idx) in topk_flat.iter().enumerate() { + if flat_idx == MAX_INDEX { + // Invalid entries are disabled + is_disabled[i] = true; + frame_indices[i] = 0; // We set disabled to 0 + } else { + // convert to frame index + let frame_idx = flat_idx % n_frames; + + // check if frame is beyond valid range + if frame_idx >= n_frames_no_sil { + is_disabled[i] = true; + frame_indices[i] = 0; // same as abov: set disabled to 0 + } else { + frame_indices[i] = frame_idx; + } + } + } + + (frame_indices, is_disabled) + } + + /// Gather selected frames + fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3, Array3) { + let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM)); + let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS)); + + let cache_preds = self.spkcache_preds.as_ref().unwrap(); + + for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() { + if i >= SPKCACHE_LEN { + break; + } + + if disabled { + // Use silence embedding + new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..])); + // Predictions stay zero + } else if idx < self.spkcache.shape()[1] { + new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..])); + new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..])); + } + } + + (new_embs, new_preds) + } + + /// Concatenate along axis 1 for 3D arrays + fn concat_axis1(a: &Array3, b: &Array3) -> Array3 { + if a.shape()[1] == 0 { + return b.clone(); + } + if b.shape()[1] == 0 { + return a.clone(); + } + ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap() + } + + /// Concatenate along axis 0 for 2D arrays + fn concat_axis1_2d(a: &Array2, b: &Array2) -> Array2 { + if a.shape()[0] == 0 { + return b.clone(); + } + if b.shape()[0] == 0 { + return a.clone(); + } + ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap() + } + + /// Concatenate predictions + fn concat_predictions(preds: &[Array2]) -> Array2 { + if preds.is_empty() { + return Array2::zeros((0, NUM_SPEAKERS)); + } + if preds.len() == 1 { + return preds[0].clone(); + } + + let views: Vec<_> = preds.iter().map(|p| p.view()).collect(); + ndarray::concatenate(Axis(0), &views).unwrap() + } + + /// Apply median filter to predictions + fn median_filter(&self, preds: &Array2) -> Array2 { + let window = self.config.median_window; + let half = window / 2; + let mut filtered = preds.clone(); + + for spk in 0..NUM_SPEAKERS { + for t in 0..preds.shape()[0] { + let start = t.saturating_sub(half); + let end = (t + half + 1).min(preds.shape()[0]); + + let mut values: Vec = (start..end) + .map(|i| preds[[i, spk]]) + .collect(); + values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + filtered[[t, spk]] = values[values.len() / 2]; + } + } + + filtered + } + + /// Binarize predictions to segments (padding applied during thresholding) + fn binarize(&self, preds: &Array2) -> Vec { + let mut segments = Vec::new(); + let num_frames = preds.shape()[0]; + + for spk in 0..NUM_SPEAKERS { + let mut in_seg = false; + let mut seg_start = 0; + let mut temp_segments = Vec::new(); + + for t in 0..num_frames { + let p = preds[[t, spk]]; + + if p >= self.config.onset && !in_seg { + in_seg = true; + seg_start = t; + } else if p < self.config.offset && in_seg { + in_seg = false; + + // Apply padding during conversion + let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); + let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset; + + if end_t - start_t >= self.config.min_duration_on { + temp_segments.push(SpeakerSegment { + start: start_t, + end: end_t, + speaker_id: spk, + }); + } + } + } + + // Handle segment at end + if in_seg { + let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); + let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset; + + if end_t - start_t >= self.config.min_duration_on { + temp_segments.push(SpeakerSegment { + start: start_t, + end: end_t, + speaker_id: spk, + }); + } + } + + // Merge close segments (min_duration_off) + if temp_segments.len() > 1 { + let mut filtered = vec![temp_segments[0].clone()]; + for seg in temp_segments.into_iter().skip(1) { + let last = filtered.last_mut().unwrap(); + let gap = seg.start - last.end; + if gap < self.config.min_duration_off { + last.end = seg.end; // Merge + } else { + filtered.push(seg); + } + } + segments.extend(filtered); + } else { + segments.extend(temp_segments); + } + } + + // Sort by start time + segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap()); + segments + } + + + fn apply_preemphasis(audio: &[f32]) -> Vec { + let mut result = Vec::with_capacity(audio.len()); + result.push(audio[0]); + for i in 1..audio.len() { + result.push(audio[i] - PREEMPH * audio[i - 1]); + } + result + } + + fn hann_window(window_length: usize) -> Vec { + // Librosa uses periodic window (fftbins=True): divide by N, not N-1 + (0..window_length) + .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos()) + .collect() + } + + fn stft(audio: &[f32]) -> Array2 { + let mut planner = FftPlanner::::new(); + let fft = planner.plan_fft_forward(N_FFT); + + // Create Hann window of length win_length, then zero-pad to n_fft (centered) + // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft) + let hann = Self::hann_window(WIN_LENGTH); + let win_offset = (N_FFT - WIN_LENGTH) / 2; + let mut fft_window = vec![0.0f32; N_FFT]; + for i in 0..WIN_LENGTH { + fft_window[win_offset + i] = hann[i]; + } + + // Pad signal for center=True (like librosa/torch.stft) + // Padding is n_fft // 2 on each side + let pad_amount = N_FFT / 2; + let mut padded_audio = vec![0.0; pad_amount]; + padded_audio.extend_from_slice(audio); + padded_audio.extend(vec![0.0; pad_amount]); + + let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1; + let freq_bins = N_FFT / 2 + 1; + let mut spectrogram = Array2::::zeros((freq_bins, num_frames)); + + for frame_idx in 0..num_frames { + let start = frame_idx * HOP_LENGTH; + let mut frame: Vec> = vec![Complex::new(0.0, 0.0); N_FFT]; + + // Extract n_fft samples and multiply by zero-padded window + for i in 0..N_FFT { + if start + i < padded_audio.len() { + frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0); + } + } + + fft.process(&mut frame); + for k in 0..freq_bins { + let magnitude = frame[k].norm(); + // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0 + spectrogram[[k, frame_idx]] = magnitude * magnitude; + } + } + + spectrogram + } + + // Librosa's Slaney mel scale (htk=False, which is the default) + fn hz_to_mel_slaney(hz: f64) -> f64 { + let f_min = 0.0; + let f_sp = 200.0 / 3.0; + let min_log_hz = 1000.0; + let min_log_mel = (min_log_hz - f_min) / f_sp; + let logstep = (6.4f64).ln() / 27.0; + + if hz >= min_log_hz { + min_log_mel + (hz / min_log_hz).ln() / logstep + } else { + (hz - f_min) / f_sp + } + } + + fn mel_to_hz_slaney(mel: f64) -> f64 { + let f_min = 0.0; + let f_sp = 200.0 / 3.0; + let min_log_hz = 1000.0; + let min_log_mel = (min_log_hz - f_min) / f_sp; + let logstep = (6.4f64).ln() / 27.0; + + if mel >= min_log_mel { + min_log_hz * (logstep * (mel - min_log_mel)).exp() + } else { + f_min + f_sp * mel + } + } + + fn create_mel_filterbank() -> Array2 { + // lets use f64 for intermediate calculations to avoid precision loss + let freq_bins = N_FFT / 2 + 1; + let mut filterbank = Array2::::zeros((N_MELS, freq_bins)); + + // FFT frequencies: fftfreqs[k] = k * sr / n_fft + let fftfreqs: Vec = (0..freq_bins) + .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64) + .collect(); + + // Mel center frequencies using Slaney scale (librosa default, htk=False) + let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64); + let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64); + let mel_f: Vec = (0..=N_MELS + 1) + .map(|i| { + let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64; + Self::mel_to_hz_slaney(mel) + }) + .collect(); + + // Differences between consecutive mel frequencies + let fdiff: Vec = mel_f.windows(2).map(|w| w[1] - w[0]).collect(); + + // Compute filterbank weights (reference: librosa's ramp method) + // https://librosa.org/doc/main/generated/librosa.stft.html + for i in 0..N_MELS { + for k in 0..freq_bins { + // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i] + let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i]; + // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1] + let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1]; + // Weight is max(0, min(lower, upper)) + filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32; + } + } + + // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i]) + for i in 0..N_MELS { + let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]); + for k in 0..freq_bins { + filterbank[[i, k]] *= enorm as f32; + } + } + + filterbank + } + + fn extract_mel_features(&self, audio: &[f32]) -> Array3 { + // 1. Add dither (small random noise to prevent log(0)) + // NeMo uses dither=1e-5, but for determinism we skip random noise + // The log_zero_guard handles zero values + + // 2. Apply preemphasis (NeMo uses preemph=0.97) + let preemphasized = Self::apply_preemphasis(audio); + + // 3. STFT + let spectrogram = Self::stft(&preemphasized); + + // 4. Apply mel filterbank (with Slaney normalization) + let mel_spec = self.mel_basis.dot(&spectrogram); + + // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24) + // NeMo uses normalize='NA' which means NO normalization + let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln()); + + let num_frames = log_mel_spec.shape()[1]; + let mut features = Array3::::zeros((1, num_frames, N_MELS)); + + // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D) + for t in 0..num_frames { + for m in 0..N_MELS { + features[[0, t, m]] = log_mel_spec[[m, t]]; + } + } + + features + } +} -- cgit v1.2.3