use crate::error::{Error, Result}; use crate::execution::ModelConfig as ExecutionConfig; use ndarray::{Array1, Array2, Array3, Array4}; use ort::session::Session; use std::path::Path; /// Encoder cache state for streaming inference /// The cache maintains temporal context across chunks pub struct EncoderCache { /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback pub cache_last_channel: Array4, /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps pub cache_last_time: Array4, /// cache length: [1] with value 0 initially pub cache_last_channel_len: Array1, } impl EncoderCache { /// 17 layers, batch=1, 70 frame lookback, 512 features pub fn new() -> Self { Self { cache_last_channel: Array4::zeros((17, 1, 70, 512)), cache_last_time: Array4::zeros((17, 1, 512, 8)), cache_last_channel_len: Array1::from_vec(vec![0i64]), } } } pub struct ParakeetEOUModel { encoder: Session, decoder_joint: Session, } impl ParakeetEOUModel { pub fn from_pretrained>( model_dir: P, exec_config: ExecutionConfig, ) -> Result { let model_dir = model_dir.as_ref(); let encoder_path = model_dir.join("encoder.onnx"); let decoder_path = model_dir.join("decoder_joint.onnx"); if !encoder_path.exists() || !decoder_path.exists() { return Err(Error::Config(format!( "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx", model_dir.display() ))); } // Load encoder let builder = Session::builder()?; let builder = exec_config.apply_to_session_builder(builder)?; let encoder = builder.commit_from_file(&encoder_path)?; // Load decoder let builder = Session::builder()?; let builder = exec_config.apply_to_session_builder(builder)?; let decoder_joint = builder.commit_from_file(&decoder_path)?; Ok(Self { encoder, decoder_joint, }) } /// Run the stateful encoder with cache /// Input: features [1, 128, T], cache state /// Output: (encoded [1, 512, T], new_cache) pub fn run_encoder( &mut self, features: &Array3, length: i64, cache: &EncoderCache ) -> Result<(Array3, EncoderCache)> { let length_arr = Array1::from_vec(vec![length]); let outputs = self.encoder.run(ort::inputs![ "audio_signal" => ort::value::Value::from_array(features.clone())?, "length" => ort::value::Value::from_array(length_arr)?, "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?, "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?, "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())? ])?; // Extract encoder output [1, 512, T] let (shape, data) = outputs["outputs"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; let shape_dims = shape.as_ref(); let b = shape_dims[0] as usize; let d = shape_dims[1] as usize; let t = shape_dims[2] as usize; let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec()) .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?; // Extract new cache states let (ch_shape, ch_data) = outputs["new_cache_last_channel"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?; let (tm_shape, tm_data) = outputs["new_cache_last_time"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?; let (len_shape, len_data) = outputs["new_cache_last_channel_len"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?; // Build new cache with extracted shapes let new_cache = EncoderCache { cache_last_channel: Array4::from_shape_vec( (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize), ch_data.to_vec() ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?, cache_last_time: Array4::from_shape_vec( (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize), tm_data.to_vec() ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?, cache_last_channel_len: Array1::from_shape_vec( len_shape[0] as usize, len_data.to_vec() ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?, }; Ok((encoder_out, new_cache)) } /// Run the stateful decoder /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c) pub fn run_decoder( &mut self, encoder_frame: &Array3, // [1, 512, 1] last_token: &Array2, // [1, 1] state_h: &Array3, // [1, 1, 640] state_c: &Array3, // [1, 1, 640] ) -> Result<(Array3, Array3, Array3)> { // Target length is always 1 for single step let target_len = Array1::from_vec(vec![1i32]); let outputs = self.decoder_joint.run(ort::inputs![ "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?, "targets" => ort::value::Value::from_array(last_token.clone())?, "target_length" => ort::value::Value::from_array(target_len)?, "input_states_1" => ort::value::Value::from_array(state_h.clone())?, "input_states_2" => ort::value::Value::from_array(state_c.clone())? ])?; // 1. Extract Logits let (l_shape, l_data) = outputs["outputs"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; // 2. Extract States (output_states_1, output_states_2) let (_h_shape, h_data) = outputs["output_states_1"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?; let (_c_shape, c_data) = outputs["output_states_2"] .try_extract_tensor::() .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?; // Reconstruct Arrays // Logits: I simplify to [1, 1, vocab] let vocab_size = l_shape[3] as usize; let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec()) .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?; // States: [1, 1, 640] let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec()) .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?; let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec()) .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?; Ok((logits, new_h, new_c)) } }