summaryrefslogtreecommitdiff
path: root/vendor/parakeet-rs/src/decoder.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 01:27:02 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch)
tree497bffd67001501a003739cfe0bb790502ffd50a /vendor/parakeet-rs/src/decoder.rs
parent55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff)
downloadsoryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz
soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'vendor/parakeet-rs/src/decoder.rs')
-rw-r--r--vendor/parakeet-rs/src/decoder.rs211
1 files changed, 211 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/decoder.rs b/vendor/parakeet-rs/src/decoder.rs
new file mode 100644
index 0000000..6da6d65
--- /dev/null
+++ b/vendor/parakeet-rs/src/decoder.rs
@@ -0,0 +1,211 @@
+use crate::error::{Error, Result};
+use ndarray::Array2;
+use std::path::Path;
+
+// Token with its timestamp information
+// start and end are in seconds
+#[derive(Debug, Clone)]
+pub struct TimedToken {
+ pub text: String,
+ pub start: f32,
+ pub end: f32,
+}
+
+#[derive(Debug, Clone)]
+pub struct TranscriptionResult {
+ pub text: String,
+ pub tokens: Vec<TimedToken>,
+}
+
+// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps
+pub struct ParakeetDecoder {
+ tokenizer: tokenizers::Tokenizer,
+ pad_token_id: usize,
+}
+
+impl ParakeetDecoder {
+ pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> {
+ let tokenizer_path = tokenizer_path.as_ref();
+
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
+ .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?;
+
+ // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main)
+ let pad_token_id = 1024;
+
+ Ok(Self {
+ tokenizer,
+ pad_token_id,
+ })
+ }
+
+ pub fn decode(&self, logits: &Array2<f32>) -> Result<String> {
+ let time_steps = logits.shape()[0];
+
+ let mut token_ids = Vec::new();
+ for t in 0..time_steps {
+ let logits_t = logits.row(t);
+ let max_idx = logits_t
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+ .map(|(idx, _)| idx)
+ .unwrap_or(0);
+
+ token_ids.push(max_idx as u32);
+ }
+
+ let collapsed = self.ctc_collapse(&token_ids);
+
+ let text = self
+ .tokenizer
+ .decode(&collapsed, true)
+ .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
+
+ Ok(text)
+ }
+
+ fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> {
+ let mut result = Vec::new();
+ let mut prev_token: Option<u32> = None;
+
+ for &token_id in token_ids {
+ if token_id == self.pad_token_id as u32 {
+ prev_token = Some(token_id);
+ continue;
+ }
+
+ if Some(token_id) != prev_token {
+ result.push(token_id);
+ }
+
+ prev_token = Some(token_id);
+ }
+
+ result
+ }
+
+ // CTC collapse with frame tracking for timestamps
+ fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> {
+ let mut result: Vec<(u32, usize, usize)> = Vec::new();
+ let mut prev_token: Option<u32> = None;
+
+ for &(token_id, frame) in token_ids.iter() {
+ if token_id == self.pad_token_id as u32 {
+ prev_token = Some(token_id);
+ continue;
+ }
+
+ if Some(token_id) != prev_token {
+ if let Some(prev) = prev_token {
+ if prev != self.pad_token_id as u32 {
+ // End previous token
+ if let Some(last) = result.last_mut() {
+ last.2 = frame;
+ }
+ }
+ }
+ // Start new token
+ result.push((token_id, frame, frame));
+ }
+
+ prev_token = Some(token_id);
+ }
+
+ // Close last token
+ if let Some(last) = result.last_mut() {
+ last.2 = token_ids.len();
+ }
+
+ result
+ }
+
+ // Decode with token-level timestamps
+ // hop_length and sample_rate are needed to convert frames to seconds
+ pub fn decode_with_timestamps(
+ &self,
+ logits: &Array2<f32>,
+ hop_length: usize,
+ sample_rate: usize,
+ ) -> Result<TranscriptionResult> {
+ let time_steps = logits.shape()[0];
+
+ let mut token_ids_with_frames = Vec::new();
+ for t in 0..time_steps {
+ let logits_t = logits.row(t);
+ let max_idx = logits_t
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+ .map(|(idx, _)| idx)
+ .unwrap_or(0);
+
+ token_ids_with_frames.push((max_idx as u32, t));
+ }
+
+ // CTC collapse with frame tracking
+ let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames);
+
+ // Extract just token IDs for decoding
+ let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect();
+
+ // Decode full text
+ let full_text = self
+ .tokenizer
+ .decode(&token_ids, true)
+ .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?;
+
+ // Progressive decode to detect word boundaries
+ // BPE tokenizers only add spaces when decoding sequences, not individual tokens
+ let mut timed_tokens = Vec::new();
+ let mut prev_decode = String::new();
+
+ for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() {
+ // Decode from start up to and including current token
+ let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i]
+ .iter()
+ .map(|(id, _, _)| *id)
+ .collect();
+
+ if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) {
+ // Find what this token added
+ let added_text = if curr_decode.len() > prev_decode.len() {
+ &curr_decode[prev_decode.len()..]
+ } else {
+ ""
+ };
+
+ if !added_text.is_empty() {
+ let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32;
+ let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32;
+
+ timed_tokens.push(TimedToken {
+ text: added_text.to_string(),
+ start: start_time,
+ end: end_time,
+ });
+ }
+
+ prev_decode = curr_decode;
+ }
+ }
+
+ Ok(TranscriptionResult {
+ text: full_text,
+ tokens: timed_tokens,
+ })
+ }
+
+ // Stub - falls back to greedy decoding. Full beam search with language model is TODO.
+ pub fn decode_with_beam_search(
+ &self,
+ logits: &Array2<f32>,
+ _beam_width: usize,
+ ) -> Result<String> {
+ self.decode(logits)
+ }
+
+ pub fn pad_token_id(&self) -> usize {
+ self.pad_token_id
+ }
+}