summaryrefslogtreecommitdiff
path: root/parakeet-rs/examples
diff options
context:
space:
mode:
Diffstat (limited to 'parakeet-rs/examples')
-rw-r--r--parakeet-rs/examples/diarization.rs137
-rw-r--r--parakeet-rs/examples/raw.rs86
-rw-r--r--parakeet-rs/examples/streaming.rs129
-rw-r--r--parakeet-rs/examples/transcribe.rs106
4 files changed, 0 insertions, 458 deletions
diff --git a/parakeet-rs/examples/diarization.rs b/parakeet-rs/examples/diarization.rs
deleted file mode 100644
index 5982ecb..0000000
--- a/parakeet-rs/examples/diarization.rs
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
-Speaker Diarization with NVIDIA Sortformer v2 (Streaming)
-
-Download the Sortformer v2 model:
-https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx
-Download test audio:
-wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav
-
-Usage:
-cargo run --example diarization --features sortformer 6_speakers.wav
-
-NOTE: This example combines two NVIDIA models:
-- Parakeet-TDT: Provides transcription with sentence-level timestamps
-- Sortformer v2: Provides streaming speaker identification (4 speakers max)
-- We use TDT's sentence timestamps + Sortformer's speaker IDs
-- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN)
-- For more information:
-https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
-
-WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence
-length limitations (~8-10 minutes max). For production use with long audio files,
-run Sortformer on the full audio for diarization, then chunk the audio into
-~5-minute segments for TDT transcription, and map the results back together.
-*/
-
-#[cfg(feature = "sortformer")]
-use parakeet_rs::sortformer::{DiarizationConfig, Sortformer};
-#[cfg(feature = "sortformer")]
-use parakeet_rs::TimestampMode;
-#[cfg(feature = "sortformer")]
-use hound;
-#[cfg(feature = "sortformer")]
-use std::env;
-#[cfg(feature = "sortformer")]
-use std::time::Instant;
-
-#[allow(unreachable_code)]
-fn main() -> Result<(), Box<dyn std::error::Error>> {
- #[cfg(not(feature = "sortformer"))]
- {
- eprintln!("Error: This example requires the 'sortformer' feature.");
- eprintln!("Please run with: cargo run --example diarization --features sortformer <audio.wav>");
- return Err("sortformer feature not enabled".into());
- }
-
- #[cfg(feature = "sortformer")]
- {
- let start_time = Instant::now();
- let args: Vec<String> = env::args().collect();
- let audio_path = args.get(1)
- .expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>");
-
- println!("{}", "=".repeat(80));
- println!("Step 1/3: Loading audio...");
-
- let mut reader = hound::WavReader::open(audio_path)?;
- let spec = reader.spec();
-
- let audio: Vec<f32> = match spec.sample_format {
- hound::SampleFormat::Float => reader
- .samples::<f32>()
- .collect::<Result<Vec<_>, _>>()?,
- hound::SampleFormat::Int => reader
- .samples::<i16>()
- .map(|s| s.map(|s| s as f32 / 32768.0))
- .collect::<Result<Vec<_>, _>>()?,
- };
-
- let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32;
- println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)",
- audio.len(), spec.sample_rate, spec.channels, duration);
-
- println!("{}", "=".repeat(80));
- println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)...");
-
- // Create Sortformer with default config (callhome)
- let mut sortformer = Sortformer::with_config(
- "diar_streaming_sortformer_4spk-v2.onnx",
- None, // default exec config
- DiarizationConfig::callhome(),
- )?;
-
- let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?;
-
- println!("Found {} speaker segments from Sortformer", speaker_segments.len());
-
- // Print raw diarization segments
- println!("\nRaw diarization segments:");
- for seg in &speaker_segments {
- println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id);
- }
-
- println!("\n{}", "=".repeat(80));
- println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n");
-
- // Use TDT for transcription with sentence-level timestamps
- let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?;
-
- // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation)
- if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) {
- // For each sentence from TDT, find the corresponding speaker from Sortformer
- for segment in &result.tokens {
- // Find speaker with maximum overlap
- let speaker = speaker_segments
- .iter()
- .filter_map(|s| {
- // Calculate overlap between transcription and diarization segment
- let overlap_start = segment.start.max(s.start);
- let overlap_end = segment.end.min(s.end);
- let overlap = (overlap_end - overlap_start).max(0.0);
- if overlap > 0.0 {
- Some((s.speaker_id, overlap))
- } else {
- None
- }
- })
- .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
- .map(|(id, _)| format!("Speaker {}", id))
- .unwrap_or_else(|| "UNKNOWN".to_string());
-
- println!("[{:.2}s - {:.2}s] {}: {}",
- segment.start, segment.end, speaker, segment.text);
- }
- }
-
- println!("\n{}", "=".repeat(80));
- let elapsed = start_time.elapsed();
- println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32());
- println!("• UNKNOWN: Segments where no speaker was detected by Sortformer");
- println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)");
-
- Ok(())
- }
-
- #[cfg(not(feature = "sortformer"))]
- unreachable!()
-}
diff --git a/parakeet-rs/examples/raw.rs b/parakeet-rs/examples/raw.rs
deleted file mode 100644
index a1a2adc..0000000
--- a/parakeet-rs/examples/raw.rs
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
-Demonstrates using transcribe_samples()
-
-This example shows manual audio loading and calling transcribe_samples() directly
-with sample_rate and channels instead of using transcribe_file()
-
-Usage:
-cargo run --example raw 6_speakers.wav
-cargo run --example raw 6_speakers.wav tdt
-
-WARNING: TDT model has sequence length limitations (~8-10 minutes max).
-For longer audio files, you must split into chunks (e.g., 5-minute segments)
-and transcribe each chunk separately. Attempting to transcribe 25+ minute
-audio files in one call will cause ONNX runtime errors.
-Otherwise you will likely get a error like:
-"Error: Ort(Error { code: RuntimeException, msg: "Non-zero status code returned while running Add node. Name:'/layers.0/self_attn/Add_2' Status Message: /Users/runner/work/ort-artifacts/ort-artifacts/onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. })"
-*/
-
-use parakeet_rs::{Parakeet, ParakeetTDT, TimestampMode};
-use std::env;
-use std::time::Instant;
-
-fn main() -> Result<(), Box<dyn std::error::Error>> {
- let start_time = Instant::now();
- let args: Vec<String> = env::args().collect();
- let audio_path = if args.len() > 1 {
- &args[1]
- } else {
- "6_speakers.wav"
- };
-
- let use_tdt = args.len() > 2 && args[2] == "tdt";
-
- // Load audio manually using hound (or any other audio library)
- // remember if you use raw audio API, you need to handle audio preprocessing yourself!
- let mut reader = hound::WavReader::open(audio_path)?;
- let spec = reader.spec();
-
- println!("Audio info: {}Hz, {} channel(s)", spec.sample_rate, spec.channels);
-
- let audio: Vec<f32> = match spec.sample_format {
- hound::SampleFormat::Float => reader
- .samples::<f32>()
- .collect::<Result<Vec<_>, _>>()?,
- hound::SampleFormat::Int => reader
- .samples::<i16>()
- .map(|s| s.map(|s| s as f32 / 32768.0))
- .collect::<Result<Vec<_>, _>>()?,
- };
-
- if use_tdt {
- println!("Loading TDT model...");
- let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?;
-
- // Use transcribe_samples() with raw parameters and timestamp mode
- let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences))?;
-
- println!("{}", result.text);
- println!("\nSentencess:");
- for segment in result.tokens.iter() {
- println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
- }
- } else {
- println!("Loading CTC model...");
- let mut parakeet = Parakeet::from_pretrained(".", None)?;
-
- // CTC model doesn't predict punctuation (lowercase alphabet only)
- // This means no sentence boundaries. we use Words mode instead of Sentences
- let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Words))?;
-
- println!("{}", result.text);
-
- // Access word-level timestamps (showing first 10 for brevity)
- // Note: CTC generates word-level timestamps but cannot segment into sentences
- // due to lack of punctuation prediction - this is a model limitation if I not mistake
- println!("\nWords (first 10):");
- for word in result.tokens.iter().take(10) {
- println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text);
- }
- }
-
- let elapsed = start_time.elapsed();
- println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
-
- Ok(())
-}
diff --git a/parakeet-rs/examples/streaming.rs b/parakeet-rs/examples/streaming.rs
deleted file mode 100644
index f5d36c9..0000000
--- a/parakeet-rs/examples/streaming.rs
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
-Demonstrates streaming ASR with Parakeet RealTime EOU
-
-Download models files from:
-https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx
-
-This example
-- Maintains 4-second ring buffer for feature extraction context
-- Processes 160ms chunks (2560 samples at 16kHz)
-- Extracts features from full buffer, then slices last 25 frames
-- Encoder receives: 9 frames (pre-encode cache) + 16 frames (new) = 25 total
-- Cache states (cache_last_channel/time) maintain temporal context
-
-Model files required in ./fullstr/:
- - encoder.onnx (cache_aware_stream_step export)
- - decoder_joint.onnx
- - tokenizer.json
-
-Additional notes:
-let reset_on_eou: bool = false;
-I must admit that this is not work very well on my real world tests :/
-
-
-Usage:
-cargo run --release --example streaming <audio.wav>
-*/
-
-use hound;
-use parakeet_rs::ParakeetEOU;
-use std::env;
-use std::time::Instant;
-
-fn main() -> Result<(), Box<dyn std::error::Error>> {
- let start_time = Instant::now();
-
- let args: Vec<String> = env::args().collect();
- let audio_path = args
- .get(1)
- .expect("Usage: cargo run --release --example streaming <audio.wav>");
-
- println!("Loading model from ./fullstr...");
- let mut parakeet = ParakeetEOU::from_pretrained("./fullstr", None)?;
-
- println!("Loading audio: {}", audio_path);
- let mut reader = hound::WavReader::open(audio_path)?;
- let spec = reader.spec();
-
- let mut audio: Vec<f32> = match spec.sample_format {
- hound::SampleFormat::Float => reader
- .samples::<f32>()
- .collect::<Result<Vec<_>, _>>()?,
- hound::SampleFormat::Int => reader
- .samples::<i16>()
- .map(|s| s.map(|s| s as f32 / 32768.0))
- .collect::<Result<Vec<_>, _>>()?,
- };
-
- if spec.sample_rate != 16000 {
- return Err(format!(
- "Expected 16kHz audio, got {}Hz. Please resample first.",
- spec.sample_rate
- )
- .into());
- }
-
- if spec.channels > 1 {
- audio = audio
- .chunks(spec.channels as usize)
- .map(|chunk| chunk.iter().sum::<f32>() / spec.channels as f32)
- .collect();
- }
-
- let max_val = audio.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
- if max_val > 1e-6 {
- let norm_factor = max_val + 1e-5;
- for sample in &mut audio {
- *sample /= norm_factor;
- }
- }
-
- let duration = audio.len() as f32 / 16000.0;
- // 160ms at 16kHz
- const CHUNK_SIZE: usize = 2560;
- let reset_on_eou: bool = false;
-
- println!("Streaming transcription (160ms chunks with 4s buffer)...\n");
-
- let mut full_text = String::new();
-
- for chunk in audio.chunks(CHUNK_SIZE) {
- let chunk_vec = if chunk.len() < CHUNK_SIZE {
- let mut padded = chunk.to_vec();
- padded.resize(CHUNK_SIZE, 0.0);
- padded
- } else {
- chunk.to_vec()
- };
-
- let text = parakeet.transcribe(&chunk_vec, reset_on_eou)?;
- if !text.is_empty() {
- print!("{}", text);
- std::io::Write::flush(&mut std::io::stdout())?;
- full_text.push_str(&text);
- }
- }
-
- println!("\n\nFlushing decoder...");
- let silence = vec![0.0f32; CHUNK_SIZE];
- for _ in 0..3 {
- let text = parakeet.transcribe(&silence, reset_on_eou)?;
- if !text.is_empty() {
- print!("{}", text);
- std::io::Write::flush(&mut std::io::stdout())?;
- full_text.push_str(&text);
- }
- }
-
- println!("\n\nFinal Transcription:\n{}", full_text.trim());
-
- let elapsed = start_time.elapsed();
- println!(
- "\nTranscription completed in {:.2}s (audio: {:.2}s, RTF: {:.2}x)",
- elapsed.as_secs_f32(),
- duration,
- duration / elapsed.as_secs_f32()
- );
-
- Ok(())
-}
diff --git a/parakeet-rs/examples/transcribe.rs b/parakeet-rs/examples/transcribe.rs
deleted file mode 100644
index 685e8de..0000000
--- a/parakeet-rs/examples/transcribe.rs
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
-transcribes entire audio, no diarization
-wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav
-
-CTC (English-only):
-cargo run --example transcribe 6_speakers.wav
-
-TDT (Multilingual):
-cargo run --example transcribe 6_speakers.wav tdt
-
-NOTE: For manual audio loading without using transcribe_file(), see examples/raw.rs
-- Shows transcribe_samples(audio, sample_rate, channels, timestamps) usage
-
-WARNING: This may fail on very long audio files (>8 min).
-For longer audio, use the pyannote example which processes segments, or split your audio into chunks.
-
-Note: The coreml feature flag is only for reproducing a known ONNX Runtime bug.
-Just ignore it :). See: https://github.com/microsoft/onnxruntime/issues/26355
-*/
-use parakeet_rs::{Parakeet, TimestampMode};
-use std::env;
-use std::time::Instant;
-
-#[cfg(feature = "coreml")]
-use parakeet_rs::{ExecutionConfig, ExecutionProvider};
-
-fn main() -> Result<(), Box<dyn std::error::Error>> {
- let start_time = Instant::now();
- let args: Vec<String> = env::args().collect();
- let audio_path = if args.len() > 1 {
- &args[1]
- } else {
- "6_speakers.wav"
- };
-
- let use_tdt = args.len() > 2 && args[2] == "tdt";
-
- // TDT model (multilingual, 25 languages)
- if use_tdt {
- #[cfg(feature = "coreml")]
- {
- let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML);
- let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", Some(config))?;
- let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?;
- println!("{}", result.text);
-
- println!("\nSentencess:");
- for segment in result.tokens.iter() {
- println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
- }
-
- let elapsed = start_time.elapsed();
- println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
- return Ok(());
- }
-
- #[cfg(not(feature = "coreml"))]
- {
- let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?;
- let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?;
- println!("{}", result.text);
-
- println!("\nSentencess:");
- for segment in result.tokens.iter() {
- println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text);
- }
-
- let elapsed = start_time.elapsed();
- println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
- return Ok(());
- }
- }
-
- // CTC model (English-only)
- #[cfg(feature = "coreml")]
- let mut parakeet = {
- let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML);
- Parakeet::from_pretrained(".", Some(config))?
- };
-
- // Default: CPU execution provider (works correctly)
- // Auto-detects model with priority: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx
- // Or specify exact model: Parakeet::from_pretrained("model_q4.onnx", None)?
- #[cfg(not(feature = "coreml"))]
- let mut parakeet = Parakeet::from_pretrained(".", None)?;
-
- // CTC model doesn't predict punctuation (lowercase alphabet only)
- // This means no sentence boundaries - use Words mode instead of Sentences
- let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Words))?;
-
- // Print transcription
- println!("{}", result.text);
-
- // Access word-level timestamps (showing first 10 for brevity)
- // Note: CTC generates word-level timestamps but cannot segment into sentences
- // due to lack of punctuation prediction - this is a model limitation
- println!("\nWords (first 10):");
- for word in result.tokens.iter().take(10) {
- println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text);
- }
-
- let elapsed = start_time.elapsed();
- println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32());
-
- Ok(())
-}