diff options
Diffstat (limited to 'parakeet-rs/src/model.rs')
| -rw-r--r-- | parakeet-rs/src/model.rs | 93 |
1 files changed, 0 insertions, 93 deletions
diff --git a/parakeet-rs/src/model.rs b/parakeet-rs/src/model.rs deleted file mode 100644 index b3cd131..0000000 --- a/parakeet-rs/src/model.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::config::ModelConfig; -use crate::error::{Error, Result}; -use crate::execution::ModelConfig as ExecutionConfig; -use ndarray::Array2; -use ort::session::Session; -use std::path::Path; - -pub struct ParakeetModel { - session: Session, - config: ModelConfig, -} - -impl ParakeetModel { - pub fn from_pretrained<P: AsRef<Path>>(model_path: P) -> Result<Self> { - Self::from_pretrained_with_config(model_path, ExecutionConfig::default()) - } - - pub fn from_pretrained_with_config<P: AsRef<Path>>( - model_path: P, - exec_config: ExecutionConfig, - ) -> Result<Self> { - let model_path = model_path.as_ref(); - - // Use default config (hardcoded constants for Parakeet-CTC-0.6b: please see: json files https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main) - let config = ModelConfig::default(); - - let builder = Session::builder()?; - let builder = exec_config.apply_to_session_builder(builder)?; - let session = builder.commit_from_file(model_path)?; - - Ok(Self { session, config }) - } - pub fn forward(&mut self, features: Array2<f32>) -> Result<Array2<f32>> { - let batch_size = 1; - let time_steps = features.shape()[0]; - let feature_size = features.shape()[1]; - - let input = features - .to_shape((batch_size, time_steps, feature_size)) - .map_err(|e| Error::Model(format!("Failed to reshape input: {e}")))? - .to_owned(); - - use ndarray::Array2; - let attention_mask = Array2::<i64>::ones((batch_size, time_steps)); - - let input_value = ort::value::Value::from_array(input)?; - let attention_mask_value = ort::value::Value::from_array(attention_mask)?; - - let outputs = self.session.run(ort::inputs!( - "input_features" => input_value, - "attention_mask" => attention_mask_value - ))?; - - let logits_value = &outputs["logits"]; - let (shape, data) = logits_value - .try_extract_tensor::<f32>() - .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; - - let shape_dims = shape.as_ref(); - if shape_dims.len() != 3 { - return Err(Error::Model(format!( - "Expected 3D logits, got shape: {shape_dims:?}" - ))); - } - - let batch_size = shape_dims[0] as usize; - let time_steps_out = shape_dims[1] as usize; - let vocab_size = shape_dims[2] as usize; - - if batch_size != 1 { - return Err(Error::Model(format!( - "Expected batch size 1, got {batch_size}" - ))); - } - - let logits_2d = Array2::from_shape_vec((time_steps_out, vocab_size), data.to_vec()) - .map_err(|e| Error::Model(format!("Failed to create array: {e}")))?; - - Ok(logits_2d) - } - - pub fn config(&self) -> &ModelConfig { - &self.config - } - - pub fn vocab_size(&self) -> usize { - self.config.vocab_size - } - - pub fn pad_token_id(&self) -> usize { - self.config.pad_token_id - } -} |
