summaryrefslogblamecommitdiff
path: root/parakeet-rs/src/decoder.rs
blob: 6da6d6503bba9580e7e61a60b6a2a899bd1eea5e (plain) (tree)


















































































































































































































                                                                                                                                                                                          
use crate::error::{Error, Result};
use ndarray::Array2;
use std::path::Path;

// Token with its timestamp information
// start and end are in seconds
#[derive(Debug, Clone)]
pub struct TimedToken {
    pub text: String,
    pub start: f32,
    pub end: f32,
}

#[derive(Debug, Clone)]
pub struct TranscriptionResult {
    pub text: String,
    pub tokens: Vec<TimedToken>,
}

// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps
pub struct ParakeetDecoder {
    tokenizer: tokenizers::Tokenizer,
    pad_token_id: usize,
}

impl ParakeetDecoder {
    pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> {
        let tokenizer_path = tokenizer_path.as_ref();

        let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
            .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?;

        // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
        let pad_token_id = 1024;

        Ok(Self {
            tokenizer,
            pad_token_id,
        })
    }

    pub fn decode(&self, logits: &Array2<f32>) -> Result<String> {
        let time_steps = logits.shape()[0];

        let mut token_ids = Vec::new();
        for t in 0..time_steps {
            let logits_t = logits.row(t);
            let max_idx = logits_t
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                .map(|(idx, _)| idx)
                .unwrap_or(0);

            token_ids.push(max_idx as u32);
        }

        let collapsed = self.ctc_collapse(&token_ids);

        let text = self
            .tokenizer
            .decode(&collapsed, true)
            .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;

        Ok(text)
    }

    fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> {
        let mut result = Vec::new();
        let mut prev_token: Option<u32> = None;

        for &token_id in token_ids {
            if token_id == self.pad_token_id as u32 {
                prev_token = Some(token_id);
                continue;
            }

            if Some(token_id) != prev_token {
                result.push(token_id);
            }

            prev_token = Some(token_id);
        }

        result
    }

    // CTC collapse with frame tracking for timestamps
    fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> {
        let mut result: Vec<(u32, usize, usize)> = Vec::new();
        let mut prev_token: Option<u32> = None;

        for &(token_id, frame) in token_ids.iter() {
            if token_id == self.pad_token_id as u32 {
                prev_token = Some(token_id);
                continue;
            }

            if Some(token_id) != prev_token {
                if let Some(prev) = prev_token {
                    if prev != self.pad_token_id as u32 {
                        // End previous token
                        if let Some(last) = result.last_mut() {
                            last.2 = frame;
                        }
                    }
                }
                // Start new token
                result.push((token_id, frame, frame));
            }

            prev_token = Some(token_id);
        }

        // Close last token
        if let Some(last) = result.last_mut() {
            last.2 = token_ids.len();
        }

        result
    }

    // Decode with token-level timestamps
    // hop_length and sample_rate are needed to convert frames to seconds
    pub fn decode_with_timestamps(
        &self,
        logits: &Array2<f32>,
        hop_length: usize,
        sample_rate: usize,
    ) -> Result<TranscriptionResult> {
        let time_steps = logits.shape()[0];

        let mut token_ids_with_frames = Vec::new();
        for t in 0..time_steps {
            let logits_t = logits.row(t);
            let max_idx = logits_t
                .iter()
                .enumerate()
                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
                .map(|(idx, _)| idx)
                .unwrap_or(0);

            token_ids_with_frames.push((max_idx as u32, t));
        }

        // CTC collapse with frame tracking
        let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames);

        // Extract just token IDs for decoding
        let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect();

        // Decode full text
        let full_text = self
            .tokenizer
            .decode(&token_ids, true)
            .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;

        // Progressive decode to detect word boundaries
        // BPE tokenizers only add spaces when decoding sequences, not individual tokens
        let mut timed_tokens = Vec::new();
        let mut prev_decode = String::new();

        for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() {
            // Decode from start up to and including current token
            let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i]
                .iter()
                .map(|(id, _, _)| *id)
                .collect();

            if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) {
                // Find what this token added
                let added_text = if curr_decode.len() > prev_decode.len() {
                    &curr_decode[prev_decode.len()..]
                } else {
                    ""
                };

                if !added_text.is_empty() {
                    let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32;
                    let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32;

                    timed_tokens.push(TimedToken {
                        text: added_text.to_string(),
                        start: start_time,
                        end: end_time,
                    });
                }

                prev_decode = curr_decode;
            }
        }

        Ok(TranscriptionResult {
            text: full_text,
            tokens: timed_tokens,
        })
    }

    // Stub - falls back to greedy decoding. Full beam search with language model is TODO.
    pub fn decode_with_beam_search(
        &self,
        logits: &Array2<f32>,
        _beam_width: usize,
    ) -> Result<String> {
        self.decode(logits)
    }

    pub fn pad_token_id(&self) -> usize {
        self.pad_token_id
    }
}