diff options
Diffstat (limited to 'parakeet-rs/src/execution.rs')
| -rw-r--r-- | parakeet-rs/src/execution.rs | 141 |
1 files changed, 0 insertions, 141 deletions
diff --git a/parakeet-rs/src/execution.rs b/parakeet-rs/src/execution.rs deleted file mode 100644 index e29aa1d..0000000 --- a/parakeet-rs/src/execution.rs +++ /dev/null @@ -1,141 +0,0 @@ -use crate::error::Result; -use ort::session::builder::SessionBuilder; - -// Hardware acceleration options. CPU is default and most reliable. -// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware. -// All GPU providers automatically fall back to CPU if they fail. -// -// Note: CoreML currently fails with this model due to unsupported operations. -// WebGPU is experimental and may produce incorrect results. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum ExecutionProvider { - #[default] - Cpu, - #[cfg(feature = "cuda")] - Cuda, - #[cfg(feature = "tensorrt")] - TensorRT, - #[cfg(feature = "coreml")] - CoreML, - #[cfg(feature = "directml")] - DirectML, - #[cfg(feature = "rocm")] - ROCm, - #[cfg(feature = "openvino")] - OpenVINO, - #[cfg(feature = "webgpu")] - WebGPU, -} - -#[derive(Debug, Clone)] -pub struct ModelConfig { - pub execution_provider: ExecutionProvider, - pub intra_threads: usize, - pub inter_threads: usize, -} - -impl Default for ModelConfig { - fn default() -> Self { - Self { - execution_provider: ExecutionProvider::default(), - intra_threads: 4, - inter_threads: 1, - } - } -} - -impl ModelConfig { - pub fn new() -> Self { - Self::default() - } - - pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self { - self.execution_provider = provider; - self - } - - pub fn with_intra_threads(mut self, threads: usize) -> Self { - self.intra_threads = threads; - self - } - - pub fn with_inter_threads(mut self, threads: usize) -> Self { - self.inter_threads = threads; - self - } - - pub(crate) fn apply_to_session_builder( - &self, - builder: SessionBuilder, - ) -> Result<SessionBuilder> { - use ort::session::builder::GraphOptimizationLevel; - #[cfg(any( - feature = "cuda", - feature = "tensorrt", - feature = "coreml", - feature = "directml", - feature = "rocm", - feature = "openvino", - feature = "webgpu" - ))] - use ort::execution_providers::CPUExecutionProvider; - - let mut builder = builder - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(self.intra_threads)? - .with_inter_threads(self.inter_threads)?; - - builder = match self.execution_provider { - ExecutionProvider::Cpu => builder, - - #[cfg(feature = "cuda")] - ExecutionProvider::Cuda => builder.with_execution_providers([ - ort::execution_providers::CUDAExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "tensorrt")] - ExecutionProvider::TensorRT => builder.with_execution_providers([ - ort::execution_providers::TensorRTExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "coreml")] - ExecutionProvider::CoreML => { - use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider}; - builder.with_execution_providers([ - CoreMLExecutionProvider::default() - .with_compute_units(CoreMLComputeUnits::CPUAndGPU) - .build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])? - } - - #[cfg(feature = "directml")] - ExecutionProvider::DirectML => builder.with_execution_providers([ - ort::execution_providers::DirectMLExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "rocm")] - ExecutionProvider::ROCm => builder.with_execution_providers([ - ort::execution_providers::ROCMExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "openvino")] - ExecutionProvider::OpenVINO => builder.with_execution_providers([ - ort::execution_providers::OpenVINOExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - - #[cfg(feature = "webgpu")] - ExecutionProvider::WebGPU => builder.with_execution_providers([ - ort::execution_providers::WebGPUExecutionProvider::default().build(), - CPUExecutionProvider::default().build().error_on_failure(), - ])?, - }; - - Ok(builder) - } -} |
