diff options
| author | soryu <soryu@soryu.co> | 2025-12-20 15:36:04 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 01088f4f1915e36a7d0d8d8756f62f8207a48911 (patch) | |
| tree | 8fdbba900f3f4bba32bae76e2e0378848a90cf93 /makima | |
| parent | ab9166170043ba5e0ce974e5b7accf0939d686e3 (diff) | |
| download | soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.tar.gz soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.zip | |
Implement makima listen websockets server
Diffstat (limited to 'makima')
| -rw-r--r-- | makima/Cargo.toml | 25 | ||||
| -rw-r--r-- | makima/src/audio.rs | 6 | ||||
| -rw-r--r-- | makima/src/bin/server.rs | 40 | ||||
| -rw-r--r-- | makima/src/lib.rs | 4 | ||||
| -rw-r--r-- | makima/src/listen.rs | 12 | ||||
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 328 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 4 | ||||
| -rw-r--r-- | makima/src/server/handlers/tts.rs | 185 | ||||
| -rw-r--r-- | makima/src/server/messages.rs | 92 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 88 | ||||
| -rw-r--r-- | makima/src/server/openapi.rs | 34 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 50 |
12 files changed, 864 insertions, 4 deletions
diff --git a/makima/Cargo.toml b/makima/Cargo.toml index df7384f..5142963 100644 --- a/makima/Cargo.toml +++ b/makima/Cargo.toml @@ -3,10 +3,35 @@ name = "makima" version = "0.1.0" edition = "2024" +[[bin]] +name = "makima-server" +path = "src/bin/server.rs" + [dependencies] +# ML/Audio (existing) parakeet-rs = { version = "0.2.5", features = ["sortformer"] } symphonia = { version = "0.5", features = ["mp3", "aac", "flac", "ogg", "vorbis", "wav", "pcm"] } ort = "2.0.0-rc.10" tokenizers = "0.21" hf-hub = "0.4" ndarray = "0.16" + +# Web server +axum = { version = "0.8", features = ["ws", "multipart"] } +tokio = { version = "1.0", features = ["full", "signal"] } +tower-http = { version = "0.6", features = ["cors", "trace"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +futures = "0.3" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +bytes = "1.0" +uuid = { version = "1.0", features = ["v4"] } + +# OpenAPI +utoipa = { version = "5", features = ["axum_extras"] } +utoipa-swagger-ui = { version = "9", features = ["axum"] } + +# Error handling +thiserror = "2.0" +anyhow = "1.0" diff --git a/makima/src/audio.rs b/makima/src/audio.rs index acfe7ce..8c969be 100644 --- a/makima/src/audio.rs +++ b/makima/src/audio.rs @@ -239,6 +239,12 @@ fn mixdown_to_mono(interleaved: &[f32], channels: u16) -> Vec<f32> { mono } +/// Resample and mixdown audio to 16kHz mono for STT processing. +pub fn resample_and_mixdown(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<f32> { + let mono = mixdown_to_mono(samples, channels); + resample_sinc(&mono, sample_rate, TARGET_SAMPLE_RATE) +} + fn resample_sinc(input: &[f32], input_rate: u32, output_rate: u32) -> Vec<f32> { if input_rate == output_rate { return input.to_vec(); diff --git a/makima/src/bin/server.rs b/makima/src/bin/server.rs new file mode 100644 index 0000000..7117cfe --- /dev/null +++ b/makima/src/bin/server.rs @@ -0,0 +1,40 @@ +//! Makima Audio API Server binary. +//! +//! This server provides WebSocket-based speech-to-text streaming +//! and HTTP-based text-to-speech synthesis with voice cloning. + +use std::sync::Arc; + +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use makima::server::{run_server, state::AppState}; + +/// Default model paths relative to the working directory. +const PARAKEET_MODEL_DIR: &str = "models/parakeet-tdt-0.6b-v3"; +const SORTFORMER_MODEL_PATH: &str = "models/diarization/diar_streaming_sortformer_4spk-v2.onnx"; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing subscriber with environment filter + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "makima=debug,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::info!("Starting Makima Audio API Server"); + tracing::info!("Loading ML models..."); + + // Load all ML models + let state = Arc::new( + AppState::new(PARAKEET_MODEL_DIR, SORTFORMER_MODEL_PATH, None) + .map_err(|e| anyhow::anyhow!("Failed to load models: {}", e))?, + ); + + tracing::info!("Models loaded successfully"); + + // Run the server + run_server(state, "0.0.0.0:8080").await +} diff --git a/makima/src/lib.rs b/makima/src/lib.rs new file mode 100644 index 0000000..1e95d95 --- /dev/null +++ b/makima/src/lib.rs @@ -0,0 +1,4 @@ +pub mod audio; +pub mod listen; +pub mod server; +pub mod tts; diff --git a/makima/src/listen.rs b/makima/src/listen.rs index cd0a394..24bf9da 100644 --- a/makima/src/listen.rs +++ b/makima/src/listen.rs @@ -1,13 +1,15 @@ use std::cmp::Ordering; use std::path::Path; -use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment}; -use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode}; +pub use parakeet_rs::sortformer::{DiarizationConfig, Sortformer, SpeakerSegment}; +pub use parakeet_rs::{ParakeetTDT, TimedToken, TimestampMode, Transcriber}; use crate::audio; const STREAM_CHUNK_MS: u32 = 5_000; +/// A segment of dialogue with speaker identification and timing. +#[derive(Debug, Clone)] pub struct DialogueSegment { pub speaker: String, pub start: f32, @@ -66,7 +68,8 @@ pub(crate) fn listen() -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error Ok(final_segments) } -fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec<DialogueSegment> { +/// Align transcription tokens with speaker diarization segments. +pub fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec<DialogueSegment> { tokens .iter() .map(|token| { @@ -82,7 +85,8 @@ fn align_speakers(tokens: &[TimedToken], speakers: &[SpeakerSegment]) -> Vec<Dia .collect() } -fn samples_per_chunk(sample_rate: u32, chunk_ms: u32) -> usize { +/// Calculate the number of samples in a chunk of given duration. +pub fn samples_per_chunk(sample_rate: u32, chunk_ms: u32) -> usize { let samples = (sample_rate as u64) .saturating_mul(chunk_ms as u64) .saturating_div(1_000); diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs new file mode 100644 index 0000000..b1c1ad9 --- /dev/null +++ b/makima/src/server/handlers/listen.rs @@ -0,0 +1,328 @@ +//! WebSocket handler for streaming speech-to-text. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::audio::{resample_and_mixdown, TARGET_CHANNELS, TARGET_SAMPLE_RATE}; +use crate::listen::{align_speakers, samples_per_chunk, DialogueSegment, TimestampMode, Transcriber}; +use crate::server::messages::{ + AudioEncoding, ClientMessage, ServerMessage, StartMessage, TranscriptMessage, +}; +use crate::server::state::SharedState; + +/// Chunk size in milliseconds for streaming transcription. +const STREAM_CHUNK_MS: u32 = 5_000; + +/// WebSocket upgrade handler for STT streaming. +/// +/// This endpoint accepts WebSocket connections for real-time speech-to-text +/// transcription with speaker diarization. +#[utoipa::path( + get, + path = "/api/v1/listen", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "STT" +)] +pub async fn websocket_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_socket(socket, state)) +} + +async fn handle_socket(socket: WebSocket, state: SharedState) { + let session_id = Uuid::new_v4().to_string(); + tracing::info!(session_id = %session_id, "New WebSocket connection"); + + // Split socket for concurrent read/write + let (mut sender, mut receiver) = socket.split(); + + // Channel for sending responses back to client + let (response_tx, mut response_rx) = mpsc::channel::<ServerMessage>(32); + + // Spawn task to forward responses to WebSocket + let sender_task = tokio::spawn(async move { + while let Some(msg) = response_rx.recv().await { + let json = match serde_json::to_string(&msg) { + Ok(j) => j, + Err(e) => { + tracing::error!("Failed to serialize message: {}", e); + continue; + } + }; + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + }); + + // Send ready message + let _ = response_tx + .send(ServerMessage::Ready { + session_id: session_id.clone(), + }) + .await; + + // Audio format state + let mut audio_format: Option<StartMessage> = None; + + // Audio buffer for accumulating samples + let mut audio_buffer: Vec<f32> = Vec::new(); + let mut last_sent_end_time: f32 = 0.0; // Track the end time of last sent segment + let mut last_processed_len: usize = 0; // Track how much audio we've processed + + // Process incoming messages + while let Some(msg_result) = receiver.next().await { + let msg = match msg_result { + Ok(m) => m, + Err(e) => { + tracing::error!("WebSocket error: {}", e); + break; + } + }; + + match msg { + Message::Text(text) => { + // Parse JSON control messages + match serde_json::from_str::<ClientMessage>(&text) { + Ok(ClientMessage::Start(start)) => { + tracing::info!( + session_id = %session_id, + sample_rate = start.sample_rate, + channels = start.channels, + encoding = ?start.encoding, + "Session started" + ); + audio_format = Some(start); + audio_buffer.clear(); + last_sent_end_time = 0.0; + last_processed_len = 0; + } + Ok(ClientMessage::Stop(stop)) => { + tracing::info!( + session_id = %session_id, + reason = ?stop.reason, + audio_buffer_len = audio_buffer.len(), + "Session stopped by client" + ); + + if let Some(ref format) = audio_format { + if !audio_buffer.is_empty() { + tracing::debug!( + session_id = %session_id, + samples = audio_buffer.len(), + "Processing final audio buffer" + ); + match process_audio(&audio_buffer, format, &state).await { + Ok(segments) => { + tracing::debug!( + session_id = %session_id, + total_segments = segments.len(), + last_sent_end = last_sent_end_time, + "Final transcription complete" + ); + + // Step 1: Send any NEW segments as interim (is_final: false) + // These are segments that weren't sent during streaming + for seg in &segments { + if seg.end > last_sent_end_time { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: false, + })) + .await; + } + } + + // Step 2: Send ALL segments as final (is_final: true) + // This is the complete authoritative transcript + for seg in &segments { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: true, + })) + .await; + } + } + Err(e) => { + tracing::error!( + session_id = %session_id, + error = %e, + "Final transcription failed" + ); + let _ = response_tx + .send(ServerMessage::Error { + code: "TRANSCRIPTION_ERROR".into(), + message: e.to_string(), + }) + .await; + } + } + } + } + + let _ = response_tx + .send(ServerMessage::Stopped { + reason: stop.reason.unwrap_or_else(|| "client_requested".into()), + }) + .await; + break; + } + Err(e) => { + tracing::warn!(session_id = %session_id, error = %e, "Failed to parse message"); + let _ = response_tx + .send(ServerMessage::Error { + code: "PARSE_ERROR".into(), + message: format!("Failed to parse message: {}", e), + }) + .await; + } + } + } + Message::Binary(data) => { + let Some(ref format) = audio_format else { + let _ = response_tx + .send(ServerMessage::Error { + code: "NO_FORMAT".into(), + message: "Received audio before start message".into(), + }) + .await; + continue; + }; + + // Decode binary audio data to f32 samples + let samples = decode_audio_chunk(&data, format); + audio_buffer.extend(samples); + + // Process when we have accumulated another chunk's worth of NEW audio + let chunk_samples = samples_per_chunk(format.sample_rate, STREAM_CHUNK_MS); + let new_audio_len = audio_buffer.len() - last_processed_len; + + if new_audio_len >= chunk_samples { + tracing::debug!( + session_id = %session_id, + total_samples = audio_buffer.len(), + new_samples = new_audio_len, + "Processing audio chunk" + ); + + match process_audio(&audio_buffer, format, &state).await { + Ok(segments) => { + tracing::debug!( + session_id = %session_id, + total_segments = segments.len(), + last_sent_end = last_sent_end_time, + "Transcription produced segments" + ); + + // Send segments that end after our last sent time + // This handles re-segmentation by the model + for seg in &segments { + if seg.end > last_sent_end_time { + let _ = response_tx + .send(ServerMessage::Transcript(TranscriptMessage { + speaker: seg.speaker.clone(), + start: seg.start, + end: seg.end, + text: seg.text.clone(), + is_final: false, + })) + .await; + last_sent_end_time = seg.end; + } + } + last_processed_len = audio_buffer.len(); + } + Err(e) => { + tracing::error!(session_id = %session_id, error = %e, "Transcription error"); + let _ = response_tx + .send(ServerMessage::Error { + code: "TRANSCRIPTION_ERROR".into(), + message: e.to_string(), + }) + .await; + } + } + } + } + Message::Close(_) => { + tracing::info!(session_id = %session_id, "WebSocket closed by client"); + break; + } + _ => {} + } + } + + // Cleanup + drop(response_tx); + let _ = sender_task.await; + tracing::info!(session_id = %session_id, "WebSocket connection closed"); +} + +/// Decode binary audio chunk to f32 samples based on encoding format. +fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> { + match format.encoding { + AudioEncoding::Pcm32f => data + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect(), + AudioEncoding::Pcm16 | AudioEncoding::Raw => data + .chunks_exact(2) + .map(|chunk| { + let sample = i16::from_le_bytes([chunk[0], chunk[1]]); + sample as f32 / 32768.0 + }) + .collect(), + } +} + +/// Process accumulated audio through STT and diarization models. +async fn process_audio( + samples: &[f32], + format: &StartMessage, + state: &SharedState, +) -> Result<Vec<DialogueSegment>, Box<dyn std::error::Error + Send + Sync>> { + // Resample to 16kHz mono if needed + let resampled = if format.sample_rate != TARGET_SAMPLE_RATE || format.channels != TARGET_CHANNELS + { + resample_and_mixdown(samples, format.sample_rate, format.channels) + } else { + samples.to_vec() + }; + + // Acquire model locks and run inference + let mut parakeet = state.parakeet.lock().await; + let mut sortformer = state.sortformer.lock().await; + + // Run diarization + let diarization_segments = + sortformer.diarize(resampled.clone(), TARGET_SAMPLE_RATE, TARGET_CHANNELS)?; + + // Run transcription + let transcription = parakeet.transcribe_samples( + resampled, + TARGET_SAMPLE_RATE, + TARGET_CHANNELS, + Some(TimestampMode::Sentences), + )?; + + // Align speakers with transcription + let aligned = align_speakers(&transcription.tokens, &diarization_segments); + + Ok(aligned) +} diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs new file mode 100644 index 0000000..90798f9 --- /dev/null +++ b/makima/src/server/handlers/mod.rs @@ -0,0 +1,4 @@ +//! HTTP and WebSocket request handlers. + +pub mod listen; +pub mod tts; diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs new file mode 100644 index 0000000..94a261f --- /dev/null +++ b/makima/src/server/handlers/tts.rs @@ -0,0 +1,185 @@ +//! HTTP handler for text-to-speech synthesis. + +use axum::{ + body::Body, + extract::{Multipart, State}, + http::{header, StatusCode}, + response::Response, + Json, +}; +use std::io::Cursor; + +use crate::audio::to_16k_mono_from_reader; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; +use crate::tts::SAMPLE_RATE; + +/// POST /api/v1/tts/synthesize +/// +/// Synthesize speech from text using voice cloning. +/// +/// Accepts multipart form data with: +/// - `text`: The text to synthesize (required) +/// - `voice`: Audio file for voice cloning reference (required) +/// +/// Returns: WAV audio file (24kHz mono) +#[utoipa::path( + post, + path = "/api/v1/tts/synthesize", + request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"), + responses( + (status = 200, description = "Generated audio file", content_type = "audio/wav"), + (status = 400, description = "Bad request", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + tag = "TTS" +)] +pub async fn synthesize_handler( + State(state): State<SharedState>, + mut multipart: Multipart, +) -> Result<Response, (StatusCode, Json<ApiError>)> { + let mut text: Option<String> = None; + let mut voice_samples: Option<Vec<f32>> = None; + let mut voice_sample_rate: u32 = 16_000; + + // Parse multipart fields + while let Some(field) = multipart.next_field().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MULTIPART_ERROR", e.to_string())), + ) + })? { + let name = field.name().unwrap_or("").to_string(); + + match name.as_str() { + "text" => { + text = Some(field.text().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())), + ) + })?); + } + "voice" => { + let data = field.bytes().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())), + ) + })?; + + // Decode audio file to PCM samples + let cursor = Cursor::new(data.to_vec()); + let pcm = to_16k_mono_from_reader(cursor).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "AUDIO_DECODE_ERROR", + format!("Failed to decode voice audio: {}", e), + )), + ) + })?; + + voice_samples = Some(pcm.samples); + voice_sample_rate = pcm.sample_rate; + } + _ => { + // Ignore unknown fields + tracing::debug!("Ignoring unknown field: {}", name); + } + } + } + + // Validate required fields + let text = text.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MISSING_TEXT", "Text field is required")), + ) + })?; + + let samples = voice_samples.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "MISSING_VOICE", + "Voice audio file is required", + )), + ) + })?; + + tracing::info!( + text_len = text.len(), + voice_samples = samples.len(), + "Generating TTS" + ); + + // Generate TTS with voice cloning + let mut chatterbox = state.chatterbox.lock().await; + let audio = chatterbox + .generate_tts_with_samples(&text, &samples, voice_sample_rate) + .map_err(|e| { + tracing::error!(error = %e, "TTS generation failed"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("TTS_ERROR", e.to_string())), + ) + })?; + + tracing::info!(samples = audio.len(), "TTS generation complete"); + + // Encode as WAV + let wav_data = encode_wav(&audio, SAMPLE_RATE); + + // Return WAV response + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "audio/wav") + .header( + header::CONTENT_DISPOSITION, + "attachment; filename=\"output.wav\"", + ) + .body(Body::from(wav_data)) + .unwrap()) +} + +/// Encode f32 samples as a WAV file in memory. +fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> { + let mut buf = Vec::new(); + + let num_samples = samples.len() as u32; + let num_channels: u16 = 1; + let bits_per_sample: u16 = 16; + let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; + let block_align = num_channels * bits_per_sample / 8; + let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; + let file_size = 36 + data_size; + + // RIFF header + buf.extend_from_slice(b"RIFF"); + buf.extend_from_slice(&file_size.to_le_bytes()); + buf.extend_from_slice(b"WAVE"); + + // fmt chunk + buf.extend_from_slice(b"fmt "); + buf.extend_from_slice(&16u32.to_le_bytes()); + buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format + buf.extend_from_slice(&num_channels.to_le_bytes()); + buf.extend_from_slice(&sample_rate.to_le_bytes()); + buf.extend_from_slice(&byte_rate.to_le_bytes()); + buf.extend_from_slice(&block_align.to_le_bytes()); + buf.extend_from_slice(&bits_per_sample.to_le_bytes()); + + // data chunk + buf.extend_from_slice(b"data"); + buf.extend_from_slice(&data_size.to_le_bytes()); + + // Convert f32 samples to i16 PCM + for &sample in samples { + let clamped = sample.clamp(-1.0, 1.0); + let int_sample = (clamped * 32767.0) as i16; + buf.extend_from_slice(&int_sample.to_le_bytes()); + } + + buf +} diff --git a/makima/src/server/messages.rs b/makima/src/server/messages.rs new file mode 100644 index 0000000..0c92447 --- /dev/null +++ b/makima/src/server/messages.rs @@ -0,0 +1,92 @@ +//! WebSocket and API message types for the makima server. + +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +/// Audio encoding format for WebSocket streaming. +#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum AudioEncoding { + /// 32-bit floating point PCM samples + Pcm32f, + /// 16-bit signed integer PCM samples + Pcm16, + /// Raw bytes (will be interpreted as PCM16) + Raw, +} + +/// Initial handshake message from client specifying audio format. +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct StartMessage { + /// Audio sample rate in Hz (e.g., 16000, 44100, 48000) + pub sample_rate: u32, + /// Number of audio channels (1 for mono, 2 for stereo) + pub channels: u16, + /// Audio encoding format + pub encoding: AudioEncoding, +} + +/// Stop message to terminate the session. +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct StopMessage { + /// Optional reason for stopping + pub reason: Option<String>, +} + +/// Wrapper for all WebSocket messages from client to server. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ClientMessage { + Start(StartMessage), + Stop(StopMessage), +} + +/// Transcription result message sent from server to client. +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TranscriptMessage { + /// Speaker identifier (e.g., "Speaker 0", "Speaker 1") + pub speaker: String, + /// Segment start time in seconds + pub start: f32, + /// Segment end time in seconds + pub end: f32, + /// Transcribed text + pub text: String, + /// Whether this is a final or interim result + pub is_final: bool, +} + +/// Wrapper for all WebSocket messages from server to client. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ServerMessage { + /// Session is ready for audio streaming + Ready { session_id: String }, + /// Transcription result + Transcript(TranscriptMessage), + /// Error occurred during processing + Error { code: String, message: String }, + /// Session has been stopped + Stopped { reason: String }, +} + +/// Error response for HTTP API endpoints. +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct ApiError { + /// Error code for programmatic handling + pub code: String, + /// Human-readable error message + pub message: String, +} + +impl ApiError { + pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self { + Self { + code: code.into(), + message: message.into(), + } + } +} diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs new file mode 100644 index 0000000..c33eeef --- /dev/null +++ b/makima/src/server/mod.rs @@ -0,0 +1,88 @@ +//! Web server module for the makima audio API. + +pub mod handlers; +pub mod messages; +pub mod openapi; +pub mod state; + +use axum::{ + routing::{get, post}, + Router, +}; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::trace::TraceLayer; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; + +use crate::server::handlers::{listen, tts}; +use crate::server::openapi::ApiDoc; +use crate::server::state::SharedState; + +/// Create the axum Router with all routes configured. +pub fn make_router(state: SharedState) -> Router { + // API v1 routes + let api_v1 = Router::new() + .route("/listen", get(listen::websocket_handler)) + .route("/tts/synthesize", post(tts::synthesize_handler)) + .with_state(state); + + let swagger = SwaggerUi::new("/swagger-ui") + .url("/api-docs/openapi.json", ApiDoc::openapi()); + + Router::new() + .nest("/api/v1", api_v1) + .merge(swagger) + .layer( + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any), + ) + .layer(TraceLayer::new_for_http()) +} + +/// Run the HTTP server with graceful shutdown support. +/// +/// # Arguments +/// * `state` - Shared application state containing ML models +/// * `addr` - Address to bind to (e.g., "0.0.0.0:8080") +pub async fn run_server(state: SharedState, addr: &str) -> anyhow::Result<()> { + let app = make_router(state); + let listener = tokio::net::TcpListener::bind(addr).await?; + + tracing::info!("Server listening on {}", addr); + tracing::info!("Swagger UI available at http://{}/swagger-ui", addr); + + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +/// Wait for shutdown signals (Ctrl+C or SIGTERM). +async fn shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("Shutdown signal received, starting graceful shutdown"); +} diff --git a/makima/src/server/openapi.rs b/makima/src/server/openapi.rs new file mode 100644 index 0000000..363d348 --- /dev/null +++ b/makima/src/server/openapi.rs @@ -0,0 +1,34 @@ +//! OpenAPI documentation configuration using utoipa. + +use utoipa::OpenApi; + +use crate::server::handlers::{listen, tts}; +use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage, TranscriptMessage}; + +#[derive(OpenApi)] +#[openapi( + info( + title = "Makima Audio API", + version = "1.0.0", + description = "Streaming audio APIs for speech-to-text and text-to-speech with voice cloning.", + license(name = "MIT"), + ), + paths( + listen::websocket_handler, + tts::synthesize_handler, + ), + components( + schemas( + ApiError, + AudioEncoding, + StartMessage, + StopMessage, + TranscriptMessage, + ) + ), + tags( + (name = "STT", description = "Speech-to-text streaming endpoints"), + (name = "TTS", description = "Text-to-speech synthesis endpoints"), + ) +)] +pub struct ApiDoc; diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs new file mode 100644 index 0000000..8eaf788 --- /dev/null +++ b/makima/src/server/state.rs @@ -0,0 +1,50 @@ +//! Application state holding shared ML models. + +use std::sync::Arc; +use tokio::sync::Mutex; + +use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer}; +use crate::tts::ChatterboxTTS; + +/// Shared application state containing ML models. +/// +/// Models are wrapped in `Mutex` for thread-safe mutable access during inference. +pub struct AppState { + /// Speech-to-text model (Parakeet) + pub parakeet: Mutex<ParakeetTDT>, + /// Speaker diarization model (Sortformer) + pub sortformer: Mutex<Sortformer>, + /// Text-to-speech model (ChatterboxTTS) + pub chatterbox: Mutex<ChatterboxTTS>, +} + +impl AppState { + /// Load all ML models from the specified directories. + /// + /// # Arguments + /// * `parakeet_model_dir` - Path to the Parakeet STT model directory + /// * `sortformer_model_path` - Path to the Sortformer diarization model file + /// * `tts_model_dir` - Optional path to the ChatterboxTTS model directory + pub fn new( + parakeet_model_dir: &str, + sortformer_model_path: &str, + tts_model_dir: Option<&str>, + ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { + let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?; + let sortformer = Sortformer::with_config( + sortformer_model_path, + None, + DiarizationConfig::callhome(), + )?; + let chatterbox = ChatterboxTTS::from_pretrained(tts_model_dir)?; + + Ok(Self { + parakeet: Mutex::new(parakeet), + sortformer: Mutex::new(sortformer), + chatterbox: Mutex::new(chatterbox), + }) + } +} + +/// Type alias for the shared application state. +pub type SharedState = Arc<AppState>; |
