diff options
Diffstat (limited to 'vendor/parakeet-rs/src/parakeet_tdt.rs')
| -rw-r--r-- | vendor/parakeet-rs/src/parakeet_tdt.rs | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/parakeet_tdt.rs b/vendor/parakeet-rs/src/parakeet_tdt.rs new file mode 100644 index 0000000..719ae75 --- /dev/null +++ b/vendor/parakeet-rs/src/parakeet_tdt.rs @@ -0,0 +1,167 @@ +use crate::audio; +use crate::config::PreprocessorConfig; +use crate::decoder::TranscriptionResult; +use crate::decoder_tdt::ParakeetTDTDecoder; +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use crate::model_tdt::ParakeetTDTModel; +use crate::timestamps::{process_timestamps, TimestampMode}; +use crate::vocab::Vocabulary; +use std::path::{Path, PathBuf}; + +/// Parakeet TDT model for multilingual ASR +pub struct ParakeetTDT { + model: ParakeetTDTModel, + decoder: ParakeetTDTDecoder, + preprocessor_config: PreprocessorConfig, + model_dir: PathBuf, +} + +impl ParakeetTDT { + /// Load Parakeet TDT model from path with optional configuration. + /// + /// # Arguments + /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt + /// * `config` - Optional execution configuration (defaults to CPU if None) + pub fn from_pretrained<P: AsRef<Path>>( + path: P, + config: Option<ExecutionConfig>, + ) -> Result<Self> { + let path = path.as_ref(); + + if !path.is_dir() { + return Err(Error::Config(format!( + "TDT model path must be a directory: {}", + path.display() + ))); + } + + let vocab_path = path.join("vocab.txt"); + if !vocab_path.exists() { + return Err(Error::Config(format!( + "vocab.txt not found in {}", + path.display() + ))); + } + + // TDT-specific preprocessor config (128 features instead of 80) + let preprocessor_config = PreprocessorConfig { + feature_extractor_type: "ParakeetFeatureExtractor".to_string(), + feature_size: 128, + hop_length: 160, + n_fft: 512, + padding_side: "right".to_string(), + padding_value: 0.0, + preemphasis: 0.97, + processor_class: "ParakeetProcessor".to_string(), + return_attention_mask: true, + sampling_rate: 16000, + win_length: 400, + }; + + let exec_config = config.unwrap_or_default(); + + let model = ParakeetTDTModel::from_pretrained(path, exec_config)?; + let vocab = Vocabulary::from_file(&vocab_path)?; + let decoder = ParakeetTDTDecoder::from_vocab(vocab); + + Ok(Self { + model, + decoder, + preprocessor_config, + model_dir: path.to_path_buf(), + }) + } + + /// Transcribe audio samples. + /// + /// # Arguments + /// + /// * `audio` - Audio samples as f32 values + /// * `sample_rate` - Sample rate in Hz + /// * `channels` - Number of audio channels + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode. + pub fn transcribe_samples( + &mut self, + audio: Vec<f32>, + sample_rate: u32, + channels: u16, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; + let (tokens, frame_indices, durations) = self.model.forward(features)?; + + let mut result = self.decoder.decode_with_timestamps( + &tokens, + &frame_indices, + &durations, + self.preprocessor_config.hop_length, + self.preprocessor_config.sampling_rate, + )?; + + // Apply timestamp mode conversion + let mode = mode.unwrap_or(TimestampMode::Tokens); + result.tokens = process_timestamps(&result.tokens, mode); + + // Rebuild full text from processed tokens + result.text = result.tokens.iter() + .map(|t| t.text.as_str()) + .collect::<Vec<_>>() + .join(" "); + + Ok(result) + } + + /// Transcribe an audio file with timestamps + /// + /// # Arguments + /// + /// * `audio_path` - A path to the audio file that needs to be transcribed. + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. + pub fn transcribe_file<P: AsRef<Path>>( + &mut self, + audio_path: P, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let audio_path = audio_path.as_ref(); + let (audio, spec) = audio::load_audio(audio_path)?; + + self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) + } + + /// Transcribes multiple audio files in batch. + /// + /// # Arguments + /// + /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. + pub fn transcribe_file_batch<P: AsRef<Path>>( + &mut self, + audio_paths: &[P], + mode: Option<TimestampMode>, + ) -> Result<Vec<TranscriptionResult>> { + let mut results = Vec::with_capacity(audio_paths.len()); + for path in audio_paths { + let result = self.transcribe_file(path, mode)?; + results.push(result); + } + Ok(results) + } + + /// Get model directory path + pub fn model_dir(&self) -> &Path { + &self.model_dir + } +} |
