summaryrefslogtreecommitdiff
path: root/vendor/parakeet-rs/src/vocab.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/parakeet-rs/src/vocab.rs')
-rw-r--r--vendor/parakeet-rs/src/vocab.rs63
1 files changed, 63 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/src/vocab.rs b/vendor/parakeet-rs/src/vocab.rs
new file mode 100644
index 0000000..888568e
--- /dev/null
+++ b/vendor/parakeet-rs/src/vocab.rs
@@ -0,0 +1,63 @@
+use crate::error::{Error, Result};
+use std::fs::File;
+use std::io::{BufRead, BufReader};
+use std::path::Path;
+
+/// Vocabulary parser for vocab.txt format used by TDT models
+#[derive(Debug, Clone)]
+pub struct Vocabulary {
+ pub id_to_token: Vec<String>,
+ pub _blank_id: usize,
+}
+
+impl Vocabulary {
+ /// Load vocabulary from vocab.txt file
+ pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let file = File::open(path.as_ref()).map_err(|e| {
+ Error::Config(format!("Failed to open vocab file: {}", e))
+ })?;
+
+ let reader = BufReader::new(file);
+ let mut id_to_token = Vec::new();
+ let mut blank_id = 0;
+
+ for line in reader.lines() {
+ let line = line.map_err(|e| {
+ Error::Config(format!("Failed to read vocab file: {}", e))
+ })?;
+
+ let parts: Vec<&str> = line.splitn(2, ' ').collect();
+ if parts.len() == 2 {
+ let token = parts[0].to_string();
+ let id: usize = parts[1].parse().map_err(|e| {
+ Error::Config(format!("Invalid token ID in vocab: {}", e))
+ })?;
+
+ if id >= id_to_token.len() {
+ id_to_token.resize(id + 1, String::new());
+ }
+ id_to_token[id] = token.clone();
+
+ // Track blank token
+ if token == "<blk>" || token == "<blank>" {
+ blank_id = id;
+ }
+ }
+ }
+
+ // Default to last token if no blank found
+ if blank_id == 0 && !id_to_token.is_empty() {
+ blank_id = id_to_token.len() - 1;
+ }
+
+ Ok(Self {
+ id_to_token,
+ _blank_id: blank_id,
+ })
+ }
+
+ /// Get token by ID
+ pub fn id_to_text(&self, id: usize) -> Option<&str> {
+ self.id_to_token.get(id).map(|s| s.as_str())
+ }
+}