diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 01:27:02 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch) | |
| tree | 497bffd67001501a003739cfe0bb790502ffd50a /parakeet-rs/src/parakeet_tdt.rs | |
| parent | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff) | |
| download | soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip | |
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'parakeet-rs/src/parakeet_tdt.rs')
| -rw-r--r-- | parakeet-rs/src/parakeet_tdt.rs | 167 |
1 files changed, 0 insertions, 167 deletions
diff --git a/parakeet-rs/src/parakeet_tdt.rs b/parakeet-rs/src/parakeet_tdt.rs deleted file mode 100644 index 719ae75..0000000 --- a/parakeet-rs/src/parakeet_tdt.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::audio; -use crate::config::PreprocessorConfig; -use crate::decoder::TranscriptionResult; -use crate::decoder_tdt::ParakeetTDTDecoder; -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use crate::model_tdt::ParakeetTDTModel; -use crate::timestamps::{process_timestamps, TimestampMode}; -use crate::vocab::Vocabulary; -use std::path::{Path, PathBuf}; - -/// Parakeet TDT model for multilingual ASR -pub struct ParakeetTDT { - model: ParakeetTDTModel, - decoder: ParakeetTDTDecoder, - preprocessor_config: PreprocessorConfig, - model_dir: PathBuf, -} - -impl ParakeetTDT { - /// Load Parakeet TDT model from path with optional configuration. - /// - /// # Arguments - /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt - /// * `config` - Optional execution configuration (defaults to CPU if None) - pub fn from_pretrained<P: AsRef<Path>>( - path: P, - config: Option<ExecutionConfig>, - ) -> Result<Self> { - let path = path.as_ref(); - - if !path.is_dir() { - return Err(Error::Config(format!( - "TDT model path must be a directory: {}", - path.display() - ))); - } - - let vocab_path = path.join("vocab.txt"); - if !vocab_path.exists() { - return Err(Error::Config(format!( - "vocab.txt not found in {}", - path.display() - ))); - } - - // TDT-specific preprocessor config (128 features instead of 80) - let preprocessor_config = PreprocessorConfig { - feature_extractor_type: "ParakeetFeatureExtractor".to_string(), - feature_size: 128, - hop_length: 160, - n_fft: 512, - padding_side: "right".to_string(), - padding_value: 0.0, - preemphasis: 0.97, - processor_class: "ParakeetProcessor".to_string(), - return_attention_mask: true, - sampling_rate: 16000, - win_length: 400, - }; - - let exec_config = config.unwrap_or_default(); - - let model = ParakeetTDTModel::from_pretrained(path, exec_config)?; - let vocab = Vocabulary::from_file(&vocab_path)?; - let decoder = ParakeetTDTDecoder::from_vocab(vocab); - - Ok(Self { - model, - decoder, - preprocessor_config, - model_dir: path.to_path_buf(), - }) - } - - /// Transcribe audio samples. - /// - /// # Arguments - /// - /// * `audio` - Audio samples as f32 values - /// * `sample_rate` - Sample rate in Hz - /// * `channels` - Number of audio channels - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode. - pub fn transcribe_samples( - &mut self, - audio: Vec<f32>, - sample_rate: u32, - channels: u16, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; - let (tokens, frame_indices, durations) = self.model.forward(features)?; - - let mut result = self.decoder.decode_with_timestamps( - &tokens, - &frame_indices, - &durations, - self.preprocessor_config.hop_length, - self.preprocessor_config.sampling_rate, - )?; - - // Apply timestamp mode conversion - let mode = mode.unwrap_or(TimestampMode::Tokens); - result.tokens = process_timestamps(&result.tokens, mode); - - // Rebuild full text from processed tokens - result.text = result.tokens.iter() - .map(|t| t.text.as_str()) - .collect::<Vec<_>>() - .join(" "); - - Ok(result) - } - - /// Transcribe an audio file with timestamps - /// - /// # Arguments - /// - /// * `audio_path` - A path to the audio file that needs to be transcribed. - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. - pub fn transcribe_file<P: AsRef<Path>>( - &mut self, - audio_path: P, - mode: Option<TimestampMode>, - ) -> Result<TranscriptionResult> { - let audio_path = audio_path.as_ref(); - let (audio, spec) = audio::load_audio(audio_path)?; - - self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) - } - - /// Transcribes multiple audio files in batch. - /// - /// # Arguments - /// - /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. - /// * `mode` - Optional timestamp mode (Token, Word, or Segment) - /// - /// # Returns - /// - /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. - pub fn transcribe_file_batch<P: AsRef<Path>>( - &mut self, - audio_paths: &[P], - mode: Option<TimestampMode>, - ) -> Result<Vec<TranscriptionResult>> { - let mut results = Vec::with_capacity(audio_paths.len()); - for path in audio_paths { - let result = self.transcribe_file(path, mode)?; - results.push(result); - } - Ok(results) - } - - /// Get model directory path - pub fn model_dir(&self) -> &Path { - &self.model_dir - } -} |
