1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
|
use crate::error::{Error, Result};
use crate::execution::ModelConfig as ExecutionConfig;
use ndarray::{Array1, Array2, Array3, Array4};
use ort::session::Session;
use std::path::Path;
/// Encoder cache state for streaming inference
/// The cache maintains temporal context across chunks
pub struct EncoderCache {
/// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback
pub cache_last_channel: Array4<f32>,
/// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps
pub cache_last_time: Array4<f32>,
/// cache length: [1] with value 0 initially
pub cache_last_channel_len: Array1<i64>,
}
impl EncoderCache {
/// 17 layers, batch=1, 70 frame lookback, 512 features
pub fn new() -> Self {
Self {
cache_last_channel: Array4::zeros((17, 1, 70, 512)),
cache_last_time: Array4::zeros((17, 1, 512, 8)),
cache_last_channel_len: Array1::from_vec(vec![0i64]),
}
}
}
pub struct ParakeetEOUModel {
encoder: Session,
decoder_joint: Session,
}
impl ParakeetEOUModel {
pub fn from_pretrained<P: AsRef<Path>>(
model_dir: P,
exec_config: ExecutionConfig,
) -> Result<Self> {
let model_dir = model_dir.as_ref();
let encoder_path = model_dir.join("encoder.onnx");
let decoder_path = model_dir.join("decoder_joint.onnx");
if !encoder_path.exists() || !decoder_path.exists() {
return Err(Error::Config(format!(
"Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx",
model_dir.display()
)));
}
// Load encoder
let builder = Session::builder()?;
let builder = exec_config.apply_to_session_builder(builder)?;
let encoder = builder.commit_from_file(&encoder_path)?;
// Load decoder
let builder = Session::builder()?;
let builder = exec_config.apply_to_session_builder(builder)?;
let decoder_joint = builder.commit_from_file(&decoder_path)?;
Ok(Self {
encoder,
decoder_joint,
})
}
/// Run the stateful encoder with cache
/// Input: features [1, 128, T], cache state
/// Output: (encoded [1, 512, T], new_cache)
pub fn run_encoder(
&mut self,
features: &Array3<f32>,
length: i64,
cache: &EncoderCache
) -> Result<(Array3<f32>, EncoderCache)> {
let length_arr = Array1::from_vec(vec![length]);
let outputs = self.encoder.run(ort::inputs![
"audio_signal" => ort::value::Value::from_array(features.clone())?,
"length" => ort::value::Value::from_array(length_arr)?,
"cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?,
"cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?,
"cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())?
])?;
// Extract encoder output [1, 512, T]
let (shape, data) = outputs["outputs"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?;
let shape_dims = shape.as_ref();
let b = shape_dims[0] as usize;
let d = shape_dims[1] as usize;
let t = shape_dims[2] as usize;
let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec())
.map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?;
// Extract new cache states
let (ch_shape, ch_data) = outputs["new_cache_last_channel"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?;
let (tm_shape, tm_data) = outputs["new_cache_last_time"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?;
let (len_shape, len_data) = outputs["new_cache_last_channel_len"]
.try_extract_tensor::<i64>()
.map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?;
// Build new cache with extracted shapes
let new_cache = EncoderCache {
cache_last_channel: Array4::from_shape_vec(
(ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize),
ch_data.to_vec()
).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?,
cache_last_time: Array4::from_shape_vec(
(tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize),
tm_data.to_vec()
).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?,
cache_last_channel_len: Array1::from_shape_vec(
len_shape[0] as usize,
len_data.to_vec()
).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?,
};
Ok((encoder_out, new_cache))
}
/// Run the stateful decoder
/// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c)
pub fn run_decoder(
&mut self,
encoder_frame: &Array3<f32>, // [1, 512, 1]
last_token: &Array2<i32>, // [1, 1]
state_h: &Array3<f32>, // [1, 1, 640]
state_c: &Array3<f32>, // [1, 1, 640]
) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> {
// Target length is always 1 for single step
let target_len = Array1::from_vec(vec![1i32]);
let outputs = self.decoder_joint.run(ort::inputs![
"encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?,
"targets" => ort::value::Value::from_array(last_token.clone())?,
"target_length" => ort::value::Value::from_array(target_len)?,
"input_states_1" => ort::value::Value::from_array(state_h.clone())?,
"input_states_2" => ort::value::Value::from_array(state_c.clone())?
])?;
// 1. Extract Logits
let (l_shape, l_data) = outputs["outputs"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?;
// 2. Extract States (output_states_1, output_states_2)
let (_h_shape, h_data) = outputs["output_states_1"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?;
let (_c_shape, c_data) = outputs["output_states_2"]
.try_extract_tensor::<f32>()
.map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?;
// Reconstruct Arrays
// Logits: I simplify to [1, 1, vocab]
let vocab_size = l_shape[3] as usize;
let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec())
.map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?;
// States: [1, 1, 640]
let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec())
.map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?;
let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec())
.map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?;
Ok((logits, new_h, new_c))
}
}
|