diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 01:27:02 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch) | |
| tree | 497bffd67001501a003739cfe0bb790502ffd50a /parakeet-rs/src/model_eou.rs | |
| parent | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff) | |
| download | soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip | |
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'parakeet-rs/src/model_eou.rs')
| -rw-r--r-- | parakeet-rs/src/model_eou.rs | 183 |
1 files changed, 0 insertions, 183 deletions
diff --git a/parakeet-rs/src/model_eou.rs b/parakeet-rs/src/model_eou.rs deleted file mode 100644 index 5b56e6d..0000000 --- a/parakeet-rs/src/model_eou.rs +++ /dev/null @@ -1,183 +0,0 @@ -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::{Array1, Array2, Array3, Array4}; -use ort::session::Session; -use std::path::Path; - -/// Encoder cache state for streaming inference -/// The cache maintains temporal context across chunks -pub struct EncoderCache { - /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback - pub cache_last_channel: Array4<f32>, - /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps - pub cache_last_time: Array4<f32>, - /// cache length: [1] with value 0 initially - pub cache_last_channel_len: Array1<i64>, -} - -impl EncoderCache { - /// 17 layers, batch=1, 70 frame lookback, 512 features - pub fn new() -> Self { - Self { - cache_last_channel: Array4::zeros((17, 1, 70, 512)), - cache_last_time: Array4::zeros((17, 1, 512, 8)), - cache_last_channel_len: Array1::from_vec(vec![0i64]), - } - } -} - -pub struct ParakeetEOUModel { - encoder: Session, - decoder_joint: Session, -} - -impl ParakeetEOUModel { - pub fn from_pretrained<P: AsRef<Path>>( - model_dir: P, - exec_config: ExecutionConfig, - ) -> Result<Self> { - let model_dir = model_dir.as_ref(); - - let encoder_path = model_dir.join("encoder.onnx"); - let decoder_path = model_dir.join("decoder_joint.onnx"); - - if !encoder_path.exists() || !decoder_path.exists() { - return Err(Error::Config(format!( - "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx", - model_dir.display() - ))); - } - - // Load encoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let encoder = builder.commit_from_file(&encoder_path)?; - - // Load decoder - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let decoder_joint = builder.commit_from_file(&decoder_path)?; - - Ok(Self { - encoder, - decoder_joint, - }) - } - - /// Run the stateful encoder with cache - /// Input: features [1, 128, T], cache state - /// Output: (encoded [1, 512, T], new_cache) - pub fn run_encoder( - &mut self, - features: &Array3<f32>, - length: i64, - cache: &EncoderCache - ) -> Result<(Array3<f32>, EncoderCache)> { - let length_arr = Array1::from_vec(vec![length]); - - let outputs = self.encoder.run(ort::inputs![ - "audio_signal" => ort::value::Value::from_array(features.clone())?, - "length" => ort::value::Value::from_array(length_arr)?, - "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?, - "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?, - "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())? - ])?; - - // Extract encoder output [1, 512, T] - let (shape, data) = outputs["outputs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; - - let shape_dims = shape.as_ref(); - let b = shape_dims[0] as usize; - let d = shape_dims[1] as usize; - let t = shape_dims[2] as usize; - - let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?; - - // Extract new cache states - let (ch_shape, ch_data) = outputs["new_cache_last_channel"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?; - - let (tm_shape, tm_data) = outputs["new_cache_last_time"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?; - - let (len_shape, len_data) = outputs["new_cache_last_channel_len"] - .try_extract_tensor::<i64>() - .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?; - - // Build new cache with extracted shapes - let new_cache = EncoderCache { - cache_last_channel: Array4::from_shape_vec( - (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize), - ch_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?, - - cache_last_time: Array4::from_shape_vec( - (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize), - tm_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?, - - cache_last_channel_len: Array1::from_shape_vec( - len_shape[0] as usize, - len_data.to_vec() - ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?, - }; - - Ok((encoder_out, new_cache)) - } - - /// Run the stateful decoder - /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c) - pub fn run_decoder( - &mut self, - encoder_frame: &Array3<f32>, // [1, 512, 1] - last_token: &Array2<i32>, // [1, 1] - state_h: &Array3<f32>, // [1, 1, 640] - state_c: &Array3<f32>, // [1, 1, 640] - ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> { - - // Target length is always 1 for single step - let target_len = Array1::from_vec(vec![1i32]); - - let outputs = self.decoder_joint.run(ort::inputs![ - "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?, - "targets" => ort::value::Value::from_array(last_token.clone())?, - "target_length" => ort::value::Value::from_array(target_len)?, - "input_states_1" => ort::value::Value::from_array(state_h.clone())?, - "input_states_2" => ort::value::Value::from_array(state_c.clone())? - ])?; - - // 1. Extract Logits - let (l_shape, l_data) = outputs["outputs"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - // 2. Extract States (output_states_1, output_states_2) - let (_h_shape, h_data) = outputs["output_states_1"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?; - - let (_c_shape, c_data) = outputs["output_states_2"] - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?; - - // Reconstruct Arrays - // Logits: I simplify to [1, 1, vocab] - let vocab_size = l_shape[3] as usize; - let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?; - - // States: [1, 1, 640] - let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?; - - let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec()) - .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?; - - Ok((logits, new_h, new_c)) - } -}
\ No newline at end of file |
