summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/sortformer.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 00:40:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch)
tree0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/sortformer.rs
parent84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff)
downloadsoryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz
soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/sortformer.rs')
-rw-r--r--parakeet-rs/src/sortformer.rs1062
1 files changed, 1062 insertions, 0 deletions
diff --git a/parakeet-rs/src/sortformer.rs b/parakeet-rs/src/sortformer.rs
new file mode 100644
index 0000000..2b1e5a3
--- /dev/null
+++ b/parakeet-rs/src/sortformer.rs
@@ -0,0 +1,1062 @@
+//! NVIDIA Sortformer v2 Streaming Speaker Diarization
+//!
+//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization.
+//!
+//! Key features:
+//! - Streaming inference with ~10s chunks (124 frames at 80ms each)
+//! - FIFO buffer for context management
+//! - Smart speaker cache compression (keeps important frames, not just recent)
+//! - Silence profile tracking
+//! - Post-processing: median filtering, hysteresis thresholding
+//! - Supports up to 4 speakers
+//!
+//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
+//! Note that, my ONNX export:
+//! CHUNK_LEN = 124
+//! FIFO_LEN = 124
+//! CACHE_LEN = 188
+//! FEAT_DIM = 128
+//! EMB_DIM = 512
+//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html
+
+use crate::error::{Error, Result};
+use crate::execution::ModelConfig;
+use ndarray::{s, Array1, Array2, Array3, Axis};
+use ort::session::Session;
+use rustfft::{num_complex::Complex, FftPlanner};
+use std::f32::consts::PI;
+use std::path::Path;
+
+// Model constants
+const N_FFT: usize = 512;
+const WIN_LENGTH: usize = 400;
+const HOP_LENGTH: usize = 160;
+const N_MELS: usize = 128;
+const PREEMPH: f32 = 0.97;
+const LOG_ZERO_GUARD: f32 = 5.960464478e-8;
+const SAMPLE_RATE: usize = 16000;
+const FMIN: f32 = 0.0;
+const FMAX: f32 = 8000.0;
+
+// Streaming constants
+const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms)
+const FIFO_LEN: usize = 124; // FIFO buffer length
+const SPKCACHE_LEN: usize = 188; // Speaker cache length
+const SPKCACHE_UPDATE_PERIOD: usize = 124;
+const SUBSAMPLING: usize = 8; // Audio frames -> model frames
+const EMB_DIM: usize = 512; // Embedding dimension
+const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers
+const FRAME_DURATION: f32 = 0.08; // 80ms per frame
+
+// Cache compression params (from NeMo)
+const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3;
+const PRED_SCORE_THRESHOLD: f32 = 0.25;
+const STRONG_BOOST_RATE: f32 = 0.75;
+const WEAK_BOOST_RATE: f32 = 1.5;
+const MIN_POS_SCORES_RATE: f32 = 0.5;
+const SIL_THRESHOLD: f32 = 0.2;
+const MAX_INDEX: usize = 99999;
+
+/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs)
+///
+/// Controls how raw model predictions are converted into speaker segments.
+/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI).
+///
+/// # Parameters
+/// - `onset`: Probability threshold to START a speaker segment (higher = more strict)
+/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments)
+/// - `pad_onset`: Seconds to subtract from segment start times
+/// - `pad_offset`: Seconds to add to segment end times
+/// - `min_duration_on`: Minimum segment length in seconds (filters short blips)
+/// - `min_duration_off`: Minimum gap between segments before merging
+/// - `median_window`: Smoothing window size (odd number, higher = smoother)
+///
+/// # Pre-tuned Configs
+/// - `callhome()` - (default)
+/// - `dihard3()`
+///
+/// # Custom Config
+/// Use `custom(onset, offset)` to create your own config for fine-tuning.
+///
+/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer
+#[derive(Debug, Clone)]
+pub struct DiarizationConfig {
+ pub onset: f32,
+ pub offset: f32,
+ pub pad_onset: f32,
+ pub pad_offset: f32,
+ pub min_duration_on: f32,
+ pub min_duration_off: f32,
+ pub median_window: usize,
+}
+
+impl Default for DiarizationConfig {
+ fn default() -> Self {
+ Self::callhome()
+ }
+}
+
+impl DiarizationConfig {
+ /// CallHome dataset config for v2 (default)
+ /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml
+ pub fn callhome() -> Self {
+ Self {
+ onset: 0.641,
+ offset: 0.561,
+ pad_onset: 0.229,
+ pad_offset: 0.079,
+ min_duration_on: 0.511,
+ min_duration_off: 0.296,
+ median_window: 11,
+ }
+ }
+
+ /// DIHARD3 dataset config for v2
+ /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml
+ pub fn dihard3() -> Self {
+ Self {
+ onset: 0.56,
+ offset: 1.0,
+ pad_onset: 0.063,
+ pad_offset: 0.002,
+ min_duration_on: 0.007,
+ min_duration_off: 0.151,
+ median_window: 11,
+ }
+ }
+
+ /// Create a custom config for fine-tuning diarization behavior.
+ ///
+ /// # Arguments
+ /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7)
+ /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6)
+ ///
+ /// # Example
+ /// ```rust
+ /// use parakeet_rs::sortformer::DiarizationConfig;
+ ///
+ /// // More sensitive detection (lower thresholds)
+ /// let sensitive = DiarizationConfig::custom(0.5, 0.4);
+ ///
+ /// // Stricter detection (higher thresholds, fewer false positives)
+ /// let strict = DiarizationConfig::custom(0.7, 0.6);
+ ///
+ /// // Full customization
+ /// let mut config = DiarizationConfig::custom(0.6, 0.5);
+ /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms
+ /// config.median_window = 15; // More smoothing
+ /// ```
+ pub fn custom(onset: f32, offset: f32) -> Self {
+ Self {
+ onset,
+ offset,
+ pad_onset: 0.0,
+ pad_offset: 0.0,
+ min_duration_on: 0.1,
+ min_duration_off: 0.1,
+ median_window: 11,
+ }
+ }
+}
+
+/// Speaker segment with start time, end time, and speaker ID
+#[derive(Debug, Clone)]
+pub struct SpeakerSegment {
+ pub start: f32,
+ pub end: f32,
+ pub speaker_id: usize,
+}
+
+/// Streaming Sortformer v2 speaker diarization engine
+pub struct Sortformer {
+ session: Session,
+ config: DiarizationConfig,
+ // Streaming state. note that, Same way as Nemo
+ spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM)
+ spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS)
+ fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM)
+ fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS)
+ mean_sil_emb: Array2<f32>, // (1, EMB_DIM)
+ n_sil_frames: usize,
+ // Mel filterbank (cached)
+ mel_basis: Array2<f32>,
+}
+
+impl Sortformer {
+ /// a new Sortformer instance from ONNX model path
+ pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
+ Self::with_config(model_path, None, DiarizationConfig::default())
+ }
+
+ /// Create with custom config
+ pub fn with_config<P: AsRef<Path>>(
+ model_path: P,
+ execution_config: Option<ModelConfig>,
+ config: DiarizationConfig,
+ ) -> Result<Self> {
+ let config_to_use = execution_config.unwrap_or_default();
+
+ let session = config_to_use
+ .apply_to_session_builder(Session::builder()?)?
+ .commit_from_file(model_path.as_ref())?;
+
+ let mel_basis = Self::create_mel_filterbank();
+
+ let mut instance = Self {
+ session,
+ config,
+ spkcache: Array3::zeros((1, 0, EMB_DIM)),
+ spkcache_preds: None,
+ fifo: Array3::zeros((1, 0, EMB_DIM)),
+ fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)),
+ mean_sil_emb: Array2::zeros((1, EMB_DIM)),
+ n_sil_frames: 0,
+ mel_basis,
+ };
+ instance.reset_state();
+ Ok(instance)
+ }
+
+ /// Reset streaming state
+ pub fn reset_state(&mut self) {
+ self.spkcache = Array3::zeros((1, 0, EMB_DIM));
+ self.spkcache_preds = None;
+ self.fifo = Array3::zeros((1, 0, EMB_DIM));
+ self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS));
+ self.mean_sil_emb = Array2::zeros((1, EMB_DIM));
+ self.n_sil_frames = 0;
+ }
+
+ /// Main diarization entry point
+ pub fn diarize(
+ &mut self,
+ mut audio: Vec<f32>,
+ sample_rate: u32,
+ channels: u16,
+ ) -> Result<Vec<SpeakerSegment>> {
+ // Resample if needed
+ if sample_rate != SAMPLE_RATE as u32 {
+ return Err(Error::Audio(format!(
+ "Expected {} Hz, got {} Hz",
+ SAMPLE_RATE, sample_rate
+ )));
+ }
+
+ // Convert to mono
+ if channels > 1 {
+ audio = audio
+ .chunks(channels as usize)
+ .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
+ .collect();
+ }
+
+ // Reset state for new audio
+ self.reset_state();
+
+ // Extract mel features (B, T, D)
+ let features = self.extract_mel_features(&audio);
+ let total_frames = features.shape()[1];
+
+ // Process in chunks
+ let chunk_stride = CHUNK_LEN * SUBSAMPLING;
+ let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
+
+ let mut all_chunk_preds = Vec::new();
+
+ for chunk_idx in 0..num_chunks {
+ let start = chunk_idx * chunk_stride;
+ let end = (start + chunk_stride).min(total_frames);
+ let current_len = end - start;
+
+ // Extract chunk features
+ let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
+
+ // Pad last chunk if needed
+ if current_len < chunk_stride {
+ let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
+ padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
+ chunk_feat = padded;
+ }
+
+ // Run streaming update
+ let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
+ all_chunk_preds.push(chunk_preds);
+ }
+
+ // Concatenate all predictions
+ let full_preds = Self::concat_predictions(&all_chunk_preds);
+
+ // Apply median filtering
+ let filtered_preds = if self.config.median_window > 1 {
+ self.median_filter(&full_preds)
+ } else {
+ full_preds
+ };
+
+ // Binarize to segments
+ let segments = self.binarize(&filtered_preds);
+
+ Ok(segments)
+ }
+
+ /// Streaming diarization that maintains state across calls.
+ ///
+ /// Unlike `diarize()`, this method does NOT reset the internal state,
+ /// allowing speaker embeddings to be preserved across multiple audio chunks.
+ /// Call `reset_state()` manually when starting a new audio session.
+ ///
+ /// This enables consistent speaker identification across long audio streams
+ /// by maintaining the speaker cache between processing windows.
+ ///
+ /// # Arguments
+ /// * `audio` - Audio samples (will be converted to mono if multi-channel)
+ /// * `sample_rate` - Must be 16000 Hz
+ /// * `channels` - Number of audio channels
+ ///
+ /// # Example
+ /// ```ignore
+ /// // Start of session
+ /// sortformer.reset_state();
+ ///
+ /// // Process sliding windows
+ /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?;
+ /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs
+ /// ```
+ pub fn diarize_streaming(
+ &mut self,
+ mut audio: Vec<f32>,
+ sample_rate: u32,
+ channels: u16,
+ ) -> Result<Vec<SpeakerSegment>> {
+ // Resample if needed
+ if sample_rate != SAMPLE_RATE as u32 {
+ return Err(Error::Audio(format!(
+ "Expected {} Hz, got {} Hz",
+ SAMPLE_RATE, sample_rate
+ )));
+ }
+
+ // Convert to mono
+ if channels > 1 {
+ audio = audio
+ .chunks(channels as usize)
+ .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
+ .collect();
+ }
+
+ // NOTE: Unlike diarize(), we do NOT call reset_state() here
+ // This preserves speaker embeddings across calls
+
+ // Extract mel features (B, T, D)
+ let features = self.extract_mel_features(&audio);
+ let total_frames = features.shape()[1];
+
+ // Process in chunks
+ let chunk_stride = CHUNK_LEN * SUBSAMPLING;
+ let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
+
+ let mut all_chunk_preds = Vec::new();
+
+ for chunk_idx in 0..num_chunks {
+ let start = chunk_idx * chunk_stride;
+ let end = (start + chunk_stride).min(total_frames);
+ let current_len = end - start;
+
+ // Extract chunk features
+ let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
+
+ // Pad last chunk if needed
+ if current_len < chunk_stride {
+ let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
+ padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
+ chunk_feat = padded;
+ }
+
+ // Run streaming update
+ let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
+ all_chunk_preds.push(chunk_preds);
+ }
+
+ // Concatenate all predictions
+ let full_preds = Self::concat_predictions(&all_chunk_preds);
+
+ // Apply median filtering
+ let filtered_preds = if self.config.median_window > 1 {
+ self.median_filter(&full_preds)
+ } else {
+ full_preds
+ };
+
+ // Binarize to segments
+ let segments = self.binarize(&filtered_preds);
+
+ Ok(segments)
+ }
+
+ /// NeMo's streaming_update with smart cache compression.
+ /// Public to allow incremental streaming diarization.
+ pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> {
+ let spkcache_len = self.spkcache.shape()[1];
+ let fifo_len = self.fifo.shape()[1];
+
+ // Prepare inputs
+ let chunk_lengths = Array1::from_vec(vec![current_len as i64]);
+ let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]);
+ let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]);
+
+ // Prepare FIFO input
+ let fifo_input = if fifo_len > 0 {
+ self.fifo.clone()
+ } else {
+ Array3::zeros((1, 0, EMB_DIM))
+ };
+
+ // Prepare spkcache input (may be empty)
+ let spkcache_input = if spkcache_len > 0 {
+ self.spkcache.clone()
+ } else {
+ Array3::zeros((1, 0, EMB_DIM))
+ };
+
+ // Create input values
+ let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?;
+ let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?;
+ let spkcache_value = ort::value::Value::from_array(spkcache_input)?;
+ let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?;
+ let fifo_value = ort::value::Value::from_array(fifo_input)?;
+ let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?;
+
+ // Run ONNX inference and extract all data in a block to release borrow
+ let (preds, new_embs, chunk_len) = {
+ let outputs = self.session.run(ort::inputs!(
+ "chunk" => chunk_value,
+ "chunk_lengths" => chunk_lengths_value,
+ "spkcache" => spkcache_value,
+ "spkcache_lengths" => spkcache_lengths_value,
+ "fifo" => fifo_value,
+ "fifo_lengths" => fifo_lengths_value
+ ))?;
+
+ // Extract outputs
+ let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"]
+ .try_extract_tensor::<f32>()
+ .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?;
+ let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"]
+ .try_extract_tensor::<f32>()
+ .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?;
+
+ // Convert to ndarray
+ let preds_dims = preds_shape.as_ref();
+ let embs_dims = embs_shape.as_ref();
+
+ let preds = Array3::from_shape_vec(
+ (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize),
+ preds_data.to_vec()
+ ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?;
+
+ let new_embs = Array3::from_shape_vec(
+ (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize),
+ embs_data.to_vec()
+ ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?;
+
+ // Calculate valid frames
+ let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING;
+
+ (preds, new_embs, valid_frames)
+ };
+
+ // Extract predictions for different parts
+ let fifo_preds = if fifo_len > 0 {
+ preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned()
+ } else {
+ Array2::zeros((0, NUM_SPEAKERS))
+ };
+
+ let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned();
+ let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned();
+
+ // Append chunk embeddings to FIFO
+ self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0)));
+
+ // Update FIFO predictions
+ if fifo_len > 0 {
+ let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds);
+ self.fifo_preds = combined.insert_axis(Axis(0));
+ } else {
+ self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0));
+ }
+
+ let fifo_len_after = self.fifo.shape()[1];
+
+ // Move from FIFO to cache when FIFO exceeds limit
+ if fifo_len_after > FIFO_LEN {
+ let mut pop_out_len = SPKCACHE_UPDATE_PERIOD;
+ pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len);
+ pop_out_len = pop_out_len.min(fifo_len_after);
+
+ let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned();
+ let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned();
+
+ // Update silence profile
+ self.update_silence_profile(&pop_out_embs, &pop_out_preds);
+
+ // Remove from FIFO
+ self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned();
+ self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned();
+
+ // Append to cache
+ self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs);
+
+ if let Some(ref cache_preds) = self.spkcache_preds {
+ self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds));
+ }
+
+ // Smart compression when cache exceeds limit
+ if self.spkcache.shape()[1] > SPKCACHE_LEN {
+ if self.spkcache_preds.is_none() {
+ // Initialize cache predictions from initial output
+ let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned();
+ let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds);
+ self.spkcache_preds = Some(combined);
+ }
+
+ // Use smart compression
+ self.compress_spkcache();
+ }
+ }
+
+ Ok(chunk_preds)
+ }
+
+ /// Update mean silence embedding
+ fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) {
+ let preds_2d = preds.slice(s![0, .., ..]);
+
+ for t in 0..preds_2d.shape()[0] {
+ let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum();
+ if sum < SIL_THRESHOLD {
+ // This is a silence frame
+ let emb = embs.slice(s![0, t, ..]);
+
+ // Update running mean
+ let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter()
+ .map(|&x| x * self.n_sil_frames as f32)
+ .collect();
+
+ self.n_sil_frames += 1;
+
+ for i in 0..EMB_DIM {
+ self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32;
+ }
+ }
+ }
+ }
+
+ /// Smart cache compression
+ fn compress_spkcache(&mut self) {
+ let cache_preds = match &self.spkcache_preds {
+ Some(p) => p.clone(),
+ None => return,
+ };
+
+ let n_frames = self.spkcache.shape()[1];
+ let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK;
+ let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize;
+ let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize;
+ let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize;
+
+ // Calculate quality scores
+ let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned();
+ let mut scores = self.get_log_pred_scores(&preds_2d);
+
+ // Disable low scores
+ scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk);
+
+ // Boost important frames
+ scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0);
+ scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0);
+
+ // Add silence frames placeholder
+ if SPKCACHE_SIL_FRAMES_PER_SPK > 0 {
+ let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY);
+ padded.slice_mut(s![..n_frames, ..]).assign(&scores);
+ for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK {
+ for j in 0..NUM_SPEAKERS {
+ padded[[i, j]] = f32::INFINITY;
+ }
+ }
+ scores = padded;
+ }
+
+ // Select top frames
+ let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames);
+
+ // Gather embeddings
+ let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled);
+
+ self.spkcache = new_embs;
+ self.spkcache_preds = Some(new_preds);
+ }
+
+ /// Calculate quality scores
+ fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> {
+ let mut scores = Array2::zeros(preds.dim());
+
+ for t in 0..preds.shape()[0] {
+ let mut log_1_probs_sum = 0.0f32;
+ for s in 0..NUM_SPEAKERS {
+ let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
+ let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
+ log_1_probs_sum += log_1_p;
+ }
+
+ for s in 0..NUM_SPEAKERS {
+ let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
+ let log_p = p.ln();
+ let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
+ scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln();
+ }
+ }
+
+ scores
+ }
+
+ /// Disable non-speech and overlapped speech
+ fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> {
+ // Count positive scores per speaker
+ let mut pos_count = vec![0usize; NUM_SPEAKERS];
+ for t in 0..scores.shape()[0] {
+ for s in 0..NUM_SPEAKERS {
+ if scores[[t, s]] > 0.0 {
+ pos_count[s] += 1;
+ }
+ }
+ }
+
+ for t in 0..preds.shape()[0] {
+ for s in 0..NUM_SPEAKERS {
+ let is_speech = preds[[t, s]] > 0.5;
+
+ if !is_speech {
+ scores[[t, s]] = f32::NEG_INFINITY;
+ } else {
+ let is_pos = scores[[t, s]] > 0.0;
+ if !is_pos && pos_count[s] >= min_pos_scores_per_spk {
+ scores[[t, s]] = f32::NEG_INFINITY;
+ }
+ }
+ }
+ }
+
+ scores
+ }
+
+ /// Boost top K frames per speaker
+ fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> {
+ for s in 0..NUM_SPEAKERS {
+ // Get column for this speaker
+ let col: Vec<(usize, f32)> = (0..scores.shape()[0])
+ .map(|t| (t, scores[[t, s]]))
+ .collect();
+
+ // Sort by score descending
+ let mut sorted = col.clone();
+ sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
+
+ // Boost top K
+ for i in 0..n_boost_per_spk.min(sorted.len()) {
+ let t = sorted[i].0;
+ if scores[[t, s]] != f32::NEG_INFINITY {
+ scores[[t, s]] -= scale_factor * 0.5f32.ln();
+ }
+ }
+ }
+
+ scores
+ }
+
+ /// Get indices of top frames
+ fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) {
+ let n_frames = scores.shape()[0];
+
+ // Flatten scores as (S, T) then reshape to (S*T,)
+ // This means we iterate: speaker 0 all times, then speaker 1 all times, etc.
+ // flat_index = speaker * n_frames + time
+ let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS);
+ for s in 0..NUM_SPEAKERS {
+ for t in 0..n_frames {
+ let flat_idx = s * n_frames + t;
+ flat_scores.push((flat_idx, scores[[t, s]]));
+ }
+ }
+
+ // Sort by score descending to get top-K
+ flat_scores.sort_by(|a, b| {
+ b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
+ });
+
+ // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX
+ let mut topk_flat: Vec<usize> = flat_scores
+ .iter()
+ .take(SPKCACHE_LEN)
+ .map(|(idx, score)| {
+ if *score == f32::NEG_INFINITY {
+ MAX_INDEX
+ } else {
+ *idx
+ }
+ })
+ .collect();
+
+ // Sort flat indices ascending (this puts MAX_INDEX at the end)
+ topk_flat.sort();
+
+ // Compute is_disabled and convert to frame indices
+ let mut is_disabled = vec![false; SPKCACHE_LEN];
+ let mut frame_indices = vec![0usize; SPKCACHE_LEN];
+
+ for (i, &flat_idx) in topk_flat.iter().enumerate() {
+ if flat_idx == MAX_INDEX {
+ // Invalid entries are disabled
+ is_disabled[i] = true;
+ frame_indices[i] = 0; // We set disabled to 0
+ } else {
+ // convert to frame index
+ let frame_idx = flat_idx % n_frames;
+
+ // check if frame is beyond valid range
+ if frame_idx >= n_frames_no_sil {
+ is_disabled[i] = true;
+ frame_indices[i] = 0; // same as abov: set disabled to 0
+ } else {
+ frame_indices[i] = frame_idx;
+ }
+ }
+ }
+
+ (frame_indices, is_disabled)
+ }
+
+ /// Gather selected frames
+ fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) {
+ let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM));
+ let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS));
+
+ let cache_preds = self.spkcache_preds.as_ref().unwrap();
+
+ for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() {
+ if i >= SPKCACHE_LEN {
+ break;
+ }
+
+ if disabled {
+ // Use silence embedding
+ new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..]));
+ // Predictions stay zero
+ } else if idx < self.spkcache.shape()[1] {
+ new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..]));
+ new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..]));
+ }
+ }
+
+ (new_embs, new_preds)
+ }
+
+ /// Concatenate along axis 1 for 3D arrays
+ fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
+ if a.shape()[1] == 0 {
+ return b.clone();
+ }
+ if b.shape()[1] == 0 {
+ return a.clone();
+ }
+ ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap()
+ }
+
+ /// Concatenate along axis 0 for 2D arrays
+ fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> {
+ if a.shape()[0] == 0 {
+ return b.clone();
+ }
+ if b.shape()[0] == 0 {
+ return a.clone();
+ }
+ ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
+ }
+
+ /// Concatenate predictions
+ fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> {
+ if preds.is_empty() {
+ return Array2::zeros((0, NUM_SPEAKERS));
+ }
+ if preds.len() == 1 {
+ return preds[0].clone();
+ }
+
+ let views: Vec<_> = preds.iter().map(|p| p.view()).collect();
+ ndarray::concatenate(Axis(0), &views).unwrap()
+ }
+
+ /// Apply median filter to predictions
+ fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> {
+ let window = self.config.median_window;
+ let half = window / 2;
+ let mut filtered = preds.clone();
+
+ for spk in 0..NUM_SPEAKERS {
+ for t in 0..preds.shape()[0] {
+ let start = t.saturating_sub(half);
+ let end = (t + half + 1).min(preds.shape()[0]);
+
+ let mut values: Vec<f32> = (start..end)
+ .map(|i| preds[[i, spk]])
+ .collect();
+ values.sort_by(|a, b| a.partial_cmp(b).unwrap());
+
+ filtered[[t, spk]] = values[values.len() / 2];
+ }
+ }
+
+ filtered
+ }
+
+ /// Binarize predictions to segments (padding applied during thresholding)
+ fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> {
+ let mut segments = Vec::new();
+ let num_frames = preds.shape()[0];
+
+ for spk in 0..NUM_SPEAKERS {
+ let mut in_seg = false;
+ let mut seg_start = 0;
+ let mut temp_segments = Vec::new();
+
+ for t in 0..num_frames {
+ let p = preds[[t, spk]];
+
+ if p >= self.config.onset && !in_seg {
+ in_seg = true;
+ seg_start = t;
+ } else if p < self.config.offset && in_seg {
+ in_seg = false;
+
+ // Apply padding during conversion
+ let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
+ let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset;
+
+ if end_t - start_t >= self.config.min_duration_on {
+ temp_segments.push(SpeakerSegment {
+ start: start_t,
+ end: end_t,
+ speaker_id: spk,
+ });
+ }
+ }
+ }
+
+ // Handle segment at end
+ if in_seg {
+ let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
+ let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset;
+
+ if end_t - start_t >= self.config.min_duration_on {
+ temp_segments.push(SpeakerSegment {
+ start: start_t,
+ end: end_t,
+ speaker_id: spk,
+ });
+ }
+ }
+
+ // Merge close segments (min_duration_off)
+ if temp_segments.len() > 1 {
+ let mut filtered = vec![temp_segments[0].clone()];
+ for seg in temp_segments.into_iter().skip(1) {
+ let last = filtered.last_mut().unwrap();
+ let gap = seg.start - last.end;
+ if gap < self.config.min_duration_off {
+ last.end = seg.end; // Merge
+ } else {
+ filtered.push(seg);
+ }
+ }
+ segments.extend(filtered);
+ } else {
+ segments.extend(temp_segments);
+ }
+ }
+
+ // Sort by start time
+ segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap());
+ segments
+ }
+
+
+ fn apply_preemphasis(audio: &[f32]) -> Vec<f32> {
+ let mut result = Vec::with_capacity(audio.len());
+ result.push(audio[0]);
+ for i in 1..audio.len() {
+ result.push(audio[i] - PREEMPH * audio[i - 1]);
+ }
+ result
+ }
+
+ fn hann_window(window_length: usize) -> Vec<f32> {
+ // Librosa uses periodic window (fftbins=True): divide by N, not N-1
+ (0..window_length)
+ .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos())
+ .collect()
+ }
+
+ fn stft(audio: &[f32]) -> Array2<f32> {
+ let mut planner = FftPlanner::<f32>::new();
+ let fft = planner.plan_fft_forward(N_FFT);
+
+ // Create Hann window of length win_length, then zero-pad to n_fft (centered)
+ // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft)
+ let hann = Self::hann_window(WIN_LENGTH);
+ let win_offset = (N_FFT - WIN_LENGTH) / 2;
+ let mut fft_window = vec![0.0f32; N_FFT];
+ for i in 0..WIN_LENGTH {
+ fft_window[win_offset + i] = hann[i];
+ }
+
+ // Pad signal for center=True (like librosa/torch.stft)
+ // Padding is n_fft // 2 on each side
+ let pad_amount = N_FFT / 2;
+ let mut padded_audio = vec![0.0; pad_amount];
+ padded_audio.extend_from_slice(audio);
+ padded_audio.extend(vec![0.0; pad_amount]);
+
+ let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1;
+ let freq_bins = N_FFT / 2 + 1;
+ let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
+
+ for frame_idx in 0..num_frames {
+ let start = frame_idx * HOP_LENGTH;
+ let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT];
+
+ // Extract n_fft samples and multiply by zero-padded window
+ for i in 0..N_FFT {
+ if start + i < padded_audio.len() {
+ frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0);
+ }
+ }
+
+ fft.process(&mut frame);
+ for k in 0..freq_bins {
+ let magnitude = frame[k].norm();
+ // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0
+ spectrogram[[k, frame_idx]] = magnitude * magnitude;
+ }
+ }
+
+ spectrogram
+ }
+
+ // Librosa's Slaney mel scale (htk=False, which is the default)
+ fn hz_to_mel_slaney(hz: f64) -> f64 {
+ let f_min = 0.0;
+ let f_sp = 200.0 / 3.0;
+ let min_log_hz = 1000.0;
+ let min_log_mel = (min_log_hz - f_min) / f_sp;
+ let logstep = (6.4f64).ln() / 27.0;
+
+ if hz >= min_log_hz {
+ min_log_mel + (hz / min_log_hz).ln() / logstep
+ } else {
+ (hz - f_min) / f_sp
+ }
+ }
+
+ fn mel_to_hz_slaney(mel: f64) -> f64 {
+ let f_min = 0.0;
+ let f_sp = 200.0 / 3.0;
+ let min_log_hz = 1000.0;
+ let min_log_mel = (min_log_hz - f_min) / f_sp;
+ let logstep = (6.4f64).ln() / 27.0;
+
+ if mel >= min_log_mel {
+ min_log_hz * (logstep * (mel - min_log_mel)).exp()
+ } else {
+ f_min + f_sp * mel
+ }
+ }
+
+ fn create_mel_filterbank() -> Array2<f32> {
+ // lets use f64 for intermediate calculations to avoid precision loss
+ let freq_bins = N_FFT / 2 + 1;
+ let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins));
+
+ // FFT frequencies: fftfreqs[k] = k * sr / n_fft
+ let fftfreqs: Vec<f64> = (0..freq_bins)
+ .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64)
+ .collect();
+
+ // Mel center frequencies using Slaney scale (librosa default, htk=False)
+ let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64);
+ let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64);
+ let mel_f: Vec<f64> = (0..=N_MELS + 1)
+ .map(|i| {
+ let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64;
+ Self::mel_to_hz_slaney(mel)
+ })
+ .collect();
+
+ // Differences between consecutive mel frequencies
+ let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect();
+
+ // Compute filterbank weights (reference: librosa's ramp method)
+ // https://librosa.org/doc/main/generated/librosa.stft.html
+ for i in 0..N_MELS {
+ for k in 0..freq_bins {
+ // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i]
+ let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i];
+ // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1]
+ let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1];
+ // Weight is max(0, min(lower, upper))
+ filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32;
+ }
+ }
+
+ // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i])
+ for i in 0..N_MELS {
+ let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]);
+ for k in 0..freq_bins {
+ filterbank[[i, k]] *= enorm as f32;
+ }
+ }
+
+ filterbank
+ }
+
+ fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> {
+ // 1. Add dither (small random noise to prevent log(0))
+ // NeMo uses dither=1e-5, but for determinism we skip random noise
+ // The log_zero_guard handles zero values
+
+ // 2. Apply preemphasis (NeMo uses preemph=0.97)
+ let preemphasized = Self::apply_preemphasis(audio);
+
+ // 3. STFT
+ let spectrogram = Self::stft(&preemphasized);
+
+ // 4. Apply mel filterbank (with Slaney normalization)
+ let mel_spec = self.mel_basis.dot(&spectrogram);
+
+ // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24)
+ // NeMo uses normalize='NA' which means NO normalization
+ let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln());
+
+ let num_frames = log_mel_spec.shape()[1];
+ let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS));
+
+ // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D)
+ for t in 0..num_frames {
+ for m in 0..N_MELS {
+ features[[0, t, m]] = log_mel_spec[[m, t]];
+ }
+ }
+
+ features
+ }
+}