diff options
Diffstat (limited to 'parakeet-rs/src/sortformer.rs')
| -rw-r--r-- | parakeet-rs/src/sortformer.rs | 1062 |
1 files changed, 0 insertions, 1062 deletions
diff --git a/parakeet-rs/src/sortformer.rs b/parakeet-rs/src/sortformer.rs deleted file mode 100644 index 2b1e5a3..0000000 --- a/parakeet-rs/src/sortformer.rs +++ /dev/null @@ -1,1062 +0,0 @@ -//! NVIDIA Sortformer v2 Streaming Speaker Diarization -//! -//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization. -//! -//! Key features: -//! - Streaming inference with ~10s chunks (124 frames at 80ms each) -//! - FIFO buffer for context management -//! - Smart speaker cache compression (keeps important frames, not just recent) -//! - Silence profile tracking -//! - Post-processing: median filtering, hysteresis thresholding -//! - Supports up to 4 speakers -//! -//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 -//! Note that, my ONNX export: -//! CHUNK_LEN = 124 -//! FIFO_LEN = 124 -//! CACHE_LEN = 188 -//! FEAT_DIM = 128 -//! EMB_DIM = 512 -//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html - -use crate::error::{Error, Result}; -use crate::execution::ModelConfig; -use ndarray::{s, Array1, Array2, Array3, Axis}; -use ort::session::Session; -use rustfft::{num_complex::Complex, FftPlanner}; -use std::f32::consts::PI; -use std::path::Path; - -// Model constants -const N_FFT: usize = 512; -const WIN_LENGTH: usize = 400; -const HOP_LENGTH: usize = 160; -const N_MELS: usize = 128; -const PREEMPH: f32 = 0.97; -const LOG_ZERO_GUARD: f32 = 5.960464478e-8; -const SAMPLE_RATE: usize = 16000; -const FMIN: f32 = 0.0; -const FMAX: f32 = 8000.0; - -// Streaming constants -const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms) -const FIFO_LEN: usize = 124; // FIFO buffer length -const SPKCACHE_LEN: usize = 188; // Speaker cache length -const SPKCACHE_UPDATE_PERIOD: usize = 124; -const SUBSAMPLING: usize = 8; // Audio frames -> model frames -const EMB_DIM: usize = 512; // Embedding dimension -const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers -const FRAME_DURATION: f32 = 0.08; // 80ms per frame - -// Cache compression params (from NeMo) -const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3; -const PRED_SCORE_THRESHOLD: f32 = 0.25; -const STRONG_BOOST_RATE: f32 = 0.75; -const WEAK_BOOST_RATE: f32 = 1.5; -const MIN_POS_SCORES_RATE: f32 = 0.5; -const SIL_THRESHOLD: f32 = 0.2; -const MAX_INDEX: usize = 99999; - -/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs) -/// -/// Controls how raw model predictions are converted into speaker segments. -/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI). -/// -/// # Parameters -/// - `onset`: Probability threshold to START a speaker segment (higher = more strict) -/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments) -/// - `pad_onset`: Seconds to subtract from segment start times -/// - `pad_offset`: Seconds to add to segment end times -/// - `min_duration_on`: Minimum segment length in seconds (filters short blips) -/// - `min_duration_off`: Minimum gap between segments before merging -/// - `median_window`: Smoothing window size (odd number, higher = smoother) -/// -/// # Pre-tuned Configs -/// - `callhome()` - (default) -/// - `dihard3()` -/// -/// # Custom Config -/// Use `custom(onset, offset)` to create your own config for fine-tuning. -/// -/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer -#[derive(Debug, Clone)] -pub struct DiarizationConfig { - pub onset: f32, - pub offset: f32, - pub pad_onset: f32, - pub pad_offset: f32, - pub min_duration_on: f32, - pub min_duration_off: f32, - pub median_window: usize, -} - -impl Default for DiarizationConfig { - fn default() -> Self { - Self::callhome() - } -} - -impl DiarizationConfig { - /// CallHome dataset config for v2 (default) - /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml - pub fn callhome() -> Self { - Self { - onset: 0.641, - offset: 0.561, - pad_onset: 0.229, - pad_offset: 0.079, - min_duration_on: 0.511, - min_duration_off: 0.296, - median_window: 11, - } - } - - /// DIHARD3 dataset config for v2 - /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml - pub fn dihard3() -> Self { - Self { - onset: 0.56, - offset: 1.0, - pad_onset: 0.063, - pad_offset: 0.002, - min_duration_on: 0.007, - min_duration_off: 0.151, - median_window: 11, - } - } - - /// Create a custom config for fine-tuning diarization behavior. - /// - /// # Arguments - /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7) - /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6) - /// - /// # Example - /// ```rust - /// use parakeet_rs::sortformer::DiarizationConfig; - /// - /// // More sensitive detection (lower thresholds) - /// let sensitive = DiarizationConfig::custom(0.5, 0.4); - /// - /// // Stricter detection (higher thresholds, fewer false positives) - /// let strict = DiarizationConfig::custom(0.7, 0.6); - /// - /// // Full customization - /// let mut config = DiarizationConfig::custom(0.6, 0.5); - /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms - /// config.median_window = 15; // More smoothing - /// ``` - pub fn custom(onset: f32, offset: f32) -> Self { - Self { - onset, - offset, - pad_onset: 0.0, - pad_offset: 0.0, - min_duration_on: 0.1, - min_duration_off: 0.1, - median_window: 11, - } - } -} - -/// Speaker segment with start time, end time, and speaker ID -#[derive(Debug, Clone)] -pub struct SpeakerSegment { - pub start: f32, - pub end: f32, - pub speaker_id: usize, -} - -/// Streaming Sortformer v2 speaker diarization engine -pub struct Sortformer { - session: Session, - config: DiarizationConfig, - // Streaming state. note that, Same way as Nemo - spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM) - spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS) - fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM) - fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS) - mean_sil_emb: Array2<f32>, // (1, EMB_DIM) - n_sil_frames: usize, - // Mel filterbank (cached) - mel_basis: Array2<f32>, -} - -impl Sortformer { - /// a new Sortformer instance from ONNX model path - pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> { - Self::with_config(model_path, None, DiarizationConfig::default()) - } - - /// Create with custom config - pub fn with_config<P: AsRef<Path>>( - model_path: P, - execution_config: Option<ModelConfig>, - config: DiarizationConfig, - ) -> Result<Self> { - let config_to_use = execution_config.unwrap_or_default(); - - let session = config_to_use - .apply_to_session_builder(Session::builder()?)? - .commit_from_file(model_path.as_ref())?; - - let mel_basis = Self::create_mel_filterbank(); - - let mut instance = Self { - session, - config, - spkcache: Array3::zeros((1, 0, EMB_DIM)), - spkcache_preds: None, - fifo: Array3::zeros((1, 0, EMB_DIM)), - fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)), - mean_sil_emb: Array2::zeros((1, EMB_DIM)), - n_sil_frames: 0, - mel_basis, - }; - instance.reset_state(); - Ok(instance) - } - - /// Reset streaming state - pub fn reset_state(&mut self) { - self.spkcache = Array3::zeros((1, 0, EMB_DIM)); - self.spkcache_preds = None; - self.fifo = Array3::zeros((1, 0, EMB_DIM)); - self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS)); - self.mean_sil_emb = Array2::zeros((1, EMB_DIM)); - self.n_sil_frames = 0; - } - - /// Main diarization entry point - pub fn diarize( - &mut self, - mut audio: Vec<f32>, - sample_rate: u32, - channels: u16, - ) -> Result<Vec<SpeakerSegment>> { - // Resample if needed - if sample_rate != SAMPLE_RATE as u32 { - return Err(Error::Audio(format!( - "Expected {} Hz, got {} Hz", - SAMPLE_RATE, sample_rate - ))); - } - - // Convert to mono - if channels > 1 { - audio = audio - .chunks(channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) - .collect(); - } - - // Reset state for new audio - self.reset_state(); - - // Extract mel features (B, T, D) - let features = self.extract_mel_features(&audio); - let total_frames = features.shape()[1]; - - // Process in chunks - let chunk_stride = CHUNK_LEN * SUBSAMPLING; - let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; - - let mut all_chunk_preds = Vec::new(); - - for chunk_idx in 0..num_chunks { - let start = chunk_idx * chunk_stride; - let end = (start + chunk_stride).min(total_frames); - let current_len = end - start; - - // Extract chunk features - let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); - - // Pad last chunk if needed - if current_len < chunk_stride { - let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); - padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); - chunk_feat = padded; - } - - // Run streaming update - let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; - all_chunk_preds.push(chunk_preds); - } - - // Concatenate all predictions - let full_preds = Self::concat_predictions(&all_chunk_preds); - - // Apply median filtering - let filtered_preds = if self.config.median_window > 1 { - self.median_filter(&full_preds) - } else { - full_preds - }; - - // Binarize to segments - let segments = self.binarize(&filtered_preds); - - Ok(segments) - } - - /// Streaming diarization that maintains state across calls. - /// - /// Unlike `diarize()`, this method does NOT reset the internal state, - /// allowing speaker embeddings to be preserved across multiple audio chunks. - /// Call `reset_state()` manually when starting a new audio session. - /// - /// This enables consistent speaker identification across long audio streams - /// by maintaining the speaker cache between processing windows. - /// - /// # Arguments - /// * `audio` - Audio samples (will be converted to mono if multi-channel) - /// * `sample_rate` - Must be 16000 Hz - /// * `channels` - Number of audio channels - /// - /// # Example - /// ```ignore - /// // Start of session - /// sortformer.reset_state(); - /// - /// // Process sliding windows - /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?; - /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs - /// ``` - pub fn diarize_streaming( - &mut self, - mut audio: Vec<f32>, - sample_rate: u32, - channels: u16, - ) -> Result<Vec<SpeakerSegment>> { - // Resample if needed - if sample_rate != SAMPLE_RATE as u32 { - return Err(Error::Audio(format!( - "Expected {} Hz, got {} Hz", - SAMPLE_RATE, sample_rate - ))); - } - - // Convert to mono - if channels > 1 { - audio = audio - .chunks(channels as usize) - .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) - .collect(); - } - - // NOTE: Unlike diarize(), we do NOT call reset_state() here - // This preserves speaker embeddings across calls - - // Extract mel features (B, T, D) - let features = self.extract_mel_features(&audio); - let total_frames = features.shape()[1]; - - // Process in chunks - let chunk_stride = CHUNK_LEN * SUBSAMPLING; - let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; - - let mut all_chunk_preds = Vec::new(); - - for chunk_idx in 0..num_chunks { - let start = chunk_idx * chunk_stride; - let end = (start + chunk_stride).min(total_frames); - let current_len = end - start; - - // Extract chunk features - let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); - - // Pad last chunk if needed - if current_len < chunk_stride { - let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); - padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); - chunk_feat = padded; - } - - // Run streaming update - let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; - all_chunk_preds.push(chunk_preds); - } - - // Concatenate all predictions - let full_preds = Self::concat_predictions(&all_chunk_preds); - - // Apply median filtering - let filtered_preds = if self.config.median_window > 1 { - self.median_filter(&full_preds) - } else { - full_preds - }; - - // Binarize to segments - let segments = self.binarize(&filtered_preds); - - Ok(segments) - } - - /// NeMo's streaming_update with smart cache compression. - /// Public to allow incremental streaming diarization. - pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> { - let spkcache_len = self.spkcache.shape()[1]; - let fifo_len = self.fifo.shape()[1]; - - // Prepare inputs - let chunk_lengths = Array1::from_vec(vec![current_len as i64]); - let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]); - let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]); - - // Prepare FIFO input - let fifo_input = if fifo_len > 0 { - self.fifo.clone() - } else { - Array3::zeros((1, 0, EMB_DIM)) - }; - - // Prepare spkcache input (may be empty) - let spkcache_input = if spkcache_len > 0 { - self.spkcache.clone() - } else { - Array3::zeros((1, 0, EMB_DIM)) - }; - - // Create input values - let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?; - let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?; - let spkcache_value = ort::value::Value::from_array(spkcache_input)?; - let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?; - let fifo_value = ort::value::Value::from_array(fifo_input)?; - let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?; - - // Run ONNX inference and extract all data in a block to release borrow - let (preds, new_embs, chunk_len) = { - let outputs = self.session.run(ort::inputs!( - "chunk" => chunk_value, - "chunk_lengths" => chunk_lengths_value, - "spkcache" => spkcache_value, - "spkcache_lengths" => spkcache_lengths_value, - "fifo" => fifo_value, - "fifo_lengths" => fifo_lengths_value - ))?; - - // Extract outputs - let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?; - let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?; - - // Convert to ndarray - let preds_dims = preds_shape.as_ref(); - let embs_dims = embs_shape.as_ref(); - - let preds = Array3::from_shape_vec( - (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize), - preds_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?; - - let new_embs = Array3::from_shape_vec( - (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize), - embs_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?; - - // Calculate valid frames - let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING; - - (preds, new_embs, valid_frames) - }; - - // Extract predictions for different parts - let fifo_preds = if fifo_len > 0 { - preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned() - } else { - Array2::zeros((0, NUM_SPEAKERS)) - }; - - let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned(); - let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned(); - - // Append chunk embeddings to FIFO - self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0))); - - // Update FIFO predictions - if fifo_len > 0 { - let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds); - self.fifo_preds = combined.insert_axis(Axis(0)); - } else { - self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0)); - } - - let fifo_len_after = self.fifo.shape()[1]; - - // Move from FIFO to cache when FIFO exceeds limit - if fifo_len_after > FIFO_LEN { - let mut pop_out_len = SPKCACHE_UPDATE_PERIOD; - pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len); - pop_out_len = pop_out_len.min(fifo_len_after); - - let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned(); - let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned(); - - // Update silence profile - self.update_silence_profile(&pop_out_embs, &pop_out_preds); - - // Remove from FIFO - self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned(); - self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned(); - - // Append to cache - self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs); - - if let Some(ref cache_preds) = self.spkcache_preds { - self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds)); - } - - // Smart compression when cache exceeds limit - if self.spkcache.shape()[1] > SPKCACHE_LEN { - if self.spkcache_preds.is_none() { - // Initialize cache predictions from initial output - let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned(); - let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds); - self.spkcache_preds = Some(combined); - } - - // Use smart compression - self.compress_spkcache(); - } - } - - Ok(chunk_preds) - } - - /// Update mean silence embedding - fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) { - let preds_2d = preds.slice(s![0, .., ..]); - - for t in 0..preds_2d.shape()[0] { - let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum(); - if sum < SIL_THRESHOLD { - // This is a silence frame - let emb = embs.slice(s![0, t, ..]); - - // Update running mean - let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter() - .map(|&x| x * self.n_sil_frames as f32) - .collect(); - - self.n_sil_frames += 1; - - for i in 0..EMB_DIM { - self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32; - } - } - } - } - - /// Smart cache compression - fn compress_spkcache(&mut self) { - let cache_preds = match &self.spkcache_preds { - Some(p) => p.clone(), - None => return, - }; - - let n_frames = self.spkcache.shape()[1]; - let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK; - let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize; - let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize; - let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize; - - // Calculate quality scores - let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned(); - let mut scores = self.get_log_pred_scores(&preds_2d); - - // Disable low scores - scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk); - - // Boost important frames - scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0); - scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0); - - // Add silence frames placeholder - if SPKCACHE_SIL_FRAMES_PER_SPK > 0 { - let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY); - padded.slice_mut(s![..n_frames, ..]).assign(&scores); - for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK { - for j in 0..NUM_SPEAKERS { - padded[[i, j]] = f32::INFINITY; - } - } - scores = padded; - } - - // Select top frames - let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames); - - // Gather embeddings - let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled); - - self.spkcache = new_embs; - self.spkcache_preds = Some(new_preds); - } - - /// Calculate quality scores - fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> { - let mut scores = Array2::zeros(preds.dim()); - - for t in 0..preds.shape()[0] { - let mut log_1_probs_sum = 0.0f32; - for s in 0..NUM_SPEAKERS { - let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); - let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); - log_1_probs_sum += log_1_p; - } - - for s in 0..NUM_SPEAKERS { - let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); - let log_p = p.ln(); - let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); - scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln(); - } - } - - scores - } - - /// Disable non-speech and overlapped speech - fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> { - // Count positive scores per speaker - let mut pos_count = vec![0usize; NUM_SPEAKERS]; - for t in 0..scores.shape()[0] { - for s in 0..NUM_SPEAKERS { - if scores[[t, s]] > 0.0 { - pos_count[s] += 1; - } - } - } - - for t in 0..preds.shape()[0] { - for s in 0..NUM_SPEAKERS { - let is_speech = preds[[t, s]] > 0.5; - - if !is_speech { - scores[[t, s]] = f32::NEG_INFINITY; - } else { - let is_pos = scores[[t, s]] > 0.0; - if !is_pos && pos_count[s] >= min_pos_scores_per_spk { - scores[[t, s]] = f32::NEG_INFINITY; - } - } - } - } - - scores - } - - /// Boost top K frames per speaker - fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> { - for s in 0..NUM_SPEAKERS { - // Get column for this speaker - let col: Vec<(usize, f32)> = (0..scores.shape()[0]) - .map(|t| (t, scores[[t, s]])) - .collect(); - - // Sort by score descending - let mut sorted = col.clone(); - sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Boost top K - for i in 0..n_boost_per_spk.min(sorted.len()) { - let t = sorted[i].0; - if scores[[t, s]] != f32::NEG_INFINITY { - scores[[t, s]] -= scale_factor * 0.5f32.ln(); - } - } - } - - scores - } - - /// Get indices of top frames - fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) { - let n_frames = scores.shape()[0]; - - // Flatten scores as (S, T) then reshape to (S*T,) - // This means we iterate: speaker 0 all times, then speaker 1 all times, etc. - // flat_index = speaker * n_frames + time - let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS); - for s in 0..NUM_SPEAKERS { - for t in 0..n_frames { - let flat_idx = s * n_frames + t; - flat_scores.push((flat_idx, scores[[t, s]])); - } - } - - // Sort by score descending to get top-K - flat_scores.sort_by(|a, b| { - b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) - }); - - // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX - let mut topk_flat: Vec<usize> = flat_scores - .iter() - .take(SPKCACHE_LEN) - .map(|(idx, score)| { - if *score == f32::NEG_INFINITY { - MAX_INDEX - } else { - *idx - } - }) - .collect(); - - // Sort flat indices ascending (this puts MAX_INDEX at the end) - topk_flat.sort(); - - // Compute is_disabled and convert to frame indices - let mut is_disabled = vec![false; SPKCACHE_LEN]; - let mut frame_indices = vec![0usize; SPKCACHE_LEN]; - - for (i, &flat_idx) in topk_flat.iter().enumerate() { - if flat_idx == MAX_INDEX { - // Invalid entries are disabled - is_disabled[i] = true; - frame_indices[i] = 0; // We set disabled to 0 - } else { - // convert to frame index - let frame_idx = flat_idx % n_frames; - - // check if frame is beyond valid range - if frame_idx >= n_frames_no_sil { - is_disabled[i] = true; - frame_indices[i] = 0; // same as abov: set disabled to 0 - } else { - frame_indices[i] = frame_idx; - } - } - } - - (frame_indices, is_disabled) - } - - /// Gather selected frames - fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) { - let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM)); - let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS)); - - let cache_preds = self.spkcache_preds.as_ref().unwrap(); - - for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() { - if i >= SPKCACHE_LEN { - break; - } - - if disabled { - // Use silence embedding - new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..])); - // Predictions stay zero - } else if idx < self.spkcache.shape()[1] { - new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..])); - new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..])); - } - } - - (new_embs, new_preds) - } - - /// Concatenate along axis 1 for 3D arrays - fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> { - if a.shape()[1] == 0 { - return b.clone(); - } - if b.shape()[1] == 0 { - return a.clone(); - } - ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap() - } - - /// Concatenate along axis 0 for 2D arrays - fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> { - if a.shape()[0] == 0 { - return b.clone(); - } - if b.shape()[0] == 0 { - return a.clone(); - } - ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap() - } - - /// Concatenate predictions - fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> { - if preds.is_empty() { - return Array2::zeros((0, NUM_SPEAKERS)); - } - if preds.len() == 1 { - return preds[0].clone(); - } - - let views: Vec<_> = preds.iter().map(|p| p.view()).collect(); - ndarray::concatenate(Axis(0), &views).unwrap() - } - - /// Apply median filter to predictions - fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> { - let window = self.config.median_window; - let half = window / 2; - let mut filtered = preds.clone(); - - for spk in 0..NUM_SPEAKERS { - for t in 0..preds.shape()[0] { - let start = t.saturating_sub(half); - let end = (t + half + 1).min(preds.shape()[0]); - - let mut values: Vec<f32> = (start..end) - .map(|i| preds[[i, spk]]) - .collect(); - values.sort_by(|a, b| a.partial_cmp(b).unwrap()); - - filtered[[t, spk]] = values[values.len() / 2]; - } - } - - filtered - } - - /// Binarize predictions to segments (padding applied during thresholding) - fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> { - let mut segments = Vec::new(); - let num_frames = preds.shape()[0]; - - for spk in 0..NUM_SPEAKERS { - let mut in_seg = false; - let mut seg_start = 0; - let mut temp_segments = Vec::new(); - - for t in 0..num_frames { - let p = preds[[t, spk]]; - - if p >= self.config.onset && !in_seg { - in_seg = true; - seg_start = t; - } else if p < self.config.offset && in_seg { - in_seg = false; - - // Apply padding during conversion - let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); - let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset; - - if end_t - start_t >= self.config.min_duration_on { - temp_segments.push(SpeakerSegment { - start: start_t, - end: end_t, - speaker_id: spk, - }); - } - } - } - - // Handle segment at end - if in_seg { - let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); - let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset; - - if end_t - start_t >= self.config.min_duration_on { - temp_segments.push(SpeakerSegment { - start: start_t, - end: end_t, - speaker_id: spk, - }); - } - } - - // Merge close segments (min_duration_off) - if temp_segments.len() > 1 { - let mut filtered = vec![temp_segments[0].clone()]; - for seg in temp_segments.into_iter().skip(1) { - let last = filtered.last_mut().unwrap(); - let gap = seg.start - last.end; - if gap < self.config.min_duration_off { - last.end = seg.end; // Merge - } else { - filtered.push(seg); - } - } - segments.extend(filtered); - } else { - segments.extend(temp_segments); - } - } - - // Sort by start time - segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap()); - segments - } - - - fn apply_preemphasis(audio: &[f32]) -> Vec<f32> { - let mut result = Vec::with_capacity(audio.len()); - result.push(audio[0]); - for i in 1..audio.len() { - result.push(audio[i] - PREEMPH * audio[i - 1]); - } - result - } - - fn hann_window(window_length: usize) -> Vec<f32> { - // Librosa uses periodic window (fftbins=True): divide by N, not N-1 - (0..window_length) - .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos()) - .collect() - } - - fn stft(audio: &[f32]) -> Array2<f32> { - let mut planner = FftPlanner::<f32>::new(); - let fft = planner.plan_fft_forward(N_FFT); - - // Create Hann window of length win_length, then zero-pad to n_fft (centered) - // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft) - let hann = Self::hann_window(WIN_LENGTH); - let win_offset = (N_FFT - WIN_LENGTH) / 2; - let mut fft_window = vec![0.0f32; N_FFT]; - for i in 0..WIN_LENGTH { - fft_window[win_offset + i] = hann[i]; - } - - // Pad signal for center=True (like librosa/torch.stft) - // Padding is n_fft // 2 on each side - let pad_amount = N_FFT / 2; - let mut padded_audio = vec![0.0; pad_amount]; - padded_audio.extend_from_slice(audio); - padded_audio.extend(vec![0.0; pad_amount]); - - let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1; - let freq_bins = N_FFT / 2 + 1; - let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames)); - - for frame_idx in 0..num_frames { - let start = frame_idx * HOP_LENGTH; - let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT]; - - // Extract n_fft samples and multiply by zero-padded window - for i in 0..N_FFT { - if start + i < padded_audio.len() { - frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0); - } - } - - fft.process(&mut frame); - for k in 0..freq_bins { - let magnitude = frame[k].norm(); - // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0 - spectrogram[[k, frame_idx]] = magnitude * magnitude; - } - } - - spectrogram - } - - // Librosa's Slaney mel scale (htk=False, which is the default) - fn hz_to_mel_slaney(hz: f64) -> f64 { - let f_min = 0.0; - let f_sp = 200.0 / 3.0; - let min_log_hz = 1000.0; - let min_log_mel = (min_log_hz - f_min) / f_sp; - let logstep = (6.4f64).ln() / 27.0; - - if hz >= min_log_hz { - min_log_mel + (hz / min_log_hz).ln() / logstep - } else { - (hz - f_min) / f_sp - } - } - - fn mel_to_hz_slaney(mel: f64) -> f64 { - let f_min = 0.0; - let f_sp = 200.0 / 3.0; - let min_log_hz = 1000.0; - let min_log_mel = (min_log_hz - f_min) / f_sp; - let logstep = (6.4f64).ln() / 27.0; - - if mel >= min_log_mel { - min_log_hz * (logstep * (mel - min_log_mel)).exp() - } else { - f_min + f_sp * mel - } - } - - fn create_mel_filterbank() -> Array2<f32> { - // lets use f64 for intermediate calculations to avoid precision loss - let freq_bins = N_FFT / 2 + 1; - let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins)); - - // FFT frequencies: fftfreqs[k] = k * sr / n_fft - let fftfreqs: Vec<f64> = (0..freq_bins) - .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64) - .collect(); - - // Mel center frequencies using Slaney scale (librosa default, htk=False) - let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64); - let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64); - let mel_f: Vec<f64> = (0..=N_MELS + 1) - .map(|i| { - let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64; - Self::mel_to_hz_slaney(mel) - }) - .collect(); - - // Differences between consecutive mel frequencies - let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect(); - - // Compute filterbank weights (reference: librosa's ramp method) - // https://librosa.org/doc/main/generated/librosa.stft.html - for i in 0..N_MELS { - for k in 0..freq_bins { - // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i] - let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i]; - // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1] - let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1]; - // Weight is max(0, min(lower, upper)) - filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32; - } - } - - // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i]) - for i in 0..N_MELS { - let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]); - for k in 0..freq_bins { - filterbank[[i, k]] *= enorm as f32; - } - } - - filterbank - } - - fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> { - // 1. Add dither (small random noise to prevent log(0)) - // NeMo uses dither=1e-5, but for determinism we skip random noise - // The log_zero_guard handles zero values - - // 2. Apply preemphasis (NeMo uses preemph=0.97) - let preemphasized = Self::apply_preemphasis(audio); - - // 3. STFT - let spectrogram = Self::stft(&preemphasized); - - // 4. Apply mel filterbank (with Slaney normalization) - let mel_spec = self.mel_basis.dot(&spectrogram); - - // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24) - // NeMo uses normalize='NA' which means NO normalization - let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln()); - - let num_frames = log_mel_spec.shape()[1]; - let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS)); - - // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D) - for t in 0..num_frames { - for m in 0..N_MELS { - features[[0, t, m]] = log_mel_spec[[m, t]]; - } - } - - features - } -} |
