summaryrefslogtreecommitdiff
path: root/vendor/parakeet-rs/examples
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/parakeet-rs/examples')
-rw-r--r--vendor/parakeet-rs/examples/diarization.rs137
-rw-r--r--vendor/parakeet-rs/examples/raw.rs86
-rw-r--r--vendor/parakeet-rs/examples/streaming.rs129
-rw-r--r--vendor/parakeet-rs/examples/transcribe.rs106
4 files changed, 458 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/examples/diarization.rs b/vendor/parakeet-rs/examples/diarization.rs
new file mode 100644
index 0000000..5982ecb
--- /dev/null
+++ b/vendor/parakeet-rs/examples/diarization.rs
@@ -0,0 +1,137 @@
+/*
+Speaker Diarization with NVIDIA Sortformer v2 (Streaming)
+
+Download the Sortformer v2 model:
+https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx
+Download test audio:
+wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav
+
+Usage:
+cargo run --example diarization --features sortformer 6_speakers.wav
+
+NOTE: This example combines two NVIDIA models:
+- Parakeet-TDT: Provides transcription with sentence-level timestamps
+- Sortformer v2: Provides streaming speaker identification (4 speakers max)
+- We use TDT's sentence timestamps + Sortformer's speaker IDs
+- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN)
+- For more information:
+https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
+
+WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence
+length limitations (~8-10 minutes max). For production use with long audio files,
+run Sortformer on the full audio for diarization, then chunk the audio into
+~5-minute segments for TDT transcription, and map the results back together.
+*/
+
+#[cfg(feature = "sortformer")]
+use parakeet_rs::sortformer::{DiarizationConfig, Sortformer};
+#[cfg(feature = "sortformer")]
+use parakeet_rs::TimestampMode;
+#[cfg(feature = "sortformer")]
+use hound;
+#[cfg(feature = "sortformer")]
+use std::env;
+#[cfg(feature = "sortformer")]
+use std::time::Instant;
+
+#[allow(unreachable_code)]
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+ #[cfg(not(feature = "sortformer"))]
+ {
+ eprintln!("Error: This example requires the 'sortformer' feature.");
+ eprintln!("Please run with: cargo run --example diarization --features sortformer <audio.wav>");
+ return Err("sortformer feature not enabled".into());
+ }
+
+ #[cfg(feature = "sortformer")]
+ {
+ let start_time = Instant::now();
+ let args: Vec<String> = env::args().collect();
+ let audio_path = args.get(1)
+ .expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>");
+
+ println!("{}", "=".repeat(80));
+ println!("Step 1/3: Loading audio...");
+
+ let mut reader = hound::WavReader::open(audio_path)?;
+ let spec = reader.spec();
+
+ let audio: Vec<f32> = match spec.sample_format {
+ hound::SampleFormat::Float => reader
+ .samples::<f32>()
+ .collect::<Result<Vec<_>, _>>()?,
+ hound::SampleFormat::Int => reader
+ .samples::<i16>()
+ .map(|s| s.map(|s| s as f32 / 32768.0))
+ .collect::<Result<Vec<_>, _>>()?,
+ };
+
+ let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32;
+ println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)",
+ audio.len(), spec.sample_rate, spec.channels, duration);
+
+ println!("{}", "=".repeat(80));
+ println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)...");
+
+ // Create Sortformer with default config (callhome)
+ let mut sortformer = Sortformer::with_config(
+ "diar_streaming_sortformer_4spk-v2.onnx",
+ None, // default exec config
+ DiarizationConfig::callhome(),
+ )?;
+
+ let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?;
+
+ println!("Found {} speaker segments from Sortformer", speaker_segments.len());
+
+ // Print raw diarization segments
+ println!("\nRaw diarization segments:");
+ for seg in &speaker_segments {
+ println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id);
+ }
+
+ println!("\n{}", "=".repeat(80));
+ println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n");
+
+ // Use TDT for transcription with sentence-level timestamps
+ let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?;
+
+ // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation)
+ if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) {
+ // For each sentence from TDT, find the corresponding speaker from Sortformer
+ for segment in &result.tokens {
+ // Find speaker with maximum overlap
+ let speaker = speaker_segments
+ .iter()
+ .filter_map(|s| {
+ // Calculate overlap between transcription and diarization segment
+ let overlap_start = segment.start.max(s.start);
+ let overlap_end = segment.end.min(s.end);
+ let overlap = (overlap_end - overlap_start).max(0.0);
+ if overlap > 0.0 {
+ Some((s.speaker_id, overlap))
+ } else {
+ None
+ }
+ })
+ .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
+ .map(|(id, _)| format!("Speaker {}", id))
+ .unwrap_or_else(|| "UNKNOWN".to_string());
+
+ println!("[{:.2}s - {:.2}s] {}: {}",
+ segment.start, segment.end, speaker, segment.text);
+ }
+ }
+
+ println!("\n{}", "=".repeat(80));
+ let elapsed = start_time.elapsed();
+ println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32());
+ println!("• UNKNOWN: Segments where no speaker was detected by Sortformer");
+ println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)");
+
+ Ok(())
+ }
+
+ #[cfg(not(feature = "sortformer"))]
+ unreachable!()
+}
diff --git a/vendor/parakeet-rs/examples/raw.rs b/vendor/parakeet-rs/examples/raw.rs
new file mode 100644
index 0000000..a1a2adc
--- /dev/null
+++ b/vendor/parakeet-rs/examples/raw.rs
@@ -0,0 +1,86 @@
+/*
+Demonstrates using transcribe_samples()
+
+This example shows manual audio loading and calling transcribe_samples() directly
+with sample_rate and channels instead of using transcribe_file()
+
+Usage:
+cargo run --example raw 6_speakers.wav
+cargo run --example raw 6_speakers.wav tdt
+
+WARNING: TDT model has sequence length limitations (~8-10 minutes max).
+For longer audio files, you must split into chunks (e.g., 5-minute segments)
+and transcribe each chunk separately. Attempting to transcribe 25+ minute
+audio files in one call will cause ONNX runtime errors.
+Otherwise you will likely get a error like:
+"Error: Ort(Error { code: RuntimeException, msg: "Non-zero status code returned while running Add node. Name:'/layers.0/self_attn/Add_2' Status Message: /Users/runner/work/ort-artifacts/ort-artifacts/onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. })"
+*/
+
+use parakeet_rs::{Parakeet, ParakeetTDT, TimestampMode};
+use std::env;
+use std::time::Instant;
+
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+ let start_time = Instant::now();
+ let args: Vec<String> = env::args().collect();
+ let audio_path = if args.len() > 1 {
+ &args[1]
+ } else {
+ "6_speakers.wav"
+ };
+
+ let use_tdt = args.len() > 2 && args[2] == "tdt";
+
+ // Load audio manually using hound (or any other audio library)
+ // remember if you use raw audio API, you need to handle audio preprocessing yourself!
+ let mut reader = hound::WavReader::open(audio_path)?;
+ let spec = reader.spec();
+
+ println!("Audio info: {}Hz, {} channel(s)", spec.sample_rate, spec.channels);
+
+ let audio: Vec<f32> = match spec.sample_format {
+ hound::SampleFormat::Float => reader
+ .samples::<f32>()
+ .collect::<Result<Vec<_>, _>>()?,
+ hound::SampleFormat::Int => reader
+ .samples::<i16>()
+ .map(|s| s.map(|s| s as f32 / 32768.0))
+ .collect::<Result<Vec<_>, _>>()?,
+ };
+
+ if use_tdt {
+ println!("Loading TDT model...");
+ let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?;
+
+ // Use transcribe_samples() with raw parameters and timestamp mode
+ let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences))?;
+
+ println!("{}", result.text);
+ println!("\nSentencess:");
+ for segment in result.tokens.iter() {
+ println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
+ }
+ } else {
+ println!("Loading CTC model...");
+ let mut parakeet = Parakeet::from_pretrained(".", None)?;
+
+ // CTC model doesn't predict punctuation (lowercase alphabet only)
+ // This means no sentence boundaries. we use Words mode instead of Sentences
+ let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Words))?;
+
+ println!("{}", result.text);
+
+ // Access word-level timestamps (showing first 10 for brevity)
+ // Note: CTC generates word-level timestamps but cannot segment into sentences
+ // due to lack of punctuation prediction - this is a model limitation if I not mistake
+ println!("\nWords (first 10):");
+ for word in result.tokens.iter().take(10) {
+ println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text);
+ }
+ }
+
+ let elapsed = start_time.elapsed();
+ println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
+
+ Ok(())
+}
diff --git a/vendor/parakeet-rs/examples/streaming.rs b/vendor/parakeet-rs/examples/streaming.rs
new file mode 100644
index 0000000..f5d36c9
--- /dev/null
+++ b/vendor/parakeet-rs/examples/streaming.rs
@@ -0,0 +1,129 @@
+/*
+Demonstrates streaming ASR with Parakeet RealTime EOU
+
+Download models files from:
+https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx
+
+This example
+- Maintains 4-second ring buffer for feature extraction context
+- Processes 160ms chunks (2560 samples at 16kHz)
+- Extracts features from full buffer, then slices last 25 frames
+- Encoder receives: 9 frames (pre-encode cache) + 16 frames (new) = 25 total
+- Cache states (cache_last_channel/time) maintain temporal context
+
+Model files required in ./fullstr/:
+ - encoder.onnx (cache_aware_stream_step export)
+ - decoder_joint.onnx
+ - tokenizer.json
+
+Additional notes:
+let reset_on_eou: bool = false;
+I must admit that this is not work very well on my real world tests :/
+
+
+Usage:
+cargo run --release --example streaming <audio.wav>
+*/
+
+use hound;
+use parakeet_rs::ParakeetEOU;
+use std::env;
+use std::time::Instant;
+
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+ let start_time = Instant::now();
+
+ let args: Vec<String> = env::args().collect();
+ let audio_path = args
+ .get(1)
+ .expect("Usage: cargo run --release --example streaming <audio.wav>");
+
+ println!("Loading model from ./fullstr...");
+ let mut parakeet = ParakeetEOU::from_pretrained("./fullstr", None)?;
+
+ println!("Loading audio: {}", audio_path);
+ let mut reader = hound::WavReader::open(audio_path)?;
+ let spec = reader.spec();
+
+ let mut audio: Vec<f32> = match spec.sample_format {
+ hound::SampleFormat::Float => reader
+ .samples::<f32>()
+ .collect::<Result<Vec<_>, _>>()?,
+ hound::SampleFormat::Int => reader
+ .samples::<i16>()
+ .map(|s| s.map(|s| s as f32 / 32768.0))
+ .collect::<Result<Vec<_>, _>>()?,
+ };
+
+ if spec.sample_rate != 16000 {
+ return Err(format!(
+ "Expected 16kHz audio, got {}Hz. Please resample first.",
+ spec.sample_rate
+ )
+ .into());
+ }
+
+ if spec.channels > 1 {
+ audio = audio
+ .chunks(spec.channels as usize)
+ .map(|chunk| chunk.iter().sum::<f32>() / spec.channels as f32)
+ .collect();
+ }
+
+ let max_val = audio.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
+ if max_val > 1e-6 {
+ let norm_factor = max_val + 1e-5;
+ for sample in &mut audio {
+ *sample /= norm_factor;
+ }
+ }
+
+ let duration = audio.len() as f32 / 16000.0;
+ // 160ms at 16kHz
+ const CHUNK_SIZE: usize = 2560;
+ let reset_on_eou: bool = false;
+
+ println!("Streaming transcription (160ms chunks with 4s buffer)...\n");
+
+ let mut full_text = String::new();
+
+ for chunk in audio.chunks(CHUNK_SIZE) {
+ let chunk_vec = if chunk.len() < CHUNK_SIZE {
+ let mut padded = chunk.to_vec();
+ padded.resize(CHUNK_SIZE, 0.0);
+ padded
+ } else {
+ chunk.to_vec()
+ };
+
+ let text = parakeet.transcribe(&chunk_vec, reset_on_eou)?;
+ if !text.is_empty() {
+ print!("{}", text);
+ std::io::Write::flush(&mut std::io::stdout())?;
+ full_text.push_str(&text);
+ }
+ }
+
+ println!("\n\nFlushing decoder...");
+ let silence = vec![0.0f32; CHUNK_SIZE];
+ for _ in 0..3 {
+ let text = parakeet.transcribe(&silence, reset_on_eou)?;
+ if !text.is_empty() {
+ print!("{}", text);
+ std::io::Write::flush(&mut std::io::stdout())?;
+ full_text.push_str(&text);
+ }
+ }
+
+ println!("\n\nFinal Transcription:\n{}", full_text.trim());
+
+ let elapsed = start_time.elapsed();
+ println!(
+ "\nTranscription completed in {:.2}s (audio: {:.2}s, RTF: {:.2}x)",
+ elapsed.as_secs_f32(),
+ duration,
+ duration / elapsed.as_secs_f32()
+ );
+
+ Ok(())
+}
diff --git a/vendor/parakeet-rs/examples/transcribe.rs b/vendor/parakeet-rs/examples/transcribe.rs
new file mode 100644
index 0000000..685e8de
--- /dev/null
+++ b/vendor/parakeet-rs/examples/transcribe.rs
@@ -0,0 +1,106 @@
+/*
+transcribes entire audio, no diarization
+wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav
+
+CTC (English-only):
+cargo run --example transcribe 6_speakers.wav
+
+TDT (Multilingual):
+cargo run --example transcribe 6_speakers.wav tdt
+
+NOTE: For manual audio loading without using transcribe_file(), see examples/raw.rs
+- Shows transcribe_samples(audio, sample_rate, channels, timestamps) usage
+
+WARNING: This may fail on very long audio files (>8 min).
+For longer audio, use the pyannote example which processes segments, or split your audio into chunks.
+
+Note: The coreml feature flag is only for reproducing a known ONNX Runtime bug.
+Just ignore it :). See: https://github.com/microsoft/onnxruntime/issues/26355
+*/
+use parakeet_rs::{Parakeet, TimestampMode};
+use std::env;
+use std::time::Instant;
+
+#[cfg(feature = "coreml")]
+use parakeet_rs::{ExecutionConfig, ExecutionProvider};
+
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+ let start_time = Instant::now();
+ let args: Vec<String> = env::args().collect();
+ let audio_path = if args.len() > 1 {
+ &args[1]
+ } else {
+ "6_speakers.wav"
+ };
+
+ let use_tdt = args.len() > 2 && args[2] == "tdt";
+
+ // TDT model (multilingual, 25 languages)
+ if use_tdt {
+ #[cfg(feature = "coreml")]
+ {
+ let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML);
+ let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", Some(config))?;
+ let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?;
+ println!("{}", result.text);
+
+ println!("\nSentencess:");
+ for segment in result.tokens.iter() {
+ println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
+ }
+
+ let elapsed = start_time.elapsed();
+ println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
+ return Ok(());
+ }
+
+ #[cfg(not(feature = "coreml"))]
+ {
+ let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?;
+ let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?;
+ println!("{}", result.text);
+
+ println!("\nSentencess:");
+ for segment in result.tokens.iter() {
+ println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
+ }
+
+ let elapsed = start_time.elapsed();
+ println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
+ return Ok(());
+ }
+ }
+
+ // CTC model (English-only)
+ #[cfg(feature = "coreml")]
+ let mut parakeet = {
+ let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML);
+ Parakeet::from_pretrained(".", Some(config))?
+ };
+
+ // Default: CPU execution provider (works correctly)
+ // Auto-detects model with priority: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx
+ // Or specify exact model: Parakeet::from_pretrained("model_q4.onnx", None)?
+ #[cfg(not(feature = "coreml"))]
+ let mut parakeet = Parakeet::from_pretrained(".", None)?;
+
+ // CTC model doesn't predict punctuation (lowercase alphabet only)
+ // This means no sentence boundaries - use Words mode instead of Sentences
+ let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Words))?;
+
+ // Print transcription
+ println!("{}", result.text);
+
+ // Access word-level timestamps (showing first 10 for brevity)
+ // Note: CTC generates word-level timestamps but cannot segment into sentences
+ // due to lack of punctuation prediction - this is a model limitation
+ println!("\nWords (first 10):");
+ for word in result.tokens.iter().take(10) {
+ println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text);
+ }
+
+ let elapsed = start_time.elapsed();
+ println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
+
+ Ok(())
+}