summaryrefslogtreecommitdiff
path: root/makima/src/server
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-20 15:36:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit01088f4f1915e36a7d0d8d8756f62f8207a48911 (patch)
tree8fdbba900f3f4bba32bae76e2e0378848a90cf93 /makima/src/server
parentab9166170043ba5e0ce974e5b7accf0939d686e3 (diff)
downloadsoryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.tar.gz
soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.zip
Implement makima listen websockets server
Diffstat (limited to 'makima/src/server')
-rw-r--r--makima/src/server/handlers/listen.rs328
-rw-r--r--makima/src/server/handlers/mod.rs4
-rw-r--r--makima/src/server/handlers/tts.rs185
-rw-r--r--makima/src/server/messages.rs92
-rw-r--r--makima/src/server/mod.rs88
-rw-r--r--makima/src/server/openapi.rs34
-rw-r--r--makima/src/server/state.rs50
7 files changed, 781 insertions, 0 deletions
diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs
new file mode 100644
index 0000000..b1c1ad9
--- /dev/null
+++ b/makima/src/server/handlers/listen.rs
@@ -0,0 +1,328 @@
+//! WebSocket handler for streaming speech-to-text.
+
+use axum::{
+ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
+ response::Response,
+};
+use futures::{SinkExt, StreamExt};
+use tokio::sync::mpsc;
+use uuid::Uuid;
+
+use crate::audio::{resample_and_mixdown, TARGET_CHANNELS, TARGET_SAMPLE_RATE};
+use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode, Transcriber};
+use crate::server::messages::{
+ AudioEncoding, ClientMessage, ServerMessage, StartMessage, TranscriptMessage,
+};
+use crate::server::state::SharedState;
+
+/// Chunk size in milliseconds for streaming transcription.
+const STREAM_CHUNK_MS: u32 = 5_000;
+
+/// WebSocket upgrade handler for STT streaming.
+///
+/// This endpoint accepts WebSocket connections for real-time speech-to-text
+/// transcription with speaker diarization.
+#[utoipa::path(
+ get,
+ path = "/api/v1/listen",
+ responses(
+ (status = 101, description = "WebSocket connection established"),
+ ),
+ tag = "STT"
+)]
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_socket(socket, state))
+}
+
+async fn handle_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+ tracing::info!(session_id = %session_id, "New WebSocket connection");
+
+ // Split socket for concurrent read/write
+ let (mut sender, mut receiver) = socket.split();
+
+ // Channel for sending responses back to client
+ let (response_tx, mut response_rx) = mpsc::channel::<ServerMessage>(32);
+
+ // Spawn task to forward responses to WebSocket
+ let sender_task = tokio::spawn(async move {
+ while let Some(msg) = response_rx.recv().await {
+ let json = match serde_json::to_string(&msg) {
+ Ok(j) => j,
+ Err(e) => {
+ tracing::error!("Failed to serialize message: {}", e);
+ continue;
+ }
+ };
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ }
+ });
+
+ // Send ready message
+ let _ = response_tx
+ .send(ServerMessage::Ready {
+ session_id: session_id.clone(),
+ })
+ .await;
+
+ // Audio format state
+ let mut audio_format: Option<StartMessage> = None;
+
+ // Audio buffer for accumulating samples
+ let mut audio_buffer: Vec<f32> = Vec::new();
+ let mut last_sent_end_time: f32 = 0.0; // Track the end time of last sent segment
+ let mut last_processed_len: usize = 0; // Track how much audio we've processed
+
+ // Process incoming messages
+ while let Some(msg_result) = receiver.next().await {
+ let msg = match msg_result {
+ Ok(m) => m,
+ Err(e) => {
+ tracing::error!("WebSocket error: {}", e);
+ break;
+ }
+ };
+
+ match msg {
+ Message::Text(text) => {
+ // Parse JSON control messages
+ match serde_json::from_str::<ClientMessage>(&text) {
+ Ok(ClientMessage::Start(start)) => {
+ tracing::info!(
+ session_id = %session_id,
+ sample_rate = start.sample_rate,
+ channels = start.channels,
+ encoding = ?start.encoding,
+ "Session started"
+ );
+ audio_format = Some(start);
+ audio_buffer.clear();
+ last_sent_end_time = 0.0;
+ last_processed_len = 0;
+ }
+ Ok(ClientMessage::Stop(stop)) => {
+ tracing::info!(
+ session_id = %session_id,
+ reason = ?stop.reason,
+ audio_buffer_len = audio_buffer.len(),
+ "Session stopped by client"
+ );
+
+ if let Some(ref format) = audio_format {
+ if !audio_buffer.is_empty() {
+ tracing::debug!(
+ session_id = %session_id,
+ samples = audio_buffer.len(),
+ "Processing final audio buffer"
+ );
+ match process_audio(&audio_buffer, format, &state).await {
+ Ok(segments) => {
+ tracing::debug!(
+ session_id = %session_id,
+ total_segments = segments.len(),
+ last_sent_end = last_sent_end_time,
+ "Final transcription complete"
+ );
+
+ // Step 1: Send any NEW segments as interim (is_final: false)
+ // These are segments that weren't sent during streaming
+ for seg in &segments {
+ if seg.end > last_sent_end_time {
+ let _ = response_tx
+ .send(ServerMessage::Transcript(TranscriptMessage {
+ speaker: seg.speaker.clone(),
+ start: seg.start,
+ end: seg.end,
+ text: seg.text.clone(),
+ is_final: false,
+ }))
+ .await;
+ }
+ }
+
+ // Step 2: Send ALL segments as final (is_final: true)
+ // This is the complete authoritative transcript
+ for seg in &segments {
+ let _ = response_tx
+ .send(ServerMessage::Transcript(TranscriptMessage {
+ speaker: seg.speaker.clone(),
+ start: seg.start,
+ end: seg.end,
+ text: seg.text.clone(),
+ is_final: true,
+ }))
+ .await;
+ }
+ }
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "Final transcription failed"
+ );
+ let _ = response_tx
+ .send(ServerMessage::Error {
+ code: "TRANSCRIPTION_ERROR".into(),
+ message: e.to_string(),
+ })
+ .await;
+ }
+ }
+ }
+ }
+
+ let _ = response_tx
+ .send(ServerMessage::Stopped {
+ reason: stop.reason.unwrap_or_else(|| "client_requested".into()),
+ })
+ .await;
+ break;
+ }
+ Err(e) => {
+ tracing::warn!(session_id = %session_id, error = %e, "Failed to parse message");
+ let _ = response_tx
+ .send(ServerMessage::Error {
+ code: "PARSE_ERROR".into(),
+ message: format!("Failed to parse message: {}", e),
+ })
+ .await;
+ }
+ }
+ }
+ Message::Binary(data) => {
+ let Some(ref format) = audio_format else {
+ let _ = response_tx
+ .send(ServerMessage::Error {
+ code: "NO_FORMAT".into(),
+ message: "Received audio before start message".into(),
+ })
+ .await;
+ continue;
+ };
+
+ // Decode binary audio data to f32 samples
+ let samples = decode_audio_chunk(&data, format);
+ audio_buffer.extend(samples);
+
+ // Process when we have accumulated another chunk's worth of NEW audio
+ let chunk_samples = samples_per_chunk(format.sample_rate, STREAM_CHUNK_MS);
+ let new_audio_len = audio_buffer.len() - last_processed_len;
+
+ if new_audio_len >= chunk_samples {
+ tracing::debug!(
+ session_id = %session_id,
+ total_samples = audio_buffer.len(),
+ new_samples = new_audio_len,
+ "Processing audio chunk"
+ );
+
+ match process_audio(&audio_buffer, format, &state).await {
+ Ok(segments) => {
+ tracing::debug!(
+ session_id = %session_id,
+ total_segments = segments.len(),
+ last_sent_end = last_sent_end_time,
+ "Transcription produced segments"
+ );
+
+ // Send segments that end after our last sent time
+ // This handles re-segmentation by the model
+ for seg in &segments {
+ if seg.end > last_sent_end_time {
+ let _ = response_tx
+ .send(ServerMessage::Transcript(TranscriptMessage {
+ speaker: seg.speaker.clone(),
+ start: seg.start,
+ end: seg.end,
+ text: seg.text.clone(),
+ is_final: false,
+ }))
+ .await;
+ last_sent_end_time = seg.end;
+ }
+ }
+ last_processed_len = audio_buffer.len();
+ }
+ Err(e) => {
+ tracing::error!(session_id = %session_id, error = %e, "Transcription error");
+ let _ = response_tx
+ .send(ServerMessage::Error {
+ code: "TRANSCRIPTION_ERROR".into(),
+ message: e.to_string(),
+ })
+ .await;
+ }
+ }
+ }
+ }
+ Message::Close(_) => {
+ tracing::info!(session_id = %session_id, "WebSocket closed by client");
+ break;
+ }
+ _ => {}
+ }
+ }
+
+ // Cleanup
+ drop(response_tx);
+ let _ = sender_task.await;
+ tracing::info!(session_id = %session_id, "WebSocket connection closed");
+}
+
+/// Decode binary audio chunk to f32 samples based on encoding format.
+fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> {
+ match format.encoding {
+ AudioEncoding::Pcm32f => data
+ .chunks_exact(4)
+ .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
+ .collect(),
+ AudioEncoding::Pcm16 | AudioEncoding::Raw => data
+ .chunks_exact(2)
+ .map(|chunk| {
+ let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
+ sample as f32 / 32768.0
+ })
+ .collect(),
+ }
+}
+
+/// Process accumulated audio through STT and diarization models.
+async fn process_audio(
+ samples: &[f32],
+ format: &StartMessage,
+ state: &SharedState,
+) -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error + Send + Sync>> {
+ // Resample to 16kHz mono if needed
+ let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS
+ {
+ resample_and_mixdown(samples, format.sample_rate, format.channels)
+ } else {
+ samples.to_vec()
+ };
+
+ // Acquire model locks and run inference
+ let mut parakeet = state.parakeet.lock().await;
+ let mut sortformer = state.sortformer.lock().await;
+
+ // Run diarization
+ let diarization_segments =
+ sortformer.diarize(resampled.clone(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?;
+
+ // Run transcription
+ let transcription = parakeet.transcribe_samples(
+ resampled,
+ TARGET_SAMPLE_RATE,
+ TARGET_CHANNELS,
+ Some(TimestampMode::Sentences),
+ )?;
+
+ // Align speakers with transcription
+ let aligned = align_speakers(&transcription.tokens, &diarization_segments);
+
+ Ok(aligned)
+}
diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs
new file mode 100644
index 0000000..90798f9
--- /dev/null
+++ b/makima/src/server/handlers/mod.rs
@@ -0,0 +1,4 @@
+//! HTTP and WebSocket request handlers.
+
+pub mod listen;
+pub mod tts;
diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs
new file mode 100644
index 0000000..94a261f
--- /dev/null
+++ b/makima/src/server/handlers/tts.rs
@@ -0,0 +1,185 @@
+//! HTTP handler for text-to-speech synthesis.
+
+use axum::{
+ body::Body,
+ extract::{Multipart, State},
+ http::{header, StatusCode},
+ response::Response,
+ Json,
+};
+use std::io::Cursor;
+
+use crate::audio::to_16k_mono_from_reader;
+use crate::server::messages::ApiError;
+use crate::server::state::SharedState;
+use crate::tts::SAMPLE_RATE;
+
+/// POST /api/v1/tts/synthesize
+///
+/// Synthesize speech from text using voice cloning.
+///
+/// Accepts multipart form data with:
+/// - `text`: The text to synthesize (required)
+/// - `voice`: Audio file for voice cloning reference (required)
+///
+/// Returns: WAV audio file (24kHz mono)
+#[utoipa::path(
+ post,
+ path = "/api/v1/tts/synthesize",
+ request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"),
+ responses(
+ (status = 200, description = "Generated audio file", content_type = "audio/wav"),
+ (status = 400, description = "Bad request", body = ApiError),
+ (status = 500, description = "Internal server error", body = ApiError),
+ ),
+ tag = "TTS"
+)]
+pub async fn synthesize_handler(
+ State(state): State<SharedState>,
+ mut multipart: Multipart,
+) -> Result<Response, (StatusCode, Json<ApiError>)> {
+ let mut text: Option<String> = None;
+ let mut voice_samples: Option<Vec<f32>> = None;
+ let mut voice_sample_rate: u32 = 16_000;
+
+ // Parse multipart fields
+ while let Some(field) = multipart.next_field().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("MULTIPART_ERROR", e.to_string())),
+ )
+ })? {
+ let name = field.name().unwrap_or("").to_string();
+
+ match name.as_str() {
+ "text" => {
+ text = Some(field.text().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())),
+ )
+ })?);
+ }
+ "voice" => {
+ let data = field.bytes().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())),
+ )
+ })?;
+
+ // Decode audio file to PCM samples
+ let cursor = Cursor::new(data.to_vec());
+ let pcm = to_16k_mono_from_reader(cursor).map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new(
+ "AUDIO_DECODE_ERROR",
+ format!("Failed to decode voice audio: {}", e),
+ )),
+ )
+ })?;
+
+ voice_samples = Some(pcm.samples);
+ voice_sample_rate = pcm.sample_rate;
+ }
+ _ => {
+ // Ignore unknown fields
+ tracing::debug!("Ignoring unknown field: {}", name);
+ }
+ }
+ }
+
+ // Validate required fields
+ let text = text.ok_or_else(|| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("MISSING_TEXT", "Text field is required")),
+ )
+ })?;
+
+ let samples = voice_samples.ok_or_else(|| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new(
+ "MISSING_VOICE",
+ "Voice audio file is required",
+ )),
+ )
+ })?;
+
+ tracing::info!(
+ text_len = text.len(),
+ voice_samples = samples.len(),
+ "Generating TTS"
+ );
+
+ // Generate TTS with voice cloning
+ let mut chatterbox = state.chatterbox.lock().await;
+ let audio = chatterbox
+ .generate_tts_with_samples(&text, &samples, voice_sample_rate)
+ .map_err(|e| {
+ tracing::error!(error = %e, "TTS generation failed");
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("TTS_ERROR", e.to_string())),
+ )
+ })?;
+
+ tracing::info!(samples = audio.len(), "TTS generation complete");
+
+ // Encode as WAV
+ let wav_data = encode_wav(&audio, SAMPLE_RATE);
+
+ // Return WAV response
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .header(header::CONTENT_TYPE, "audio/wav")
+ .header(
+ header::CONTENT_DISPOSITION,
+ "attachment; filename=\"output.wav\"",
+ )
+ .body(Body::from(wav_data))
+ .unwrap())
+}
+
+/// Encode f32 samples as a WAV file in memory.
+fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> {
+ let mut buf = Vec::new();
+
+ let num_samples = samples.len() as u32;
+ let num_channels: u16 = 1;
+ let bits_per_sample: u16 = 16;
+ let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
+ let block_align = num_channels * bits_per_sample / 8;
+ let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
+ let file_size = 36 + data_size;
+
+ // RIFF header
+ buf.extend_from_slice(b"RIFF");
+ buf.extend_from_slice(&file_size.to_le_bytes());
+ buf.extend_from_slice(b"WAVE");
+
+ // fmt chunk
+ buf.extend_from_slice(b"fmt ");
+ buf.extend_from_slice(&16u32.to_le_bytes());
+ buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
+ buf.extend_from_slice(&num_channels.to_le_bytes());
+ buf.extend_from_slice(&sample_rate.to_le_bytes());
+ buf.extend_from_slice(&byte_rate.to_le_bytes());
+ buf.extend_from_slice(&block_align.to_le_bytes());
+ buf.extend_from_slice(&bits_per_sample.to_le_bytes());
+
+ // data chunk
+ buf.extend_from_slice(b"data");
+ buf.extend_from_slice(&data_size.to_le_bytes());
+
+ // Convert f32 samples to i16 PCM
+ for &sample in samples {
+ let clamped = sample.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ buf.extend_from_slice(&int_sample.to_le_bytes());
+ }
+
+ buf
+}
diff --git a/makima/src/server/messages.rs b/makima/src/server/messages.rs
new file mode 100644
index 0000000..0c92447
--- /dev/null
+++ b/makima/src/server/messages.rs
@@ -0,0 +1,92 @@
+//! WebSocket and API message types for the makima server.
+
+use serde::{Deserialize, Serialize};
+use utoipa::ToSchema;
+
+/// Audio encoding format for WebSocket streaming.
+#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "lowercase")]
+pub enum AudioEncoding {
+ /// 32-bit floating point PCM samples
+ Pcm32f,
+ /// 16-bit signed integer PCM samples
+ Pcm16,
+ /// Raw bytes (will be interpreted as PCM16)
+ Raw,
+}
+
+/// Initial handshake message from client specifying audio format.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct StartMessage {
+ /// Audio sample rate in Hz (e.g., 16000, 44100, 48000)
+ pub sample_rate: u32,
+ /// Number of audio channels (1 for mono, 2 for stereo)
+ pub channels: u16,
+ /// Audio encoding format
+ pub encoding: AudioEncoding,
+}
+
+/// Stop message to terminate the session.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct StopMessage {
+ /// Optional reason for stopping
+ pub reason: Option<String>,
+}
+
+/// Wrapper for all WebSocket messages from client to server.
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum ClientMessage {
+ Start(StartMessage),
+ Stop(StopMessage),
+}
+
+/// Transcription result message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TranscriptMessage {
+ /// Speaker identifier (e.g., "Speaker 0", "Speaker 1")
+ pub speaker: String,
+ /// Segment start time in seconds
+ pub start: f32,
+ /// Segment end time in seconds
+ pub end: f32,
+ /// Transcribed text
+ pub text: String,
+ /// Whether this is a final or interim result
+ pub is_final: bool,
+}
+
+/// Wrapper for all WebSocket messages from server to client.
+#[derive(Debug, Clone, Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum ServerMessage {
+ /// Session is ready for audio streaming
+ Ready { session_id: String },
+ /// Transcription result
+ Transcript(TranscriptMessage),
+ /// Error occurred during processing
+ Error { code: String, message: String },
+ /// Session has been stopped
+ Stopped { reason: String },
+}
+
+/// Error response for HTTP API endpoints.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+pub struct ApiError {
+ /// Error code for programmatic handling
+ pub code: String,
+ /// Human-readable error message
+ pub message: String,
+}
+
+impl ApiError {
+ pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
+ Self {
+ code: code.into(),
+ message: message.into(),
+ }
+ }
+}
diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs
new file mode 100644
index 0000000..c33eeef
--- /dev/null
+++ b/makima/src/server/mod.rs
@@ -0,0 +1,88 @@
+//! Web server module for the makima audio API.
+
+pub mod handlers;
+pub mod messages;
+pub mod openapi;
+pub mod state;
+
+use axum::{
+ routing::{get, post},
+ Router,
+};
+use tower_http::cors::{Any, CorsLayer};
+use tower_http::trace::TraceLayer;
+use utoipa::OpenApi;
+use utoipa_swagger_ui::SwaggerUi;
+
+use crate::server::handlers::{listen, tts};
+use crate::server::openapi::ApiDoc;
+use crate::server::state::SharedState;
+
+/// Create the axum Router with all routes configured.
+pub fn make_router(state: SharedState) -> Router {
+ // API v1 routes
+ let api_v1 = Router::new()
+ .route("/listen", get(listen::websocket_handler))
+ .route("/tts/synthesize", post(tts::synthesize_handler))
+ .with_state(state);
+
+ let swagger = SwaggerUi::new("/swagger-ui")
+ .url("/api-docs/openapi.json", ApiDoc::openapi());
+
+ Router::new()
+ .nest("/api/v1", api_v1)
+ .merge(swagger)
+ .layer(
+ CorsLayer::new()
+ .allow_origin(Any)
+ .allow_methods(Any)
+ .allow_headers(Any),
+ )
+ .layer(TraceLayer::new_for_http())
+}
+
+/// Run the HTTP server with graceful shutdown support.
+///
+/// # Arguments
+/// * `state` - Shared application state containing ML models
+/// * `addr` - Address to bind to (e.g., "0.0.0.0:8080")
+pub async fn run_server(state: SharedState, addr: &str) -> anyhow::Result<()> {
+ let app = make_router(state);
+ let listener = tokio::net::TcpListener::bind(addr).await?;
+
+ tracing::info!("Server listening on {}", addr);
+ tracing::info!("Swagger UI available at http://{}/swagger-ui", addr);
+
+ axum::serve(listener, app)
+ .with_graceful_shutdown(shutdown_signal())
+ .await?;
+
+ Ok(())
+}
+
+/// Wait for shutdown signals (Ctrl+C or SIGTERM).
+async fn shutdown_signal() {
+ let ctrl_c = async {
+ tokio::signal::ctrl_c()
+ .await
+ .expect("Failed to install Ctrl+C handler");
+ };
+
+ #[cfg(unix)]
+ let terminate = async {
+ tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
+ .expect("Failed to install signal handler")
+ .recv()
+ .await;
+ };
+
+ #[cfg(not(unix))]
+ let terminate = std::future::pending::<()>();
+
+ tokio::select! {
+ _ = ctrl_c => {},
+ _ = terminate => {},
+ }
+
+ tracing::info!("Shutdown signal received, starting graceful shutdown");
+}
diff --git a/makima/src/server/openapi.rs b/makima/src/server/openapi.rs
new file mode 100644
index 0000000..363d348
--- /dev/null
+++ b/makima/src/server/openapi.rs
@@ -0,0 +1,34 @@
+//! OpenAPI documentation configuration using utoipa.
+
+use utoipa::OpenApi;
+
+use crate::server::handlers::{listen, tts};
+use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage, TranscriptMessage};
+
+#[derive(OpenApi)]
+#[openapi(
+ info(
+ title = "Makima Audio API",
+ version = "1.0.0",
+ description = "Streaming audio APIs for speech-to-text and text-to-speech with voice cloning.",
+ license(name = "MIT"),
+ ),
+ paths(
+ listen::websocket_handler,
+ tts::synthesize_handler,
+ ),
+ components(
+ schemas(
+ ApiError,
+ AudioEncoding,
+ StartMessage,
+ StopMessage,
+ TranscriptMessage,
+ )
+ ),
+ tags(
+ (name = "STT", description = "Speech-to-text streaming endpoints"),
+ (name = "TTS", description = "Text-to-speech synthesis endpoints"),
+ )
+)]
+pub struct ApiDoc;
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
new file mode 100644
index 0000000..8eaf788
--- /dev/null
+++ b/makima/src/server/state.rs
@@ -0,0 +1,50 @@
+//! Application state holding shared ML models.
+
+use std::sync::Arc;
+use tokio::sync::Mutex;
+
+use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer};
+use crate::tts::ChatterboxTTS;
+
+/// Shared application state containing ML models.
+///
+/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
+pub struct AppState {
+ /// Speech-to-text model (Parakeet)
+ pub parakeet: Mutex<ParakeetTDT>,
+ /// Speaker diarization model (Sortformer)
+ pub sortformer: Mutex<Sortformer>,
+ /// Text-to-speech model (ChatterboxTTS)
+ pub chatterbox: Mutex<ChatterboxTTS>,
+}
+
+impl AppState {
+ /// Load all ML models from the specified directories.
+ ///
+ /// # Arguments
+ /// * `parakeet_model_dir` - Path to the Parakeet STT model directory
+ /// * `sortformer_model_path` - Path to the Sortformer diarization model file
+ /// * `tts_model_dir` - Optional path to the ChatterboxTTS model directory
+ pub fn new(
+ parakeet_model_dir: &str,
+ sortformer_model_path: &str,
+ tts_model_dir: Option<&str>,
+ ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
+ let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
+ let sortformer = Sortformer::with_config(
+ sortformer_model_path,
+ None,
+ DiarizationConfig::callhome(),
+ )?;
+ let chatterbox = ChatterboxTTS::from_pretrained(tts_model_dir)?;
+
+ Ok(Self {
+ parakeet: Mutex::new(parakeet),
+ sortformer: Mutex::new(sortformer),
+ chatterbox: Mutex::new(chatterbox),
+ })
+ }
+}
+
+/// Type alias for the shared application state.
+pub type SharedState = Arc<AppState>;