//! NVIDIA Sortformer v2 Streaming Speaker Diarization
//!
//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization.
//!
//! Key features:
//! - Streaming inference with ~10s chunks (124 frames at 80ms each)
//! - FIFO buffer for context management
//! - Smart speaker cache compression (keeps important frames, not just recent)
//! - Silence profile tracking
//! - Post-processing: median filtering, hysteresis thresholding
//! - Supports up to 4 speakers
//!
//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
//! Note that, my ONNX export:
//! CHUNK_LEN = 124
//! FIFO_LEN = 124
//! CACHE_LEN = 188
//! FEAT_DIM = 128
//! EMB_DIM = 512
//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html
use crate::error::{Error, Result};
use crate::execution::ModelConfig;
use ndarray::{s, Array1, Array2, Array3, Axis};
use ort::session::Session;
use rustfft::{num_complex::Complex, FftPlanner};
use std::f32::consts::PI;
use std::path::Path;
// Model constants
const N_FFT: usize = 512;
const WIN_LENGTH: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
const PREEMPH: f32 = 0.97;
const LOG_ZERO_GUARD: f32 = 5.960464478e-8;
const SAMPLE_RATE: usize = 16000;
const FMIN: f32 = 0.0;
const FMAX: f32 = 8000.0;
// Streaming constants
const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms)
const FIFO_LEN: usize = 124; // FIFO buffer length
const SPKCACHE_LEN: usize = 188; // Speaker cache length
const SPKCACHE_UPDATE_PERIOD: usize = 124;
const SUBSAMPLING: usize = 8; // Audio frames -> model frames
const EMB_DIM: usize = 512; // Embedding dimension
const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers
const FRAME_DURATION: f32 = 0.08; // 80ms per frame
// Cache compression params (from NeMo)
const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3;
const PRED_SCORE_THRESHOLD: f32 = 0.25;
const STRONG_BOOST_RATE: f32 = 0.75;
const WEAK_BOOST_RATE: f32 = 1.5;
const MIN_POS_SCORES_RATE: f32 = 0.5;
const SIL_THRESHOLD: f32 = 0.2;
const MAX_INDEX: usize = 99999;
/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs)
///
/// Controls how raw model predictions are converted into speaker segments.
/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI).
///
/// # Parameters
/// - `onset`: Probability threshold to START a speaker segment (higher = more strict)
/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments)
/// - `pad_onset`: Seconds to subtract from segment start times
/// - `pad_offset`: Seconds to add to segment end times
/// - `min_duration_on`: Minimum segment length in seconds (filters short blips)
/// - `min_duration_off`: Minimum gap between segments before merging
/// - `median_window`: Smoothing window size (odd number, higher = smoother)
///
/// # Pre-tuned Configs
/// - `callhome()` - (default)
/// - `dihard3()`
///
/// # Custom Config
/// Use `custom(onset, offset)` to create your own config for fine-tuning.
///
/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer
#[derive(Debug, Clone)]
pub struct DiarizationConfig {
pub onset: f32,
pub offset: f32,
pub pad_onset: f32,
pub pad_offset: f32,
pub min_duration_on: f32,
pub min_duration_off: f32,
pub median_window: usize,
}
impl Default for DiarizationConfig {
fn default() -> Self {
Self::callhome()
}
}
impl DiarizationConfig {
/// CallHome dataset config for v2 (default)
/// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml
pub fn callhome() -> Self {
Self {
onset: 0.641,
offset: 0.561,
pad_onset: 0.229,
pad_offset: 0.079,
min_duration_on: 0.511,
min_duration_off: 0.296,
median_window: 11,
}
}
/// DIHARD3 dataset config for v2
/// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml
pub fn dihard3() -> Self {
Self {
onset: 0.56,
offset: 1.0,
pad_onset: 0.063,
pad_offset: 0.002,
min_duration_on: 0.007,
min_duration_off: 0.151,
median_window: 11,
}
}
/// Create a custom config for fine-tuning diarization behavior.
///
/// # Arguments
/// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7)
/// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6)
///
/// # Example
/// ```rust
/// use parakeet_rs::sortformer::DiarizationConfig;
///
/// // More sensitive detection (lower thresholds)
/// let sensitive = DiarizationConfig::custom(0.5, 0.4);
///
/// // Stricter detection (higher thresholds, fewer false positives)
/// let strict = DiarizationConfig::custom(0.7, 0.6);
///
/// // Full customization
/// let mut config = DiarizationConfig::custom(0.6, 0.5);
/// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms
/// config.median_window = 15; // More smoothing
/// ```
pub fn custom(onset: f32, offset: f32) -> Self {
Self {
onset,
offset,
pad_onset: 0.0,
pad_offset: 0.0,
min_duration_on: 0.1,
min_duration_off: 0.1,
median_window: 11,
}
}
}
/// Speaker segment with start time, end time, and speaker ID
#[derive(Debug, Clone)]
pub struct SpeakerSegment {
pub start: f32,
pub end: f32,
pub speaker_id: usize,
}
/// Streaming Sortformer v2 speaker diarization engine
pub struct Sortformer {
session: Session,
config: DiarizationConfig,
// Streaming state. note that, Same way as Nemo
spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM)
spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS)
fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM)
fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS)
mean_sil_emb: Array2<f32>, // (1, EMB_DIM)
n_sil_frames: usize,
// Mel filterbank (cached)
mel_basis: Array2<f32>,
}
impl Sortformer {
/// a new Sortformer instance from ONNX model path
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
Self::with_config(model_path, None, DiarizationConfig::default())
}
/// Create with custom config
pub fn with_config<P: AsRef<Path>>(
model_path: P,
execution_config: Option<ModelConfig>,
config: DiarizationConfig,
) -> Result<Self> {
let config_to_use = execution_config.unwrap_or_default();
let session = config_to_use
.apply_to_session_builder(Session::builder()?)?
.commit_from_file(model_path.as_ref())?;
let mel_basis = Self::create_mel_filterbank();
let mut instance = Self {
session,
config,
spkcache: Array3::zeros((1, 0, EMB_DIM)),
spkcache_preds: None,
fifo: Array3::zeros((1, 0, EMB_DIM)),
fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)),
mean_sil_emb: Array2::zeros((1, EMB_DIM)),
n_sil_frames: 0,
mel_basis,
};
instance.reset_state();
Ok(instance)
}
/// Reset streaming state
pub fn reset_state(&mut self) {
self.spkcache = Array3::zeros((1, 0, EMB_DIM));
self.spkcache_preds = None;
self.fifo = Array3::zeros((1, 0, EMB_DIM));
self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS));
self.mean_sil_emb = Array2::zeros((1, EMB_DIM));
self.n_sil_frames = 0;
}
/// Main diarization entry point
pub fn diarize(
&mut self,
mut audio: Vec<f32>,
sample_rate: u32,
channels: u16,
) -> Result<Vec<SpeakerSegment>> {
// Resample if needed
if sample_rate != SAMPLE_RATE as u32 {
return Err(Error::Audio(format!(
"Expected {} Hz, got {} Hz",
SAMPLE_RATE, sample_rate
)));
}
// Convert to mono
if channels > 1 {
audio = audio
.chunks(channels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect();
}
// Reset state for new audio
self.reset_state();
// Extract mel features (B, T, D)
let features = self.extract_mel_features(&audio);
let total_frames = features.shape()[1];
// Process in chunks
let chunk_stride = CHUNK_LEN * SUBSAMPLING;
let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
let mut all_chunk_preds = Vec::new();
for chunk_idx in 0..num_chunks {
let start = chunk_idx * chunk_stride;
let end = (start + chunk_stride).min(total_frames);
let current_len = end - start;
// Extract chunk features
let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
// Pad last chunk if needed
if current_len < chunk_stride {
let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
chunk_feat = padded;
}
// Run streaming update
let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
all_chunk_preds.push(chunk_preds);
}
// Concatenate all predictions
let full_preds = Self::concat_predictions(&all_chunk_preds);
// Apply median filtering
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&full_preds)
} else {
full_preds
};
// Binarize to segments
let segments = self.binarize(&filtered_preds);
Ok(segments)
}
/// Streaming diarization that maintains state across calls.
///
/// Unlike `diarize()`, this method does NOT reset the internal state,
/// allowing speaker embeddings to be preserved across multiple audio chunks.
/// Call `reset_state()` manually when starting a new audio session.
///
/// This enables consistent speaker identification across long audio streams
/// by maintaining the speaker cache between processing windows.
///
/// # Arguments
/// * `audio` - Audio samples (will be converted to mono if multi-channel)
/// * `sample_rate` - Must be 16000 Hz
/// * `channels` - Number of audio channels
///
/// # Example
/// ```ignore
/// // Start of session
/// sortformer.reset_state();
///
/// // Process sliding windows
/// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?;
/// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs
/// ```
pub fn diarize_streaming(
&mut self,
mut audio: Vec<f32>,
sample_rate: u32,
channels: u16,
) -> Result<Vec<SpeakerSegment>> {
// Resample if needed
if sample_rate != SAMPLE_RATE as u32 {
return Err(Error::Audio(format!(
"Expected {} Hz, got {} Hz",
SAMPLE_RATE, sample_rate
)));
}
// Convert to mono
if channels > 1 {
audio = audio
.chunks(channels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect();
}
// NOTE: Unlike diarize(), we do NOT call reset_state() here
// This preserves speaker embeddings across calls
// Extract mel features (B, T, D)
let features = self.extract_mel_features(&audio);
let total_frames = features.shape()[1];
// Process in chunks
let chunk_stride = CHUNK_LEN * SUBSAMPLING;
let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride;
let mut all_chunk_preds = Vec::new();
for chunk_idx in 0..num_chunks {
let start = chunk_idx * chunk_stride;
let end = (start + chunk_stride).min(total_frames);
let current_len = end - start;
// Extract chunk features
let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned();
// Pad last chunk if needed
if current_len < chunk_stride {
let mut padded = Array3::zeros((1, chunk_stride, N_MELS));
padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat);
chunk_feat = padded;
}
// Run streaming update
let chunk_preds = self.streaming_update(&chunk_feat, current_len)?;
all_chunk_preds.push(chunk_preds);
}
// Concatenate all predictions
let full_preds = Self::concat_predictions(&all_chunk_preds);
// Apply median filtering
let filtered_preds = if self.config.median_window > 1 {
self.median_filter(&full_preds)
} else {
full_preds
};
// Binarize to segments
let segments = self.binarize(&filtered_preds);
Ok(segments)
}
/// NeMo's streaming_update with smart cache compression.
/// Public to allow incremental streaming diarization.
pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> {
let spkcache_len = self.spkcache.shape()[1];
let fifo_len = self.fifo.shape()[1];
// Prepare inputs
let chunk_lengths = Array1::from_vec(vec![current_len as i64]);
let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]);
let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]);
// Prepare FIFO input
let fifo_input = if fifo_len > 0 {
self.fifo.clone()
} else {
Array3::zeros((1, 0, EMB_DIM))
};
// Prepare spkcache input (may be empty)
let spkcache_input = if spkcache_len > 0 {
self.spkcache.clone()
} else {
Array3::zeros((1, 0, EMB_DIM))
};
// Create input values
let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?;
let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?;
let spkcache_value = ort::value::Value::from_array(spkcache_input)?;
let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?;
let fifo_value = ort::value::Value::from_array(fifo_input)?;
let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?;
// Run ONNX inference and extract all data in a block to release borrow
let (preds, new_embs, chunk_len) = {
let outputs = self.session.run(ort::inputs!(
"chunk" => chunk_value,
"chunk_lengths" => chunk_lengths_value,
"spkcache" => spkcache_value,
"spkcache_lengths" => spkcache_lengths_value,
"fifo" => fifo_value,
"fifo_lengths" => fifo_lengths_value
))?;
// Extract outputs
let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?;
let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?;
// Convert to ndarray
let preds_dims = preds_shape.as_ref();
let embs_dims = embs_shape.as_ref();
let preds = Array3::from_shape_vec(
(preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize),
preds_data.to_vec()
).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?;
let new_embs = Array3::from_shape_vec(
(embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize),
embs_data.to_vec()
).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?;
// Calculate valid frames
let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING;
(preds, new_embs, valid_frames)
};
// Extract predictions for different parts
let fifo_preds = if fifo_len > 0 {
preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned()
} else {
Array2::zeros((0, NUM_SPEAKERS))
};
let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned();
let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned();
// Append chunk embeddings to FIFO
self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0)));
// Update FIFO predictions
if fifo_len > 0 {
let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds);
self.fifo_preds = combined.insert_axis(Axis(0));
} else {
self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0));
}
let fifo_len_after = self.fifo.shape()[1];
// Move from FIFO to cache when FIFO exceeds limit
if fifo_len_after > FIFO_LEN {
let mut pop_out_len = SPKCACHE_UPDATE_PERIOD;
pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len);
pop_out_len = pop_out_len.min(fifo_len_after);
let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned();
let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned();
// Update silence profile
self.update_silence_profile(&pop_out_embs, &pop_out_preds);
// Remove from FIFO
self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned();
self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned();
// Append to cache
self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs);
if let Some(ref cache_preds) = self.spkcache_preds {
self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds));
}
// Smart compression when cache exceeds limit
if self.spkcache.shape()[1] > SPKCACHE_LEN {
if self.spkcache_preds.is_none() {
// Initialize cache predictions from initial output
let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned();
let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds);
self.spkcache_preds = Some(combined);
}
// Use smart compression
self.compress_spkcache();
}
}
Ok(chunk_preds)
}
/// Update mean silence embedding
fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) {
let preds_2d = preds.slice(s![0, .., ..]);
for t in 0..preds_2d.shape()[0] {
let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum();
if sum < SIL_THRESHOLD {
// This is a silence frame
let emb = embs.slice(s![0, t, ..]);
// Update running mean
let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter()
.map(|&x| x * self.n_sil_frames as f32)
.collect();
self.n_sil_frames += 1;
for i in 0..EMB_DIM {
self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32;
}
}
}
}
/// Smart cache compression
fn compress_spkcache(&mut self) {
let cache_preds = match &self.spkcache_preds {
Some(p) => p.clone(),
None => return,
};
let n_frames = self.spkcache.shape()[1];
let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK;
let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize;
let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize;
let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize;
// Calculate quality scores
let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned();
let mut scores = self.get_log_pred_scores(&preds_2d);
// Disable low scores
scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk);
// Boost important frames
scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0);
scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0);
// Add silence frames placeholder
if SPKCACHE_SIL_FRAMES_PER_SPK > 0 {
let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY);
padded.slice_mut(s![..n_frames, ..]).assign(&scores);
for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK {
for j in 0..NUM_SPEAKERS {
padded[[i, j]] = f32::INFINITY;
}
}
scores = padded;
}
// Select top frames
let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames);
// Gather embeddings
let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled);
self.spkcache = new_embs;
self.spkcache_preds = Some(new_preds);
}
/// Calculate quality scores
fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> {
let mut scores = Array2::zeros(preds.dim());
for t in 0..preds.shape()[0] {
let mut log_1_probs_sum = 0.0f32;
for s in 0..NUM_SPEAKERS {
let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
log_1_probs_sum += log_1_p;
}
for s in 0..NUM_SPEAKERS {
let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD);
let log_p = p.ln();
let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln();
scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln();
}
}
scores
}
/// Disable non-speech and overlapped speech
fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> {
// Count positive scores per speaker
let mut pos_count = vec![0usize; NUM_SPEAKERS];
for t in 0..scores.shape()[0] {
for s in 0..NUM_SPEAKERS {
if scores[[t, s]] > 0.0 {
pos_count[s] += 1;
}
}
}
for t in 0..preds.shape()[0] {
for s in 0..NUM_SPEAKERS {
let is_speech = preds[[t, s]] > 0.5;
if !is_speech {
scores[[t, s]] = f32::NEG_INFINITY;
} else {
let is_pos = scores[[t, s]] > 0.0;
if !is_pos && pos_count[s] >= min_pos_scores_per_spk {
scores[[t, s]] = f32::NEG_INFINITY;
}
}
}
}
scores
}
/// Boost top K frames per speaker
fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> {
for s in 0..NUM_SPEAKERS {
// Get column for this speaker
let col: Vec<(usize, f32)> = (0..scores.shape()[0])
.map(|t| (t, scores[[t, s]]))
.collect();
// Sort by score descending
let mut sorted = col.clone();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Boost top K
for i in 0..n_boost_per_spk.min(sorted.len()) {
let t = sorted[i].0;
if scores[[t, s]] != f32::NEG_INFINITY {
scores[[t, s]] -= scale_factor * 0.5f32.ln();
}
}
}
scores
}
/// Get indices of top frames
fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) {
let n_frames = scores.shape()[0];
// Flatten scores as (S, T) then reshape to (S*T,)
// This means we iterate: speaker 0 all times, then speaker 1 all times, etc.
// flat_index = speaker * n_frames + time
let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS);
for s in 0..NUM_SPEAKERS {
for t in 0..n_frames {
let flat_idx = s * n_frames + t;
flat_scores.push((flat_idx, scores[[t, s]]));
}
}
// Sort by score descending to get top-K
flat_scores.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
// Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX
let mut topk_flat: Vec<usize> = flat_scores
.iter()
.take(SPKCACHE_LEN)
.map(|(idx, score)| {
if *score == f32::NEG_INFINITY {
MAX_INDEX
} else {
*idx
}
})
.collect();
// Sort flat indices ascending (this puts MAX_INDEX at the end)
topk_flat.sort();
// Compute is_disabled and convert to frame indices
let mut is_disabled = vec![false; SPKCACHE_LEN];
let mut frame_indices = vec![0usize; SPKCACHE_LEN];
for (i, &flat_idx) in topk_flat.iter().enumerate() {
if flat_idx == MAX_INDEX {
// Invalid entries are disabled
is_disabled[i] = true;
frame_indices[i] = 0; // We set disabled to 0
} else {
// convert to frame index
let frame_idx = flat_idx % n_frames;
// check if frame is beyond valid range
if frame_idx >= n_frames_no_sil {
is_disabled[i] = true;
frame_indices[i] = 0; // same as abov: set disabled to 0
} else {
frame_indices[i] = frame_idx;
}
}
}
(frame_indices, is_disabled)
}
/// Gather selected frames
fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) {
let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM));
let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS));
let cache_preds = self.spkcache_preds.as_ref().unwrap();
for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() {
if i >= SPKCACHE_LEN {
break;
}
if disabled {
// Use silence embedding
new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..]));
// Predictions stay zero
} else if idx < self.spkcache.shape()[1] {
new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..]));
new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..]));
}
}
(new_embs, new_preds)
}
/// Concatenate along axis 1 for 3D arrays
fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> {
if a.shape()[1] == 0 {
return b.clone();
}
if b.shape()[1] == 0 {
return a.clone();
}
ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap()
}
/// Concatenate along axis 0 for 2D arrays
fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> {
if a.shape()[0] == 0 {
return b.clone();
}
if b.shape()[0] == 0 {
return a.clone();
}
ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap()
}
/// Concatenate predictions
fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> {
if preds.is_empty() {
return Array2::zeros((0, NUM_SPEAKERS));
}
if preds.len() == 1 {
return preds[0].clone();
}
let views: Vec<_> = preds.iter().map(|p| p.view()).collect();
ndarray::concatenate(Axis(0), &views).unwrap()
}
/// Apply median filter to predictions
fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> {
let window = self.config.median_window;
let half = window / 2;
let mut filtered = preds.clone();
for spk in 0..NUM_SPEAKERS {
for t in 0..preds.shape()[0] {
let start = t.saturating_sub(half);
let end = (t + half + 1).min(preds.shape()[0]);
let mut values: Vec<f32> = (start..end)
.map(|i| preds[[i, spk]])
.collect();
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
filtered[[t, spk]] = values[values.len() / 2];
}
}
filtered
}
/// Binarize predictions to segments (padding applied during thresholding)
fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> {
let mut segments = Vec::new();
let num_frames = preds.shape()[0];
for spk in 0..NUM_SPEAKERS {
let mut in_seg = false;
let mut seg_start = 0;
let mut temp_segments = Vec::new();
for t in 0..num_frames {
let p = preds[[t, spk]];
if p >= self.config.onset && !in_seg {
in_seg = true;
seg_start = t;
} else if p < self.config.offset && in_seg {
in_seg = false;
// Apply padding during conversion
let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset;
if end_t - start_t >= self.config.min_duration_on {
temp_segments.push(SpeakerSegment {
start: start_t,
end: end_t,
speaker_id: spk,
});
}
}
}
// Handle segment at end
if in_seg {
let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0);
let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset;
if end_t - start_t >= self.config.min_duration_on {
temp_segments.push(SpeakerSegment {
start: start_t,
end: end_t,
speaker_id: spk,
});
}
}
// Merge close segments (min_duration_off)
if temp_segments.len() > 1 {
let mut filtered = vec![temp_segments[0].clone()];
for seg in temp_segments.into_iter().skip(1) {
let last = filtered.last_mut().unwrap();
let gap = seg.start - last.end;
if gap < self.config.min_duration_off {
last.end = seg.end; // Merge
} else {
filtered.push(seg);
}
}
segments.extend(filtered);
} else {
segments.extend(temp_segments);
}
}
// Sort by start time
segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap());
segments
}
fn apply_preemphasis(audio: &[f32]) -> Vec<f32> {
let mut result = Vec::with_capacity(audio.len());
result.push(audio[0]);
for i in 1..audio.len() {
result.push(audio[i] - PREEMPH * audio[i - 1]);
}
result
}
fn hann_window(window_length: usize) -> Vec<f32> {
// Librosa uses periodic window (fftbins=True): divide by N, not N-1
(0..window_length)
.map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos())
.collect()
}
fn stft(audio: &[f32]) -> Array2<f32> {
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
// Create Hann window of length win_length, then zero-pad to n_fft (centered)
// This is exactly what librosa does: util.pad_center(fft_window, size=n_fft)
let hann = Self::hann_window(WIN_LENGTH);
let win_offset = (N_FFT - WIN_LENGTH) / 2;
let mut fft_window = vec![0.0f32; N_FFT];
for i in 0..WIN_LENGTH {
fft_window[win_offset + i] = hann[i];
}
// Pad signal for center=True (like librosa/torch.stft)
// Padding is n_fft // 2 on each side
let pad_amount = N_FFT / 2;
let mut padded_audio = vec![0.0; pad_amount];
padded_audio.extend_from_slice(audio);
padded_audio.extend(vec![0.0; pad_amount]);
let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1;
let freq_bins = N_FFT / 2 + 1;
let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames));
for frame_idx in 0..num_frames {
let start = frame_idx * HOP_LENGTH;
let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT];
// Extract n_fft samples and multiply by zero-padded window
for i in 0..N_FFT {
if start + i < padded_audio.len() {
frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0);
}
}
fft.process(&mut frame);
for k in 0..freq_bins {
let magnitude = frame[k].norm();
// Power spectrum (magnitude^2) - NeMo uses mag_power=2.0
spectrogram[[k, frame_idx]] = magnitude * magnitude;
}
}
spectrogram
}
// Librosa's Slaney mel scale (htk=False, which is the default)
fn hz_to_mel_slaney(hz: f64) -> f64 {
let f_min = 0.0;
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0;
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (6.4f64).ln() / 27.0;
if hz >= min_log_hz {
min_log_mel + (hz / min_log_hz).ln() / logstep
} else {
(hz - f_min) / f_sp
}
}
fn mel_to_hz_slaney(mel: f64) -> f64 {
let f_min = 0.0;
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0;
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (6.4f64).ln() / 27.0;
if mel >= min_log_mel {
min_log_hz * (logstep * (mel - min_log_mel)).exp()
} else {
f_min + f_sp * mel
}
}
fn create_mel_filterbank() -> Array2<f32> {
// lets use f64 for intermediate calculations to avoid precision loss
let freq_bins = N_FFT / 2 + 1;
let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins));
// FFT frequencies: fftfreqs[k] = k * sr / n_fft
let fftfreqs: Vec<f64> = (0..freq_bins)
.map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64)
.collect();
// Mel center frequencies using Slaney scale (librosa default, htk=False)
let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64);
let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64);
let mel_f: Vec<f64> = (0..=N_MELS + 1)
.map(|i| {
let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64;
Self::mel_to_hz_slaney(mel)
})
.collect();
// Differences between consecutive mel frequencies
let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect();
// Compute filterbank weights (reference: librosa's ramp method)
// https://librosa.org/doc/main/generated/librosa.stft.html
for i in 0..N_MELS {
for k in 0..freq_bins {
// Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i]
let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i];
// Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1]
let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1];
// Weight is max(0, min(lower, upper))
filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32;
}
}
// Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i])
for i in 0..N_MELS {
let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]);
for k in 0..freq_bins {
filterbank[[i, k]] *= enorm as f32;
}
}
filterbank
}
fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> {
// 1. Add dither (small random noise to prevent log(0))
// NeMo uses dither=1e-5, but for determinism we skip random noise
// The log_zero_guard handles zero values
// 2. Apply preemphasis (NeMo uses preemph=0.97)
let preemphasized = Self::apply_preemphasis(audio);
// 3. STFT
let spectrogram = Self::stft(&preemphasized);
// 4. Apply mel filterbank (with Slaney normalization)
let mel_spec = self.mel_basis.dot(&spectrogram);
// 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24)
// NeMo uses normalize='NA' which means NO normalization
let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln());
let num_frames = log_mel_spec.shape()[1];
let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS));
// Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D)
for t in 0..num_frames {
for m in 0..N_MELS {
features[[0, t, m]] = log_mel_spec[[m, t]];
}
}
features
}
}