summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/parakeet.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 00:40:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch)
tree0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/parakeet.rs
parent84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff)
downloadsoryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz
soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/parakeet.rs')
-rw-r--r--parakeet-rs/src/parakeet.rs210
1 files changed, 210 insertions, 0 deletions
diff --git a/parakeet-rs/src/parakeet.rs b/parakeet-rs/src/parakeet.rs
new file mode 100644
index 0000000..d2aabdd
--- /dev/null
+++ b/parakeet-rs/src/parakeet.rs
@@ -0,0 +1,210 @@
+use crate::audio;
+use crate::config::PreprocessorConfig;
+use crate::decoder::{ParakeetDecoder, TranscriptionResult};
+use crate::error::{Error, Result};
+use crate::execution::ModelConfig as ExecutionConfig;
+use crate::model::ParakeetModel;
+use crate::timestamps::{process_timestamps, TimestampMode};
+use std::path::{Path, PathBuf};
+
+pub struct Parakeet {
+ model: ParakeetModel,
+ decoder: ParakeetDecoder,
+ preprocessor_config: PreprocessorConfig,
+ model_dir: PathBuf,
+}
+
+impl Parakeet {
+ /// Load Parakeet model from path with optional configuration.
+ ///
+ /// # Arguments
+ /// * `path` - Directory containing model files, or path to specific model file
+ /// * `config` - Optional execution configuration (defaults to CPU if None)
+ ///
+ /// # Examples
+ /// ```no_run
+ /// use parakeet_rs::Parakeet;
+ ///
+ /// // Load from directory with CPU (default)
+ /// let parakeet = Parakeet::from_pretrained(".", None)?;
+ ///
+ /// // Or load from specific model file
+ /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?;
+ /// # Ok::<(), Box<dyn std::error::Error>>(())
+ /// ```
+ ///
+ /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.)
+ /// and pass an `ExecutionConfig` with the desired execution provider.
+ pub fn from_pretrained<P: AsRef<Path>>(
+ path: P,
+ config: Option<ExecutionConfig>,
+ ) -> Result<Self> {
+ let path = path.as_ref();
+
+ // Determine if path is a directory or file
+ let (model_path, tokenizer_path, model_dir) = if path.is_dir() {
+ // Directory mode: auto-detect model file
+ let model_path = Self::find_model_file(path)?;
+ let tokenizer_path = path.join("tokenizer.json");
+ (model_path, tokenizer_path, path.to_path_buf())
+ } else if path.is_file() {
+ // File mode: path points directly to model file
+ let model_dir = path
+ .parent()
+ .ok_or_else(|| Error::Config("Invalid model path".to_string()))?;
+ let tokenizer_path = model_dir.join("tokenizer.json");
+ (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf())
+ } else {
+ return Err(Error::Config(format!(
+ "Path does not exist: {}",
+ path.display()
+ )));
+ };
+
+ // Check tokenizer exists
+ if !tokenizer_path.exists() {
+ return Err(Error::Config(format!(
+ "Required file 'tokenizer.json' not found in {}",
+ model_dir.display()
+ )));
+ }
+
+ let preprocessor_config = PreprocessorConfig::default();
+ let exec_config = config.unwrap_or_default();
+
+ let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?;
+ let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?;
+
+ Ok(Self {
+ model,
+ decoder,
+ preprocessor_config,
+ model_dir,
+ })
+ }
+
+ fn find_model_file(dir: &Path) -> Result<PathBuf> {
+ // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx
+ let candidates = [
+ "model.onnx",
+ "model_fp16.onnx",
+ "model_int8.onnx",
+ "model_q4.onnx",
+ ];
+
+ for candidate in &candidates {
+ let path = dir.join(candidate);
+ if path.exists() {
+ return Ok(path);
+ }
+ }
+
+ // If none of the standard names found, search for any .onnx file
+ if let Ok(entries) = std::fs::read_dir(dir) {
+ for entry in entries.flatten() {
+ let path = entry.path();
+ if path.extension().and_then(|s| s.to_str()) == Some("onnx") {
+ return Ok(path);
+ }
+ }
+ }
+
+ Err(Error::Config(format!(
+ "No model file (*.onnx) found in directory: {}",
+ dir.display()
+ )))
+ }
+
+ /// Transcribe audio samples.
+ ///
+ /// # Arguments
+ ///
+ /// * `audio` - Audio samples as f32 values
+ /// * `sample_rate` - Sample rate in Hz
+ /// * `channels` - Number of audio channels
+ /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
+ ///
+ /// # Returns
+ ///
+ /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level.
+ pub fn transcribe_samples(
+ &mut self,
+ audio: Vec<f32>,
+ sample_rate: u32,
+ channels: u16,
+ mode: Option<TimestampMode>,
+ ) -> Result<TranscriptionResult> {
+ let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
+ let logits = self.model.forward(features)?;
+
+ let mut result = self.decoder.decode_with_timestamps(
+ &logits,
+ self.preprocessor_config.hop_length,
+ self.preprocessor_config.sampling_rate,
+ )?;
+
+ // Process timestamps to requested output mode
+ let mode = mode.unwrap_or(TimestampMode::Tokens);
+ result.tokens = process_timestamps(&result.tokens, mode);
+
+ // Rebuild full text from processed tokens to ensure consistency
+ result.text = result.tokens.iter()
+ .map(|t| t.text.as_str())
+ .collect::<Vec<_>>()
+ .join(" ");
+
+ Ok(result)
+ }
+
+ /// Transcribe an audio file with timestamps
+ ///
+ /// # Arguments
+ ///
+ /// * `audio_path` - A path to the audio file that needs to be transcribed.
+ /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
+ ///
+ /// # Returns
+ ///
+ /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
+ pub fn transcribe_file<P: AsRef<Path>>(
+ &mut self,
+ audio_path: P,
+ mode: Option<TimestampMode>,
+ ) -> Result<TranscriptionResult> {
+ let audio_path = audio_path.as_ref();
+ let (audio, spec) = audio::load_audio(audio_path)?;
+
+ self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
+ }
+
+ /// Transcribes multiple audio files in batch.
+ ///
+ /// # Arguments
+ ///
+ /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
+ /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
+ ///
+ /// # Returns
+ ///
+ /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
+ pub fn transcribe_file_batch<P: AsRef<Path>>(
+ &mut self,
+ audio_paths: &[P],
+ mode: Option<TimestampMode>,
+ ) -> Result<Vec<TranscriptionResult>> {
+ let mut results = Vec::with_capacity(audio_paths.len());
+ for path in audio_paths {
+ let result = self.transcribe_file(path, mode)?;
+ results.push(result);
+ }
+ Ok(results)
+ }
+
+ pub fn model_dir(&self) -> &Path {
+ &self.model_dir
+ }
+
+ pub fn preprocessor_config(&self) -> &PreprocessorConfig {
+ &self.preprocessor_config
+ }
+}