diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 00:40:04 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch) | |
| tree | 0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/execution.rs | |
| parent | 84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff) | |
| download | soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip | |
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/execution.rs')
| -rw-r--r-- | parakeet-rs/src/execution.rs | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/parakeet-rs/src/execution.rs b/parakeet-rs/src/execution.rs new file mode 100644 index 0000000..e29aa1d --- /dev/null +++ b/parakeet-rs/src/execution.rs @@ -0,0 +1,141 @@ +use crate::error::Result; +use ort::session::builder::SessionBuilder; + +// Hardware acceleration options. CPU is default and most reliable. +// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware. +// All GPU providers automatically fall back to CPU if they fail. +// +// Note: CoreML currently fails with this model due to unsupported operations. +// WebGPU is experimental and may produce incorrect results. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExecutionProvider { + #[default] + Cpu, + #[cfg(feature = "cuda")] + Cuda, + #[cfg(feature = "tensorrt")] + TensorRT, + #[cfg(feature = "coreml")] + CoreML, + #[cfg(feature = "directml")] + DirectML, + #[cfg(feature = "rocm")] + ROCm, + #[cfg(feature = "openvino")] + OpenVINO, + #[cfg(feature = "webgpu")] + WebGPU, +} + +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub execution_provider: ExecutionProvider, + pub intra_threads: usize, + pub inter_threads: usize, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + execution_provider: ExecutionProvider::default(), + intra_threads: 4, + inter_threads: 1, + } + } +} + +impl ModelConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self { + self.execution_provider = provider; + self + } + + pub fn with_intra_threads(mut self, threads: usize) -> Self { + self.intra_threads = threads; + self + } + + pub fn with_inter_threads(mut self, threads: usize) -> Self { + self.inter_threads = threads; + self + } + + pub(crate) fn apply_to_session_builder( + &self, + builder: SessionBuilder, + ) -> Result<SessionBuilder> { + use ort::session::builder::GraphOptimizationLevel; + #[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "rocm", + feature = "openvino", + feature = "webgpu" + ))] + use ort::execution_providers::CPUExecutionProvider; + + let mut builder = builder + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(self.intra_threads)? + .with_inter_threads(self.inter_threads)?; + + builder = match self.execution_provider { + ExecutionProvider::Cpu => builder, + + #[cfg(feature = "cuda")] + ExecutionProvider::Cuda => builder.with_execution_providers([ + ort::execution_providers::CUDAExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "tensorrt")] + ExecutionProvider::TensorRT => builder.with_execution_providers([ + ort::execution_providers::TensorRTExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "coreml")] + ExecutionProvider::CoreML => { + use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider}; + builder.with_execution_providers([ + CoreMLExecutionProvider::default() + .with_compute_units(CoreMLComputeUnits::CPUAndGPU) + .build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])? + } + + #[cfg(feature = "directml")] + ExecutionProvider::DirectML => builder.with_execution_providers([ + ort::execution_providers::DirectMLExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "rocm")] + ExecutionProvider::ROCm => builder.with_execution_providers([ + ort::execution_providers::ROCMExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "openvino")] + ExecutionProvider::OpenVINO => builder.with_execution_providers([ + ort::execution_providers::OpenVINOExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "webgpu")] + ExecutionProvider::WebGPU => builder.with_execution_providers([ + ort::execution_providers::WebGPUExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + }; + + Ok(builder) + } +} |
