summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/model_eou.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 01:27:02 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch)
tree497bffd67001501a003739cfe0bb790502ffd50a /parakeet-rs/src/model_eou.rs
parent55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff)
downloadsoryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz
soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'parakeet-rs/src/model_eou.rs')
-rw-r--r--parakeet-rs/src/model_eou.rs183
1 files changed, 0 insertions, 183 deletions
diff --git a/parakeet-rs/src/model_eou.rs b/parakeet-rs/src/model_eou.rs
deleted file mode 100644
index 5b56e6d..0000000
--- a/parakeet-rs/src/model_eou.rs
+++ /dev/null
@@ -1,183 +0,0 @@
-use crate::error::{Error, Result};
-use crate::execution::ModelConfig as ExecutionConfig;
-use ndarray::{Array1, Array2, Array3, Array4};
-use ort::session::Session;
-use std::path::Path;
-
-/// Encoder cache state for streaming inference
-/// The cache maintains temporal context across chunks
-pub struct EncoderCache {
- /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback
- pub cache_last_channel: Array4<f32>,
- /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps
- pub cache_last_time: Array4<f32>,
- /// cache length: [1] with value 0 initially
- pub cache_last_channel_len: Array1<i64>,
-}
-
-impl EncoderCache {
- /// 17 layers, batch=1, 70 frame lookback, 512 features
- pub fn new() -> Self {
- Self {
- cache_last_channel: Array4::zeros((17, 1, 70, 512)),
- cache_last_time: Array4::zeros((17, 1, 512, 8)),
- cache_last_channel_len: Array1::from_vec(vec![0i64]),
- }
- }
-}
-
-pub struct ParakeetEOUModel {
- encoder: Session,
- decoder_joint: Session,
-}
-
-impl ParakeetEOUModel {
- pub fn from_pretrained<P: AsRef<Path>>(
- model_dir: P,
- exec_config: ExecutionConfig,
- ) -> Result<Self> {
- let model_dir = model_dir.as_ref();
-
- let encoder_path = model_dir.join("encoder.onnx");
- let decoder_path = model_dir.join("decoder_joint.onnx");
-
- if !encoder_path.exists() || !decoder_path.exists() {
- return Err(Error::Config(format!(
- "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx",
- model_dir.display()
- )));
- }
-
- // Load encoder
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let encoder = builder.commit_from_file(&encoder_path)?;
-
- // Load decoder
- let builder = Session::builder()?;
- let builder = exec_config.apply_to_session_builder(builder)?;
- let decoder_joint = builder.commit_from_file(&decoder_path)?;
-
- Ok(Self {
- encoder,
- decoder_joint,
- })
- }
-
- /// Run the stateful encoder with cache
- /// Input: features [1, 128, T], cache state
- /// Output: (encoded [1, 512, T], new_cache)
- pub fn run_encoder(
- &mut self,
- features: &Array3<f32>,
- length: i64,
- cache: &EncoderCache
- ) -> Result<(Array3<f32>, EncoderCache)> {
- let length_arr = Array1::from_vec(vec![length]);
-
- let outputs = self.encoder.run(ort::inputs![
- "audio_signal" => ort::value::Value::from_array(features.clone())?,
- "length" => ort::value::Value::from_array(length_arr)?,
- "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
- "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
- "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
- ])?;
-
- // Extract encoder output [1, 512, T]
- let (shape, data) = outputs["outputs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
-
- let shape_dims = shape.as_ref();
- let b = shape_dims[0] as usize;
- let d = shape_dims[1] as usize;
- let t = shape_dims[2] as usize;
-
- let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
- .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
-
- // Extract new cache states
- let (ch_shape, ch_data) = outputs["new_cache_last_channel"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
-
- let (tm_shape, tm_data) = outputs["new_cache_last_time"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
-
- let (len_shape, len_data) = outputs["new_cache_last_channel_len"]
- .try_extract_tensor::<i64>()
- .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
-
- // Build new cache with extracted shapes
- let new_cache = EncoderCache {
- cache_last_channel: Array4::from_shape_vec(
- (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize),
- ch_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
-
- cache_last_time: Array4::from_shape_vec(
- (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize),
- tm_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
-
- cache_last_channel_len: Array1::from_shape_vec(
- len_shape[0] as usize,
- len_data.to_vec()
- ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
- };
-
- Ok((encoder_out, new_cache))
- }
-
- /// Run the stateful decoder
- /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c)
- pub fn run_decoder(
- &mut self,
- encoder_frame: &Array3<f32>, // [1, 512, 1]
- last_token: &Array2<i32>, // [1, 1]
- state_h: &Array3<f32>, // [1, 1, 640]
- state_c: &Array3<f32>, // [1, 1, 640]
- ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> {
-
- // Target length is always 1 for single step
- let target_len = Array1::from_vec(vec![1i32]);
-
- let outputs = self.decoder_joint.run(ort::inputs![
- "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
- "targets" => ort::value::Value::from_array(last_token.clone())?,
- "target_length" => ort::value::Value::from_array(target_len)?,
- "input_states_1" => ort::value::Value::from_array(state_h.clone())?,
- "input_states_2" => ort::value::Value::from_array(state_c.clone())?
- ])?;
-
- // 1. Extract Logits
- let (l_shape, l_data) = outputs["outputs"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
-
- // 2. Extract States (output_states_1, output_states_2)
- let (_h_shape, h_data) = outputs["output_states_1"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?;
-
- let (_c_shape, c_data) = outputs["output_states_2"]
- .try_extract_tensor::<f32>()
- .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?;
-
- // Reconstruct Arrays
- // Logits: I simplify to [1, 1, vocab]
- let vocab_size = l_shape[3] as usize;
- let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?;
-
- // States: [1, 1, 640]
- let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?;
-
- let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec())
- .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?;
-
- Ok((logits, new_h, new_c))
- }
-} \ No newline at end of file