From 3c696cfc9005e73be5ed46f8941dfc8f0aca7102 Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 21 Dec 2025 01:27:02 +0000 Subject: Create container image and move parakeet fork to vendor dir --- parakeet-rs/src/model_tdt.rs | 263 ------------------------------------------- 1 file changed, 263 deletions(-) delete mode 100644 parakeet-rs/src/model_tdt.rs (limited to 'parakeet-rs/src/model_tdt.rs') diff --git a/parakeet-rs/src/model_tdt.rs b/parakeet-rs/src/model_tdt.rs deleted file mode 100644 index e00ebdc..0000000 --- a/parakeet-rs/src/model_tdt.rs +++ /dev/null @@ -1,263 +0,0 @@ -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::{Array1, Array2, Array3}; -use ort::session::Session; -use std::path::{Path, PathBuf}; - -/// TDT model configs -#[derive(Debug, Clone)] -pub struct TDTModelConfig { - pub vocab_size: usize, -} - -impl Default for TDTModelConfig { - fn default() -> Self { - Self { - vocab_size: 8193, - } - } -} - -pub struct ParakeetTDTModel { - encoder: Session, - decoder_joint: Session, - config: TDTModelConfig, -} - -impl ParakeetTDTModel { - /// Load TDT model from directory containing encoder and decoder_joint ONNX files - pub fn from_pretrained>( - model_dir: P, - exec_config: ExecutionConfig, - ) -> Result { - let model_dir = model_dir.as_ref(); - - // Find encoder and decoder_joint files - let encoder_path = Self::find_encoder(model_dir)?; - let decoder_joint_path = Self::find_decoder_joint(model_dir)?; - - let config = TDTModelConfig::default(); - - // Load encoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let encoder = builder.commit_from_file(&encoder_path)?; - - // Load decoder_joint - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let decoder_joint = builder.commit_from_file(&decoder_joint_path)?; - - - Ok(Self { - encoder, - decoder_joint, - config, - }) - } - - fn find_encoder(dir: &Path) -> Result { - let candidates = ["encoder-model.onnx", "encoder.onnx"]; - for candidate in &candidates { - let path = dir.join(candidate); - if path.exists() { - return Ok(path); - } - } - Err(Error::Config(format!( - "No encoder model found in {}", - dir.display() - ))) - } - - fn find_decoder_joint(dir: &Path) -> Result { - let candidates = [ - "decoder_joint-model.onnx", - "decoder_joint.onnx", - "decoder-model.onnx", - ]; - for candidate in &candidates { - let path = dir.join(candidate); - if path.exists() { - return Ok(path); - } - } - Err(Error::Config(format!( - "No decoder_joint model found in {}", - dir.display() - ))) - } - - /// Run greedy decoding - returns (token_ids, frame_indices, durations) - pub fn forward(&mut self, features: Array2) -> Result<(Vec, Vec, Vec)> { - // Run encoder - let (encoder_out, encoder_len) = self.run_encoder(&features)?; - - // Run greedy decoding with decoder_joint - let (tokens, frame_indices, durations) = self.greedy_decode(&encoder_out, encoder_len)?; - - Ok((tokens, frame_indices, durations)) - } - - fn run_encoder(&mut self, features: &Array2) -> Result<(Array3, i64)> { - let batch_size = 1; - let time_steps = features.shape()[0]; - let feature_size = features.shape()[1]; - - // TDT encoder expects (batch, features, time) not (batch, time, features) - let input = features - .t() - .to_shape((batch_size, feature_size, time_steps)) - .map_err(|e| Error::Model(format!("Failed to reshape encoder input: {e}")))? - .to_owned(); - - let input_length = Array1::from_vec(vec![time_steps as i64]); - - let input_value = ort::value::Value::from_array(input)?; - let length_value = ort::value::Value::from_array(input_length)?; - - let outputs = self.encoder.run(ort::inputs!( - "audio_signal" => input_value, - "length" => length_value - ))?; - - let encoder_out = &outputs["outputs"]; - let encoder_lens = &outputs["encoded_lengths"]; - - let (shape, data) = encoder_out - .try_extract_tensor::() - .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; - - let (_, lens_data) = encoder_lens - .try_extract_tensor::() - .map_err(|e| Error::Model(format!("Failed to extract encoder lengths: {e}")))?; - - let shape_dims = shape.as_ref(); - if shape_dims.len() != 3 { - return Err(Error::Model(format!( - "Expected 3D encoder output, got shape: {shape_dims:?}" - ))); - } - - let b = shape_dims[0] as usize; - let t = shape_dims[1] as usize; - let d = shape_dims[2] as usize; - - let encoder_array = Array3::from_shape_vec((b, t, d), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to create encoder array: {e}")))?; - - // TDT encoder outputs [batch, encoder_dim, time] directly - Ok((encoder_array, lens_data[0])) - } - - fn greedy_decode(&mut self, encoder_out: &Array3, _encoder_len: i64) -> Result<(Vec, Vec, Vec)> { - // encoder_out shape: [batch, encoder_dim, time] - let encoder_dim = encoder_out.shape()[1]; - let time_steps = encoder_out.shape()[2]; - let vocab_size = self.config.vocab_size; - let max_tokens_per_step = 10; - let blank_id = vocab_size - 1; - - // States: (num_layers=2, batch=1, hidden_dim=640) - let mut state_h = Array3::::zeros((2, 1, 640)); - let mut state_c = Array3::::zeros((2, 1, 640)); - - let mut tokens = Vec::new(); - let mut frame_indices = Vec::new(); - let mut durations = Vec::new(); - - let mut t = 0; - let mut emitted_tokens = 0; - let mut last_emitted_token = blank_id as i32; - - // Frame-by-frame RNN-T/TDT greedy decoding - while t < time_steps { - // Get single encoder frame: slice [0, :, t] and reshape to [1, encoder_dim, 1] - let frame = encoder_out.slice(ndarray::s![0, .., t]).to_owned(); - let frame_reshaped = frame - .to_shape((1, encoder_dim, 1)) - .map_err(|e| Error::Model(format!("Failed to reshape frame: {e}")))? - .to_owned(); - - // Current token for prediction network - let targets = Array2::from_shape_vec((1, 1), vec![last_emitted_token]) - .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?; - - // Run decoder_joint - let outputs = self.decoder_joint.run(ort::inputs!( - "encoder_outputs" => ort::value::Value::from_array(frame_reshaped)?, - "targets" => ort::value::Value::from_array(targets)?, - "target_length" => ort::value::Value::from_array(Array1::from_vec(vec![1i32]))?, - "input_states_1" => ort::value::Value::from_array(state_h.clone())?, - "input_states_2" => ort::value::Value::from_array(state_c.clone())? - ))?; - - // Extract logits - let (_, logits_data) = outputs["outputs"] - .try_extract_tensor::() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - // TDT outputs vocab_size + 5 durations (8193 + 5 = 8198) - let vocab_logits: Vec = logits_data.iter().take(vocab_size).copied().collect(); - let duration_logits: Vec = logits_data.iter().skip(vocab_size).copied().collect(); - - let token_id = vocab_logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(blank_id); - - let duration_step = if !duration_logits.is_empty() { - duration_logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(idx, _)| idx) - .unwrap_or(0) - } else { - 0 - }; - - // Check if blank token - if token_id != blank_id { - // Update states when we emit a token - if let Ok((h_shape, h_data)) = outputs["output_states_1"].try_extract_tensor::() { - let dims = h_shape.as_ref(); - state_h = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), h_data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to update state_h: {e}")))?; - } - if let Ok((c_shape, c_data)) = outputs["output_states_2"].try_extract_tensor::() { - let dims = c_shape.as_ref(); - state_c = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), c_data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to update state_c: {e}")))?; - } - - tokens.push(token_id); - frame_indices.push(t); - durations.push(duration_step); - last_emitted_token = token_id as i32; - emitted_tokens += 1; - - // Don't advance yet - try to emit more tokens from the same frame - } else { - // Blank token - advance frame pointer - // Duration prediction applies when we finally move to next frame after emitting tokens - if duration_step > 0 && emitted_tokens > 0 { - t += duration_step; - } else { - t += 1; - } - emitted_tokens = 0; - } - - // Safety check: if we've emitted too many tokens from the same frame, advance - if emitted_tokens >= max_tokens_per_step { - t += 1; - emitted_tokens = 0; - } - } - - Ok((tokens, frame_indices, durations)) - } -} -- cgit v1.2.3