summaryrefslogtreecommitdiff
path: root/vendor/parakeet-rs/src/parakeet_tdt.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/parakeet-rs/src/parakeet_tdt.rs')
-rw-r--r--vendor/parakeet-rs/src/parakeet_tdt.rs167
1 files changed, 167 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/parakeet_tdt.rs b/vendor/parakeet-rs/src/parakeet_tdt.rs
new file mode 100644
index 0000000..719ae75
--- /dev/null
+++ b/vendor/parakeet-rs/src/parakeet_tdt.rs
@@ -0,0 +1,167 @@
+use crate::audio;
+use crate::config::PreprocessorConfig;
+use crate::decoder::TranscriptionResult;
+use crate::decoder_tdt::ParakeetTDTDecoder;
+use crate::error::{Error, Result};
+use crate::execution::ModelConfig as ExecutionConfig;
+use crate::model_tdt::ParakeetTDTModel;
+use crate::timestamps::{process_timestamps, TimestampMode};
+use crate::vocab::Vocabulary;
+use std::path::{Path, PathBuf};
+
+/// Parakeet TDT model for multilingual ASR
+pub struct ParakeetTDT {
+ model: ParakeetTDTModel,
+ decoder: ParakeetTDTDecoder,
+ preprocessor_config: PreprocessorConfig,
+ model_dir: PathBuf,
+}
+
+impl ParakeetTDT {
+ /// Load Parakeet TDT model from path with optional configuration.
+ ///
+ /// # Arguments
+ /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt
+ /// * `config` - Optional execution configuration (defaults to CPU if None)
+ pub fn from_pretrained<P: AsRef<Path>>(
+ path: P,
+ config: Option<ExecutionConfig>,
+ ) -> Result<Self> {
+ let path = path.as_ref();
+
+ if !path.is_dir() {
+ return Err(Error::Config(format!(
+ "TDT model path must be a directory: {}",
+ path.display()
+ )));
+ }
+
+ let vocab_path = path.join("vocab.txt");
+ if !vocab_path.exists() {
+ return Err(Error::Config(format!(
+ "vocab.txt not found in {}",
+ path.display()
+ )));
+ }
+
+ // TDT-specific preprocessor config (128 features instead of 80)
+ let preprocessor_config = PreprocessorConfig {
+ feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
+ feature_size: 128,
+ hop_length: 160,
+ n_fft: 512,
+ padding_side: "right".to_string(),
+ padding_value: 0.0,
+ preemphasis: 0.97,
+ processor_class: "ParakeetProcessor".to_string(),
+ return_attention_mask: true,
+ sampling_rate: 16000,
+ win_length: 400,
+ };
+
+ let exec_config = config.unwrap_or_default();
+
+ let model = ParakeetTDTModel::from_pretrained(path, exec_config)?;
+ let vocab = Vocabulary::from_file(&vocab_path)?;
+ let decoder = ParakeetTDTDecoder::from_vocab(vocab);
+
+ Ok(Self {
+ model,
+ decoder,
+ preprocessor_config,
+ model_dir: path.to_path_buf(),
+ })
+ }
+
+ /// Transcribe audio samples.
+ ///
+ /// # Arguments
+ ///
+ /// * `audio` - Audio samples as f32 values
+ /// * `sample_rate` - Sample rate in Hz
+ /// * `channels` - Number of audio channels
+ /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
+ ///
+ /// # Returns
+ ///
+ /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode.
+ pub fn transcribe_samples(
+ &mut self,
+ audio: Vec<f32>,
+ sample_rate: u32,
+ channels: u16,
+ mode: Option<TimestampMode>,
+ ) -> Result<TranscriptionResult> {
+ let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
+ let (tokens, frame_indices, durations) = self.model.forward(features)?;
+
+ let mut result = self.decoder.decode_with_timestamps(
+ &tokens,
+ &frame_indices,
+ &durations,
+ self.preprocessor_config.hop_length,
+ self.preprocessor_config.sampling_rate,
+ )?;
+
+ // Apply timestamp mode conversion
+ let mode = mode.unwrap_or(TimestampMode::Tokens);
+ result.tokens = process_timestamps(&result.tokens, mode);
+
+ // Rebuild full text from processed tokens
+ result.text = result.tokens.iter()
+ .map(|t| t.text.as_str())
+ .collect::<Vec<_>>()
+ .join(" ");
+
+ Ok(result)
+ }
+
+ /// Transcribe an audio file with timestamps
+ ///
+ /// # Arguments
+ ///
+ /// * `audio_path` - A path to the audio file that needs to be transcribed.
+ /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
+ ///
+ /// # Returns
+ ///
+ /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
+ pub fn transcribe_file<P: AsRef<Path>>(
+ &mut self,
+ audio_path: P,
+ mode: Option<TimestampMode>,
+ ) -> Result<TranscriptionResult> {
+ let audio_path = audio_path.as_ref();
+ let (audio, spec) = audio::load_audio(audio_path)?;
+
+ self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
+ }
+
+ /// Transcribes multiple audio files in batch.
+ ///
+ /// # Arguments
+ ///
+ /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
+ /// * `mode` - Optional timestamp mode (Token, Word, or Segment)
+ ///
+ /// # Returns
+ ///
+ /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
+ pub fn transcribe_file_batch<P: AsRef<Path>>(
+ &mut self,
+ audio_paths: &[P],
+ mode: Option<TimestampMode>,
+ ) -> Result<Vec<TranscriptionResult>> {
+ let mut results = Vec::with_capacity(audio_paths.len());
+ for path in audio_paths {
+ let result = self.transcribe_file(path, mode)?;
+ results.push(result);
+ }
+ Ok(results)
+ }
+
+ /// Get model directory path
+ pub fn model_dir(&self) -> &Path {
+ &self.model_dir
+ }
+}