summaryrefslogblamecommitdiff
path: root/vendor/parakeet-rs/src/decoder_tdt.rs
blob: 65f576d594940e815f3f3ce0dc8c2499786be89c (plain) (tree)






























































                                                                                                         
use crate::decoder::TranscriptionResult;
use crate::error::Result;
use crate::vocab::Vocabulary;

/// TDT greedy decoder for Parakeet TDT models
#[derive(Debug)]
pub struct ParakeetTDTDecoder {
    vocab: Vocabulary,
}

impl ParakeetTDTDecoder {
    /// Load decoder from vocab file
    pub fn from_vocab(vocab: Vocabulary) -> Self {
        Self { vocab }
    }

    /// Decode tokens with timestamps
    /// For TDT models, greedy decoding is done in the model, here we just convert to text
    pub fn decode_with_timestamps(
        &self,
        tokens: &[usize],
        frame_indices: &[usize],
        _durations: &[usize],
        hop_length: usize,
        sample_rate: usize,
    ) -> Result<TranscriptionResult> {
        let mut result_tokens = Vec::new();
        let mut full_text = String::new();
         // TDT encoder does 8x subsampling
        let encoder_stride = 8;

        for (i, &token_id) in tokens.iter().enumerate() {
            if let Some(token_text) = self.vocab.id_to_text(token_id) {
                let frame = frame_indices[i];
                let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32;
                let end = if i + 1 < frame_indices.len() {
                    (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32
                } else {
                    start + 0.01
                };

                // Handle SentencePiece format (▁ prefix for word start)
                let display_text = token_text.replace('▁', " ");

                // Skip special tokens
                if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") {
                    full_text.push_str(&display_text);

                    result_tokens.push(crate::decoder::TimedToken {
                        text: display_text,
                        start,
                        end,
                    });
                }
            }
        }

        Ok(TranscriptionResult {
            text: full_text.trim().to_string(),
            tokens: result_tokens,
        })
    }
}