summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/decoder_tdt.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 00:40:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch)
tree0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/decoder_tdt.rs
parent84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff)
downloadsoryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz
soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/decoder_tdt.rs')
-rw-r--r--parakeet-rs/src/decoder_tdt.rs63
1 files changed, 63 insertions, 0 deletions
diff --git a/parakeet-rs/src/decoder_tdt.rs b/parakeet-rs/src/decoder_tdt.rs
new file mode 100644
index 0000000..65f576d
--- /dev/null
+++ b/parakeet-rs/src/decoder_tdt.rs
@@ -0,0 +1,63 @@
+use crate::decoder::TranscriptionResult;
+use crate::error::Result;
+use crate::vocab::Vocabulary;
+
+/// TDT greedy decoder for Parakeet TDT models
+#[derive(Debug)]
+pub struct ParakeetTDTDecoder {
+ vocab: Vocabulary,
+}
+
+impl ParakeetTDTDecoder {
+ /// Load decoder from vocab file
+ pub fn from_vocab(vocab: Vocabulary) -> Self {
+ Self { vocab }
+ }
+
+ /// Decode tokens with timestamps
+ /// For TDT models, greedy decoding is done in the model, here we just convert to text
+ pub fn decode_with_timestamps(
+ &self,
+ tokens: &[usize],
+ frame_indices: &[usize],
+ _durations: &[usize],
+ hop_length: usize,
+ sample_rate: usize,
+ ) -> Result<TranscriptionResult> {
+ let mut result_tokens = Vec::new();
+ let mut full_text = String::new();
+ // TDT encoder does 8x subsampling
+ let encoder_stride = 8;
+
+ for (i, &token_id) in tokens.iter().enumerate() {
+ if let Some(token_text) = self.vocab.id_to_text(token_id) {
+ let frame = frame_indices[i];
+ let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32;
+ let end = if i + 1 < frame_indices.len() {
+ (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32
+ } else {
+ start + 0.01
+ };
+
+ // Handle SentencePiece format (▁ prefix for word start)
+ let display_text = token_text.replace('▁', " ");
+
+ // Skip special tokens
+ if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") {
+ full_text.push_str(&display_text);
+
+ result_tokens.push(crate::decoder::TimedToken {
+ text: display_text,
+ start,
+ end,
+ });
+ }
+ }
+ }
+
+ Ok(TranscriptionResult {
+ text: full_text.trim().to_string(),
+ tokens: result_tokens,
+ })
+ }
+}