summaryrefslogtreecommitdiff
path: root/vendor/parakeet-rs/src/decoder_tdt.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 01:27:02 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch)
tree497bffd67001501a003739cfe0bb790502ffd50a /vendor/parakeet-rs/src/decoder_tdt.rs
parent55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff)
downloadsoryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz
soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'vendor/parakeet-rs/src/decoder_tdt.rs')
-rw-r--r--vendor/parakeet-rs/src/decoder_tdt.rs63
1 files changed, 63 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/decoder_tdt.rs b/vendor/parakeet-rs/src/decoder_tdt.rs
new file mode 100644
index 0000000..65f576d
--- /dev/null
+++ b/vendor/parakeet-rs/src/decoder_tdt.rs
@@ -0,0 +1,63 @@
+use crate::decoder::TranscriptionResult;
+use crate::error::Result;
+use crate::vocab::Vocabulary;
+
+/// TDT greedy decoder for Parakeet TDT models
+#[derive(Debug)]
+pub struct ParakeetTDTDecoder {
+ vocab: Vocabulary,
+}
+
+impl ParakeetTDTDecoder {
+ /// Load decoder from vocab file
+ pub fn from_vocab(vocab: Vocabulary) -> Self {
+ Self { vocab }
+ }
+
+ /// Decode tokens with timestamps
+ /// For TDT models, greedy decoding is done in the model, here we just convert to text
+ pub fn decode_with_timestamps(
+ &self,
+ tokens: &[usize],
+ frame_indices: &[usize],
+ _durations: &[usize],
+ hop_length: usize,
+ sample_rate: usize,
+ ) -> Result<TranscriptionResult> {
+ let mut result_tokens = Vec::new();
+ let mut full_text = String::new();
+ // TDT encoder does 8x subsampling
+ let encoder_stride = 8;
+
+ for (i, &token_id) in tokens.iter().enumerate() {
+ if let Some(token_text) = self.vocab.id_to_text(token_id) {
+ let frame = frame_indices[i];
+ let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32;
+ let end = if i + 1 < frame_indices.len() {
+ (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32
+ } else {
+ start + 0.01
+ };
+
+ // Handle SentencePiece format (▁ prefix for word start)
+ let display_text = token_text.replace('▁', " ");
+
+ // Skip special tokens
+ if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") {
+ full_text.push_str(&display_text);
+
+ result_tokens.push(crate::decoder::TimedToken {
+ text: display_text,
+ start,
+ end,
+ });
+ }
+ }
+ }
+
+ Ok(TranscriptionResult {
+ text: full_text.trim().to_string(),
+ tokens: result_tokens,
+ })
+ }
+}