diff options
Diffstat (limited to 'vendor/parakeet-rs/src/parakeet.rs')
| -rw-r--r-- | vendor/parakeet-rs/src/parakeet.rs | 210 |
1 files changed, 210 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/parakeet.rs b/vendor/parakeet-rs/src/parakeet.rs new file mode 100644 index 0000000..d2aabdd --- /dev/null +++ b/vendor/parakeet-rs/src/parakeet.rs @@ -0,0 +1,210 @@ +use crate::audio; +use crate::config::PreprocessorConfig; +use crate::decoder::{ParakeetDecoder, TranscriptionResult}; +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use crate::model::ParakeetModel; +use crate::timestamps::{process_timestamps, TimestampMode}; +use std::path::{Path, PathBuf}; + +pub struct Parakeet { + model: ParakeetModel, + decoder: ParakeetDecoder, + preprocessor_config: PreprocessorConfig, + model_dir: PathBuf, +} + +impl Parakeet { + /// Load Parakeet model from path with optional configuration. + /// + /// # Arguments + /// * `path` - Directory containing model files, or path to specific model file + /// * `config` - Optional execution configuration (defaults to CPU if None) + /// + /// # Examples + /// ```no_run + /// use parakeet_rs::Parakeet; + /// + /// // Load from directory with CPU (default) + /// let parakeet = Parakeet::from_pretrained(".", None)?; + /// + /// // Or load from specific model file + /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?; + /// # Ok::<(), Box<dyn std::error::Error>>(()) + /// ``` + /// + /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.) + /// and pass an `ExecutionConfig` with the desired execution provider. + pub fn from_pretrained<P: AsRef<Path>>( + path: P, + config: Option<ExecutionConfig>, + ) -> Result<Self> { + let path = path.as_ref(); + + // Determine if path is a directory or file + let (model_path, tokenizer_path, model_dir) = if path.is_dir() { + // Directory mode: auto-detect model file + let model_path = Self::find_model_file(path)?; + let tokenizer_path = path.join("tokenizer.json"); + (model_path, tokenizer_path, path.to_path_buf()) + } else if path.is_file() { + // File mode: path points directly to model file + let model_dir = path + .parent() + .ok_or_else(|| Error::Config("Invalid model path".to_string()))?; + let tokenizer_path = model_dir.join("tokenizer.json"); + (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf()) + } else { + return Err(Error::Config(format!( + "Path does not exist: {}", + path.display() + ))); + }; + + // Check tokenizer exists + if !tokenizer_path.exists() { + return Err(Error::Config(format!( + "Required file 'tokenizer.json' not found in {}", + model_dir.display() + ))); + } + + let preprocessor_config = PreprocessorConfig::default(); + let exec_config = config.unwrap_or_default(); + + let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?; + let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?; + + Ok(Self { + model, + decoder, + preprocessor_config, + model_dir, + }) + } + + fn find_model_file(dir: &Path) -> Result<PathBuf> { + // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx + let candidates = [ + "model.onnx", + "model_fp16.onnx", + "model_int8.onnx", + "model_q4.onnx", + ]; + + for candidate in &candidates { + let path = dir.join(candidate); + if path.exists() { + return Ok(path); + } + } + + // If none of the standard names found, search for any .onnx file + if let Ok(entries) = std::fs::read_dir(dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("onnx") { + return Ok(path); + } + } + } + + Err(Error::Config(format!( + "No model file (*.onnx) found in directory: {}", + dir.display() + ))) + } + + /// Transcribe audio samples. + /// + /// # Arguments + /// + /// * `audio` - Audio samples as f32 values + /// * `sample_rate` - Sample rate in Hz + /// * `channels` - Number of audio channels + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level. + pub fn transcribe_samples( + &mut self, + audio: Vec<f32>, + sample_rate: u32, + channels: u16, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; + let logits = self.model.forward(features)?; + + let mut result = self.decoder.decode_with_timestamps( + &logits, + self.preprocessor_config.hop_length, + self.preprocessor_config.sampling_rate, + )?; + + // Process timestamps to requested output mode + let mode = mode.unwrap_or(TimestampMode::Tokens); + result.tokens = process_timestamps(&result.tokens, mode); + + // Rebuild full text from processed tokens to ensure consistency + result.text = result.tokens.iter() + .map(|t| t.text.as_str()) + .collect::<Vec<_>>() + .join(" "); + + Ok(result) + } + + /// Transcribe an audio file with timestamps + /// + /// # Arguments + /// + /// * `audio_path` - A path to the audio file that needs to be transcribed. + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. + pub fn transcribe_file<P: AsRef<Path>>( + &mut self, + audio_path: P, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let audio_path = audio_path.as_ref(); + let (audio, spec) = audio::load_audio(audio_path)?; + + self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) + } + + /// Transcribes multiple audio files in batch. + /// + /// # Arguments + /// + /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. + pub fn transcribe_file_batch<P: AsRef<Path>>( + &mut self, + audio_paths: &[P], + mode: Option<TimestampMode>, + ) -> Result<Vec<TranscriptionResult>> { + let mut results = Vec::with_capacity(audio_paths.len()); + for path in audio_paths { + let result = self.transcribe_file(path, mode)?; + results.push(result); + } + Ok(results) + } + + pub fn model_dir(&self) -> &Path { + &self.model_dir + } + + pub fn preprocessor_config(&self) -> &PreprocessorConfig { + &self.preprocessor_config + } +} |
