1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
|
use crate::audio;
use crate::config::PreprocessorConfig;
use crate::decoder::TranscriptionResult;
use crate::decoder_tdt::ParakeetTDTDecoder;
use crate::error::{Error, Result};
use crate::execution::ModelConfig as ExecutionConfig;
use crate::model_tdt::ParakeetTDTModel;
use crate::timestamps::{process_timestamps, TimestampMode};
use crate::vocab::Vocabulary;
use std::path::{Path, PathBuf};
/// Parakeet TDT model for multilingual ASR
pub struct ParakeetTDT {
model: ParakeetTDTModel,
decoder: ParakeetTDTDecoder,
preprocessor_config: PreprocessorConfig,
model_dir: PathBuf,
}
impl ParakeetTDT {
/// Load Parakeet TDT model from path with optional configuration.
///
/// # Arguments
/// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt
/// * `config` - Optional execution configuration (defaults to CPU if None)
pub fn from_pretrained<P: AsRef<Path>>(
path: P,
config: Option<ExecutionConfig>,
) -> Result<Self> {
let path = path.as_ref();
if !path.is_dir() {
return Err(Error::Config(format!(
"TDT model path must be a directory: {}",
path.display()
)));
}
let vocab_path = path.join("vocab.txt");
if !vocab_path.exists() {
return Err(Error::Config(format!(
"vocab.txt not found in {}",
path.display()
)));
}
// TDT-specific preprocessor config (128 features instead of 80)
let preprocessor_config = PreprocessorConfig {
feature_extractor_type: "ParakeetFeatureExtractor".to_string(),
feature_size: 128,
hop_length: 160,
n_fft: 512,
padding_side: "right".to_string(),
padding_value: 0.0,
preemphasis: 0.97,
processor_class: "ParakeetProcessor".to_string(),
return_attention_mask: true,
sampling_rate: 16000,
win_length: 400,
};
let exec_config = config.unwrap_or_default();
let model = ParakeetTDTModel::from_pretrained(path, exec_config)?;
let vocab = Vocabulary::from_file(&vocab_path)?;
let decoder = ParakeetTDTDecoder::from_vocab(vocab);
Ok(Self {
model,
decoder,
preprocessor_config,
model_dir: path.to_path_buf(),
})
}
/// Transcribe audio samples.
///
/// # Arguments
///
/// * `audio` - Audio samples as f32 values
/// * `sample_rate` - Sample rate in Hz
/// * `channels` - Number of audio channels
/// * `mode` - Optional timestamp mode (Token, Word, or Segment)
///
/// # Returns
///
/// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode.
pub fn transcribe_samples(
&mut self,
audio: Vec<f32>,
sample_rate: u32,
channels: u16,
mode: Option<TimestampMode>,
) -> Result<TranscriptionResult> {
let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?;
let (tokens, frame_indices, durations) = self.model.forward(features)?;
let mut result = self.decoder.decode_with_timestamps(
&tokens,
&frame_indices,
&durations,
self.preprocessor_config.hop_length,
self.preprocessor_config.sampling_rate,
)?;
// Apply timestamp mode conversion
let mode = mode.unwrap_or(TimestampMode::Tokens);
result.tokens = process_timestamps(&result.tokens, mode);
// Rebuild full text from processed tokens
result.text = result.tokens.iter()
.map(|t| t.text.as_str())
.collect::<Vec<_>>()
.join(" ");
Ok(result)
}
/// Transcribe an audio file with timestamps
///
/// # Arguments
///
/// * `audio_path` - A path to the audio file that needs to be transcribed.
/// * `mode` - Optional timestamp mode (Token, Word, or Segment)
///
/// # Returns
///
/// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
pub fn transcribe_file<P: AsRef<Path>>(
&mut self,
audio_path: P,
mode: Option<TimestampMode>,
) -> Result<TranscriptionResult> {
let audio_path = audio_path.as_ref();
let (audio, spec) = audio::load_audio(audio_path)?;
self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode)
}
/// Transcribes multiple audio files in batch.
///
/// # Arguments
///
/// * `audio_paths`: A slice of paths to the audio files that need to be transcribed.
/// * `mode` - Optional timestamp mode (Token, Word, or Segment)
///
/// # Returns
///
/// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode.
pub fn transcribe_file_batch<P: AsRef<Path>>(
&mut self,
audio_paths: &[P],
mode: Option<TimestampMode>,
) -> Result<Vec<TranscriptionResult>> {
let mut results = Vec::with_capacity(audio_paths.len());
for path in audio_paths {
let result = self.transcribe_file(path, mode)?;
results.push(result);
}
Ok(results)
}
/// Get model directory path
pub fn model_dir(&self) -> &Path {
&self.model_dir
}
}
|