1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
use crate::decoder::TranscriptionResult;
use crate::error::Result;
use crate::vocab::Vocabulary;
/// TDT greedy decoder for Parakeet TDT models
#[derive(Debug)]
pub struct ParakeetTDTDecoder {
vocab: Vocabulary,
}
impl ParakeetTDTDecoder {
/// Load decoder from vocab file
pub fn from_vocab(vocab: Vocabulary) -> Self {
Self { vocab }
}
/// Decode tokens with timestamps
/// For TDT models, greedy decoding is done in the model, here we just convert to text
pub fn decode_with_timestamps(
&self,
tokens: &[usize],
frame_indices: &[usize],
_durations: &[usize],
hop_length: usize,
sample_rate: usize,
) -> Result<TranscriptionResult> {
let mut result_tokens = Vec::new();
let mut full_text = String::new();
// TDT encoder does 8x subsampling
let encoder_stride = 8;
for (i, &token_id) in tokens.iter().enumerate() {
if let Some(token_text) = self.vocab.id_to_text(token_id) {
let frame = frame_indices[i];
let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32;
let end = if i + 1 < frame_indices.len() {
(frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32
} else {
start + 0.01
};
// Handle SentencePiece format (▁ prefix for word start)
let display_text = token_text.replace('▁', " ");
// Skip special tokens
if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") {
full_text.push_str(&display_text);
result_tokens.push(crate::decoder::TimedToken {
text: display_text,
start,
end,
});
}
}
}
Ok(TranscriptionResult {
text: full_text.trim().to_string(),
tokens: result_tokens,
})
}
}
|