summaryrefslogtreecommitdiff
path: root/makima/src/server/state.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-20 15:36:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit01088f4f1915e36a7d0d8d8756f62f8207a48911 (patch)
tree8fdbba900f3f4bba32bae76e2e0378848a90cf93 /makima/src/server/state.rs
parentab9166170043ba5e0ce974e5b7accf0939d686e3 (diff)
downloadsoryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.tar.gz
soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.zip
Implement makima listen websockets server
Diffstat (limited to 'makima/src/server/state.rs')
-rw-r--r--makima/src/server/state.rs50
1 files changed, 50 insertions, 0 deletions
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
new file mode 100644
index 0000000..8eaf788
--- /dev/null
+++ b/makima/src/server/state.rs
@@ -0,0 +1,50 @@
+//! Application state holding shared ML models.
+
+use std::sync::Arc;
+use tokio::sync::Mutex;
+
+use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer};
+use crate::tts::ChatterboxTTS;
+
+/// Shared application state containing ML models.
+///
+/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
+pub struct AppState {
+ /// Speech-to-text model (Parakeet)
+ pub parakeet: Mutex<ParakeetTDT>,
+ /// Speaker diarization model (Sortformer)
+ pub sortformer: Mutex<Sortformer>,
+ /// Text-to-speech model (ChatterboxTTS)
+ pub chatterbox: Mutex<ChatterboxTTS>,
+}
+
+impl AppState {
+ /// Load all ML models from the specified directories.
+ ///
+ /// # Arguments
+ /// * `parakeet_model_dir` - Path to the Parakeet STT model directory
+ /// * `sortformer_model_path` - Path to the Sortformer diarization model file
+ /// * `tts_model_dir` - Optional path to the ChatterboxTTS model directory
+ pub fn new(
+ parakeet_model_dir: &str,
+ sortformer_model_path: &str,
+ tts_model_dir: Option<&str>,
+ ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
+ let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
+ let sortformer = Sortformer::with_config(
+ sortformer_model_path,
+ None,
+ DiarizationConfig::callhome(),
+ )?;
+ let chatterbox = ChatterboxTTS::from_pretrained(tts_model_dir)?;
+
+ Ok(Self {
+ parakeet: Mutex::new(parakeet),
+ sortformer: Mutex::new(sortformer),
+ chatterbox: Mutex::new(chatterbox),
+ })
+ }
+}
+
+/// Type alias for the shared application state.
+pub type SharedState = Arc<AppState>;