use crate::error::Result; use ort::session::builder::SessionBuilder; // Hardware acceleration options. CPU is default and most reliable. // GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware. // All GPU providers automatically fall back to CPU if they fail. // // Note: CoreML currently fails with this model due to unsupported operations. // WebGPU is experimental and may produce incorrect results. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ExecutionProvider { #[default] Cpu, #[cfg(feature = "cuda")] Cuda, #[cfg(feature = "tensorrt")] TensorRT, #[cfg(feature = "coreml")] CoreML, #[cfg(feature = "directml")] DirectML, #[cfg(feature = "rocm")] ROCm, #[cfg(feature = "openvino")] OpenVINO, #[cfg(feature = "webgpu")] WebGPU, } #[derive(Debug, Clone)] pub struct ModelConfig { pub execution_provider: ExecutionProvider, pub intra_threads: usize, pub inter_threads: usize, } impl Default for ModelConfig { fn default() -> Self { Self { execution_provider: ExecutionProvider::default(), intra_threads: 4, inter_threads: 1, } } } impl ModelConfig { pub fn new() -> Self { Self::default() } pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self { self.execution_provider = provider; self } pub fn with_intra_threads(mut self, threads: usize) -> Self { self.intra_threads = threads; self } pub fn with_inter_threads(mut self, threads: usize) -> Self { self.inter_threads = threads; self } pub(crate) fn apply_to_session_builder( &self, builder: SessionBuilder, ) -> Result { use ort::session::builder::GraphOptimizationLevel; #[cfg(any( feature = "cuda", feature = "tensorrt", feature = "coreml", feature = "directml", feature = "rocm", feature = "openvino", feature = "webgpu" ))] use ort::execution_providers::CPUExecutionProvider; let mut builder = builder .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(self.intra_threads)? .with_inter_threads(self.inter_threads)?; builder = match self.execution_provider { ExecutionProvider::Cpu => builder, #[cfg(feature = "cuda")] ExecutionProvider::Cuda => builder.with_execution_providers([ ort::execution_providers::CUDAExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, #[cfg(feature = "tensorrt")] ExecutionProvider::TensorRT => builder.with_execution_providers([ ort::execution_providers::TensorRTExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, #[cfg(feature = "coreml")] ExecutionProvider::CoreML => { use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider}; builder.with_execution_providers([ CoreMLExecutionProvider::default() .with_compute_units(CoreMLComputeUnits::CPUAndGPU) .build(), CPUExecutionProvider::default().build().error_on_failure(), ])? } #[cfg(feature = "directml")] ExecutionProvider::DirectML => builder.with_execution_providers([ ort::execution_providers::DirectMLExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, #[cfg(feature = "rocm")] ExecutionProvider::ROCm => builder.with_execution_providers([ ort::execution_providers::ROCMExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, #[cfg(feature = "openvino")] ExecutionProvider::OpenVINO => builder.with_execution_providers([ ort::execution_providers::OpenVINOExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, #[cfg(feature = "webgpu")] ExecutionProvider::WebGPU => builder.with_execution_providers([ ort::execution_providers::WebGPUExecutionProvider::default().build(), CPUExecutionProvider::default().build().error_on_failure(), ])?, }; Ok(builder) } }