summaryrefslogtreecommitdiff
path: root/parakeet-rs/src/execution.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-21 00:40:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit55cacf6e1a087c0fa6950a1ddeb09060f787e541 (patch)
tree0b8e754eb16c829fc0ee7c8f4ba66fe75b4f3ebf /parakeet-rs/src/execution.rs
parent84fee5ce2ae30fb2381c99b9b223b8235b962869 (diff)
downloadsoryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.tar.gz
soryu-55cacf6e1a087c0fa6950a1ddeb09060f787e541.zip
Add EOU detection and streaming diarization
Diffstat (limited to 'parakeet-rs/src/execution.rs')
-rw-r--r--parakeet-rs/src/execution.rs141
1 files changed, 141 insertions, 0 deletions
diff --git a/parakeet-rs/src/execution.rs b/parakeet-rs/src/execution.rs
new file mode 100644
index 0000000..e29aa1d
--- /dev/null
+++ b/parakeet-rs/src/execution.rs
@@ -0,0 +1,141 @@
+use crate::error::Result;
+use ort::session::builder::SessionBuilder;
+
+// Hardware acceleration options. CPU is default and most reliable.
+// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware.
+// All GPU providers automatically fall back to CPU if they fail.
+//
+// Note: CoreML currently fails with this model due to unsupported operations.
+// WebGPU is experimental and may produce incorrect results.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
+pub enum ExecutionProvider {
+ #[default]
+ Cpu,
+ #[cfg(feature = "cuda")]
+ Cuda,
+ #[cfg(feature = "tensorrt")]
+ TensorRT,
+ #[cfg(feature = "coreml")]
+ CoreML,
+ #[cfg(feature = "directml")]
+ DirectML,
+ #[cfg(feature = "rocm")]
+ ROCm,
+ #[cfg(feature = "openvino")]
+ OpenVINO,
+ #[cfg(feature = "webgpu")]
+ WebGPU,
+}
+
+#[derive(Debug, Clone)]
+pub struct ModelConfig {
+ pub execution_provider: ExecutionProvider,
+ pub intra_threads: usize,
+ pub inter_threads: usize,
+}
+
+impl Default for ModelConfig {
+ fn default() -> Self {
+ Self {
+ execution_provider: ExecutionProvider::default(),
+ intra_threads: 4,
+ inter_threads: 1,
+ }
+ }
+}
+
+impl ModelConfig {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
+ self.execution_provider = provider;
+ self
+ }
+
+ pub fn with_intra_threads(mut self, threads: usize) -> Self {
+ self.intra_threads = threads;
+ self
+ }
+
+ pub fn with_inter_threads(mut self, threads: usize) -> Self {
+ self.inter_threads = threads;
+ self
+ }
+
+ pub(crate) fn apply_to_session_builder(
+ &self,
+ builder: SessionBuilder,
+ ) -> Result<SessionBuilder> {
+ use ort::session::builder::GraphOptimizationLevel;
+ #[cfg(any(
+ feature = "cuda",
+ feature = "tensorrt",
+ feature = "coreml",
+ feature = "directml",
+ feature = "rocm",
+ feature = "openvino",
+ feature = "webgpu"
+ ))]
+ use ort::execution_providers::CPUExecutionProvider;
+
+ let mut builder = builder
+ .with_optimization_level(GraphOptimizationLevel::Level3)?
+ .with_intra_threads(self.intra_threads)?
+ .with_inter_threads(self.inter_threads)?;
+
+ builder = match self.execution_provider {
+ ExecutionProvider::Cpu => builder,
+
+ #[cfg(feature = "cuda")]
+ ExecutionProvider::Cuda => builder.with_execution_providers([
+ ort::execution_providers::CUDAExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+
+ #[cfg(feature = "tensorrt")]
+ ExecutionProvider::TensorRT => builder.with_execution_providers([
+ ort::execution_providers::TensorRTExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+
+ #[cfg(feature = "coreml")]
+ ExecutionProvider::CoreML => {
+ use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider};
+ builder.with_execution_providers([
+ CoreMLExecutionProvider::default()
+ .with_compute_units(CoreMLComputeUnits::CPUAndGPU)
+ .build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?
+ }
+
+ #[cfg(feature = "directml")]
+ ExecutionProvider::DirectML => builder.with_execution_providers([
+ ort::execution_providers::DirectMLExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+
+ #[cfg(feature = "rocm")]
+ ExecutionProvider::ROCm => builder.with_execution_providers([
+ ort::execution_providers::ROCMExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+
+ #[cfg(feature = "openvino")]
+ ExecutionProvider::OpenVINO => builder.with_execution_providers([
+ ort::execution_providers::OpenVINOExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+
+ #[cfg(feature = "webgpu")]
+ ExecutionProvider::WebGPU => builder.with_execution_providers([
+ ort::execution_providers::WebGPUExecutionProvider::default().build(),
+ CPUExecutionProvider::default().build().error_on_failure(),
+ ])?,
+ };
+
+ Ok(builder)
+ }
+}