summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/timestamps.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 00:40:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch)
tree0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/timestamps.rs
parent84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff)
downloadsoryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz
soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/timestamps.rs')
-rw-r--r--parakeet-rs/src/timestamps.rs280
1 files changed, 280 insertions, 0 deletions
diff --git a/parakeet-rs/src/timestamps.rs b/parakeet-rs/src/timestamps.rs
new file mode 100644
index 0000000..81ea600
--- /dev/null
+++ b/parakeet-rs/src/timestamps.rs
@@ -0,0 +1,280 @@
+use crate::decoder::TimedToken;
+
+/// Timestamp output mode for transcription results
+///
+/// Determines how token-level timestamps are grouped and presented:
+/// - `Tokens`: Raw token-level output from the model (most detailed)
+/// - `Words`: Tokens grouped into individual words
+/// - `Sentences`: Tokens grouped by sentence boundaries (., ?, !)
+///
+/// # Model-Specific Recommendations
+///
+/// - **Parakeet CTC (English)**: Use `Words` mode. The CTC model only outputs lowercase
+/// alphabet without punctuation, so sentence segmentation is not possible.
+/// - **Parakeet TDT (Multilingual)**: Use `Sentences` mode. The TDT model predicts
+/// punctuation, enabling natural sentence boundaries.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum TimestampMode {
+ /// Raw token-level timestamps from the model
+ Tokens,
+ /// Word-level timestamps (groups subword tokens)
+ Words,
+ /// Sentence-level timestamps (groups by punctuation)
+ ///
+ /// Note: Only works with models that predict punctuation (e.g., Parakeet TDT).
+ /// CTC models don't predict punctuation, so use `Words` mode instead.
+ Sentences,
+}
+
+impl Default for TimestampMode {
+ fn default() -> Self {
+ Self::Tokens
+ }
+}
+
+/// Convert token timestamps to the requested output mode
+///
+/// Takes raw token-level timestamps from the model and optionally groups them
+/// into words or sentences while preserving the original timing information.
+///
+/// # Arguments
+///
+/// * `tokens` - Raw token-level timestamps from model output
+/// * `mode` - Desired grouping level (Tokens, Words, or Sentences)
+///
+/// # Returns
+///
+/// Vector of TimedToken with timestamps at the requested granularity
+pub fn process_timestamps(tokens: &[TimedToken], mode: TimestampMode) -> Vec<TimedToken> {
+ match mode {
+ TimestampMode::Tokens => tokens.to_vec(),
+ TimestampMode::Words => group_by_words(tokens),
+ TimestampMode::Sentences => group_by_sentences(tokens),
+ }
+}
+
+// Group tokens into words based on word boundary markers
+fn group_by_words(tokens: &[TimedToken]) -> Vec<TimedToken> {
+ if tokens.is_empty() {
+ return Vec::new();
+ }
+
+ let mut words = Vec::new();
+ let mut current_word_text = String::new();
+ let mut current_word_start = 0.0;
+ let mut last_word_lower = String::new();
+
+ for (i, token) in tokens.iter().enumerate() {
+ // Skip empty tokens
+ if token.text.trim().is_empty() {
+ continue;
+ }
+
+ // Check if this starts a new word (SentencePiece uses ▁ or space prefix)
+ // Also treat PURE punctuation marks (like ".", ",") as separate words
+ // But NOT contractions like "'re" or "'s" which should attach to previous word
+ let is_pure_punctuation = !token.text.is_empty() &&
+ token.text.chars().all(|c| c.is_ascii_punctuation());
+
+ // Check if this is a contraction suffix
+ // These should NOT start a new word - they attach to the previous word
+ let token_without_marker = token.text.trim_start_matches('▁').trim_start_matches(' ');
+ let is_contraction = token_without_marker.starts_with('\'');
+
+ let starts_word = (token.text.starts_with('▁')
+ || token.text.starts_with(' ')
+ || is_pure_punctuation)
+ && !is_contraction
+ || i == 0;
+
+ if starts_word && !current_word_text.is_empty() {
+ // Save previous word (with deduplication)
+ let word_lower = current_word_text.to_lowercase();
+ if word_lower != last_word_lower {
+ words.push(TimedToken {
+ text: current_word_text.clone(),
+ start: current_word_start,
+ end: tokens[i - 1].end,
+ });
+ last_word_lower = word_lower;
+ }
+ current_word_text.clear();
+ }
+
+ // Start new word or append to current
+ if current_word_text.is_empty() {
+ current_word_start = token.start;
+ }
+
+ // Add token text, removing word boundary markers
+ let token_text = token
+ .text
+ .trim_start_matches('▁')
+ .trim_start_matches(' ');
+ current_word_text.push_str(token_text);
+ }
+
+ // Add final word
+ if !current_word_text.is_empty() {
+ let word_lower = current_word_text.to_lowercase();
+ if word_lower != last_word_lower {
+ words.push(TimedToken {
+ text: current_word_text,
+ start: current_word_start,
+ end: tokens.last().unwrap().end,
+ });
+ }
+ }
+
+ words
+}
+
+// Group words into sentences based on punctuation
+fn group_by_sentences(tokens: &[TimedToken]) -> Vec<TimedToken> {
+ // First get word-level grouping
+ let words = group_by_words(tokens);
+ if words.is_empty() {
+ return Vec::new();
+ }
+
+ let mut sentences = Vec::new();
+ let mut current_sentence = Vec::new();
+
+ for word in words {
+ current_sentence.push(word.clone());
+
+ // Check if word ends with sentence terminator
+ let ends_sentence = word.text.contains('.')
+ || word.text.contains('?')
+ || word.text.contains('!');
+
+ if ends_sentence {
+ let sentence_text = format_sentence(&current_sentence);
+ let start = current_sentence.first().unwrap().start;
+ let end = current_sentence.last().unwrap().end;
+
+ if !sentence_text.is_empty() {
+ sentences.push(TimedToken {
+ text: sentence_text,
+ start,
+ end,
+ });
+ }
+ current_sentence.clear();
+ }
+ }
+
+ // Add final sentence if exists
+ if !current_sentence.is_empty() {
+ let sentence_text = format_sentence(&current_sentence);
+ let start = current_sentence.first().unwrap().start;
+ let end = current_sentence.last().unwrap().end;
+
+ if !sentence_text.is_empty() {
+ sentences.push(TimedToken {
+ text: sentence_text,
+ start,
+ end,
+ });
+ }
+ }
+
+ sentences
+}
+
+// Join words with punctuation spacing
+fn format_sentence(words: &[TimedToken]) -> String {
+ let result: Vec<&str> = words.iter().map(|w| w.text.as_str()).collect();
+
+ // Join words, but don't add space before certain punctuation
+ let mut output = String::new();
+ for (i, word) in result.iter().enumerate() {
+ // Check if this word is standalone punctuation that shouldn't have space before it
+ // Contractions like "'re" or "'s" should have spaces before them
+ let is_standalone_punct = word.len() == 1 &&
+ word.chars().all(|c| matches!(c, '.' | ',' | '!' | '?' | ';' | ':' | ')'));
+
+ if i > 0 && !is_standalone_punct {
+ output.push(' ');
+ }
+ output.push_str(word);
+ }
+ output
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_word_grouping() {
+ let tokens = vec![
+ TimedToken {
+ text: "▁Hello".to_string(),
+ start: 0.0,
+ end: 0.5,
+ },
+ TimedToken {
+ text: "▁world".to_string(),
+ start: 0.5,
+ end: 1.0,
+ },
+ ];
+
+ let words = group_by_words(&tokens);
+ assert_eq!(words.len(), 2);
+ assert_eq!(words[0].text, "Hello");
+ assert_eq!(words[1].text, "world");
+ }
+
+ #[test]
+ fn test_sentence_grouping() {
+ let tokens = vec![
+ TimedToken {
+ text: "▁Hello".to_string(),
+ start: 0.0,
+ end: 0.5,
+ },
+ TimedToken {
+ text: "▁world".to_string(),
+ start: 0.5,
+ end: 1.0,
+ },
+ TimedToken {
+ text: ".".to_string(),
+ start: 1.0,
+ end: 1.1,
+ },
+ ];
+
+ let sentences = group_by_sentences(&tokens);
+ assert_eq!(sentences.len(), 1);
+ assert_eq!(sentences[0].text, "Hello world.");
+ assert_eq!(sentences[0].start, 0.0);
+ assert_eq!(sentences[0].end, 1.1);
+ }
+
+ #[test]
+ fn test_repetition_preservation() {
+ let words = vec![
+ TimedToken {
+ text: "uh".to_string(),
+ start: 0.0,
+ end: 0.5,
+ },
+ TimedToken {
+ text: "uh".to_string(),
+ start: 0.5,
+ end: 1.0,
+ },
+ TimedToken {
+ text: "hello".to_string(),
+ start: 1.0,
+ end: 1.5,
+ },
+ ];
+
+ let result = format_sentence(&words);
+ assert_eq!(result, "uh uh hello");
+ }
+}