summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/parakeet.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 01:27:02 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch)
tree497bffd67001501a003739cfe0bb790502ffd50a /parakeet-rs/src/parakeet.rs
parent55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff)
downloadsoryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz
soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'parakeet-rs/src/parakeet.rs')
-rw-r--r--parakeet-rs/src/parakeet.rs210
1 files changed, 0 insertions, 210 deletions
diff --git a/parakeet-rs/src/parakeet.rs b/parakeet-rs/src/parakeet.rs
deleted file mode 100644
index d2aabdd..0000000
--- a/parakeet-rs/src/parakeet.rs
+++ /dev/null
@@ -1,210 +0,0 @@
-use crate::audio;
-use crate::config::PreprocessorConfig;
-use crate::decoder::{ParakeetDecoder, TranscriptionResult};
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use crate::model::ParakeetModel;
-use crate::timestamps::{process_timestamps, TimestampMode};
-use std::path::{Path, PathBuf};
-
-pub struct Parakeet {
- model: ParakeetModel,
- decoder: ParakeetDecoder,
- preprocessor_config: PreprocessorConfig,
- model_dir: PathBuf,
-}
-
-impl Parakeet {
- /// Load Parakeet model from path with optional configuration.
- ///
- /// # Arguments
- /// * `path` - Directory containing model files, or path to specific model file
- /// * `config` - Optional execution configuration (defaults to CPU if None)
- ///
- /// # Examples
- /// ```no_run
- /// use parakeet_rs::Parakeet;
- ///
- /// // Load from directory with CPU (default)
- /// let parakeet = Parakeet::from_pretrained(".", None)?;
- ///
- /// // Or load from specific model file
- /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?;
- /// # Ok::<(), Box<dyn std::error::Error>>(())
- /// ```
- ///
- /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.)
- /// and pass an `ExecutionConfig` with the desired execution provider.
- pub fn from_pretrained<P: AsRef<Path>>(
- path: P,
- config: Option<ExecutionConfig>,
- ) -> Result<Self> {
- let path = path.as_ref();
-
- // Determine if path is a directory or file
- let (model_path, tokenizer_path, model_dir) = if path.is_dir() {
- // Directory mode: auto-detect model file
- let model_path = Self::find_model_file(path)?;
- let tokenizer_path = path.join("tokenizer.json");
- (model_path, tokenizer_path, path.to_path_buf())
- } else if path.is_file() {
- // File mode: path points directly to model file
- let model_dir = path
- .parent()
- .ok_or_else(|| Error::Config("Invalid model path".to_string()))?;
- let tokenizer_path = model_dir.join("tokenizer.json");
- (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf())
- } else {
- return Err(Error::Config(format!(
- "Path does not exist: {}",
- path.display()
- )));
- };
-
- // Check tokenizer exists
- if !tokenizer_path.exists() {
- return Err(Error::Config(format!(
- "Required file 'tokenizer.json' not found in {}",
- model_dir.display()
- )));
- }
-
- let preprocessor_config = PreprocessorConfig::default();
- let exec_config = config.unwrap_or_default();
-
- let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?;
- let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?;
-
- Ok(Self {
- model,
- decoder,
- preprocessor_config,
- model_dir,
- })
- }
-
- fn find_model_file(dir: &Path) -> Result<PathBuf> {
- // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx
- let candidates = [
- "model.onnx",
- "model_fp16.onnx",
- "model_int8.onnx",
- "model_q4.onnx",
- ];
-
- for candidate in &candidates {
- let path = dir.join(candidate);
- if path.exists() {
- return Ok(path);
- }
- }
-
- // If none of the standard names found, search for any .onnx file
- if let Ok(entries) = std::fs::read_dir(dir) {
- for entry in entries.flatten() {
- let path = entry.path();
- if path.extension().and_then(|s| s.to_str()) == Some("onnx") {
- return Ok(path);
- }
- }
- }
-
- Err(Error::Config(format!(
- "No model file (*.onnx) found in directory: {}",
- dir.display()
- )))
- }
-
- /// Transcribe audio samples.
- ///
- /// # Arguments
- ///
- /// * `audio` - Audio samples as f32 values
- /// * `sample_rate` - Sample rate in Hz
- /// * `channels` - Number of audio channels
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level.
- pub fn transcribe_samples(
- &mut self,
- audio: Vec<f32>,
- sample_rate: u32,
- channels: u16,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
- let logits = self.model.forward(features)?;
-
- let mut result = self.decoder.decode_with_timestamps(
- &logits,
- self.preprocessor_config.hop_length,
- self.preprocessor_config.sampling_rate,
- )?;
-
- // Process timestamps to requested output mode
- let mode = mode.unwrap_or(TimestampMode::Tokens);
- result.tokens = process_timestamps(&result.tokens, mode);
-
- // Rebuild full text from processed tokens to ensure consistency
- result.text = result.tokens.iter()
- .map(|t| t.text.as_str())
- .collect::<Vec<_>>()
- .join(" ");
-
- Ok(result)
- }
-
- /// Transcribe an audio file with timestamps
- ///
- /// # Arguments
- ///
- /// * `audio_path` - A path to the audio file that needs to be transcribed.
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
- pub fn transcribe_file<P: AsRef<Path>>(
- &mut self,
- audio_path: P,
- mode: Option<TimestampMode>,
- ) -> Result<TranscriptionResult> {
- let audio_path = audio_path.as_ref();
- let (audio, spec) = audio::load_audio(audio_path)?;
-
- self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
- }
-
- /// Transcribes multiple audio files in batch.
- ///
- /// # Arguments
- ///
- /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
- /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences)
- ///
- /// # Returns
- ///
- /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level.
- pub fn transcribe_file_batch<P: AsRef<Path>>(
- &mut self,
- audio_paths: &[P],
- mode: Option<TimestampMode>,
- ) -> Result<Vec<TranscriptionResult>> {
- let mut results = Vec::with_capacity(audio_paths.len());
- for path in audio_paths {
- let result = self.transcribe_file(path, mode)?;
- results.push(result);
- }
- Ok(results)
- }
-
- pub fn model_dir(&self) -> &Path {
- &self.model_dir
- }
-
- pub fn preprocessor_config(&self) -> &PreprocessorConfig {
- &self.preprocessor_config
- }
-}