diff options
Diffstat (limited to 'makima/src/server')
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 1 | ||||
| -rw-r--r-- | makima/src/server/handlers/tts.rs | 185 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 5 | ||||
| -rw-r--r-- | makima/src/server/openapi.rs | 10 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 7 |
5 files changed, 6 insertions, 202 deletions
diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index 90798f9..94b0384 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,4 +1,3 @@ //! HTTP and WebSocket request handlers. pub mod listen; -pub mod tts; diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs deleted file mode 100644 index 94a261f..0000000 --- a/makima/src/server/handlers/tts.rs +++ /dev/null @@ -1,185 +0,0 @@ -//! HTTP handler for text-to-speech synthesis. - -use axum::{ - body::Body, - extract::{Multipart, State}, - http::{header, StatusCode}, - response::Response, - Json, -}; -use std::io::Cursor; - -use crate::audio::to_16k_mono_from_reader; -use crate::server::messages::ApiError; -use crate::server::state::SharedState; -use crate::tts::SAMPLE_RATE; - -/// POST /api/v1/tts/synthesize -/// -/// Synthesize speech from text using voice cloning. -/// -/// Accepts multipart form data with: -/// - `text`: The text to synthesize (required) -/// - `voice`: Audio file for voice cloning reference (required) -/// -/// Returns: WAV audio file (24kHz mono) -#[utoipa::path( - post, - path = "/api/v1/tts/synthesize", - request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"), - responses( - (status = 200, description = "Generated audio file", content_type = "audio/wav"), - (status = 400, description = "Bad request", body = ApiError), - (status = 500, description = "Internal server error", body = ApiError), - ), - tag = "TTS" -)] -pub async fn synthesize_handler( - State(state): State<SharedState>, - mut multipart: Multipart, -) -> Result<Response, (StatusCode, Json<ApiError>)> { - let mut text: Option<String> = None; - let mut voice_samples: Option<Vec<f32>> = None; - let mut voice_sample_rate: u32 = 16_000; - - // Parse multipart fields - while let Some(field) = multipart.next_field().await.map_err(|e| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new("MULTIPART_ERROR", e.to_string())), - ) - })? { - let name = field.name().unwrap_or("").to_string(); - - match name.as_str() { - "text" => { - text = Some(field.text().await.map_err(|e| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())), - ) - })?); - } - "voice" => { - let data = field.bytes().await.map_err(|e| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())), - ) - })?; - - // Decode audio file to PCM samples - let cursor = Cursor::new(data.to_vec()); - let pcm = to_16k_mono_from_reader(cursor).map_err(|e| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new( - "AUDIO_DECODE_ERROR", - format!("Failed to decode voice audio: {}", e), - )), - ) - })?; - - voice_samples = Some(pcm.samples); - voice_sample_rate = pcm.sample_rate; - } - _ => { - // Ignore unknown fields - tracing::debug!("Ignoring unknown field: {}", name); - } - } - } - - // Validate required fields - let text = text.ok_or_else(|| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new("MISSING_TEXT", "Text field is required")), - ) - })?; - - let samples = voice_samples.ok_or_else(|| { - ( - StatusCode::BAD_REQUEST, - Json(ApiError::new( - "MISSING_VOICE", - "Voice audio file is required", - )), - ) - })?; - - tracing::info!( - text_len = text.len(), - voice_samples = samples.len(), - "Generating TTS" - ); - - // Generate TTS with voice cloning - let mut chatterbox = state.chatterbox.lock().await; - let audio = chatterbox - .generate_tts_with_samples(&text, &samples, voice_sample_rate) - .map_err(|e| { - tracing::error!(error = %e, "TTS generation failed"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ApiError::new("TTS_ERROR", e.to_string())), - ) - })?; - - tracing::info!(samples = audio.len(), "TTS generation complete"); - - // Encode as WAV - let wav_data = encode_wav(&audio, SAMPLE_RATE); - - // Return WAV response - Ok(Response::builder() - .status(StatusCode::OK) - .header(header::CONTENT_TYPE, "audio/wav") - .header( - header::CONTENT_DISPOSITION, - "attachment; filename=\"output.wav\"", - ) - .body(Body::from(wav_data)) - .unwrap()) -} - -/// Encode f32 samples as a WAV file in memory. -fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> { - let mut buf = Vec::new(); - - let num_samples = samples.len() as u32; - let num_channels: u16 = 1; - let bits_per_sample: u16 = 16; - let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; - let block_align = num_channels * bits_per_sample / 8; - let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; - let file_size = 36 + data_size; - - // RIFF header - buf.extend_from_slice(b"RIFF"); - buf.extend_from_slice(&file_size.to_le_bytes()); - buf.extend_from_slice(b"WAVE"); - - // fmt chunk - buf.extend_from_slice(b"fmt "); - buf.extend_from_slice(&16u32.to_le_bytes()); - buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format - buf.extend_from_slice(&num_channels.to_le_bytes()); - buf.extend_from_slice(&sample_rate.to_le_bytes()); - buf.extend_from_slice(&byte_rate.to_le_bytes()); - buf.extend_from_slice(&block_align.to_le_bytes()); - buf.extend_from_slice(&bits_per_sample.to_le_bytes()); - - // data chunk - buf.extend_from_slice(b"data"); - buf.extend_from_slice(&data_size.to_le_bytes()); - - // Convert f32 samples to i16 PCM - for &sample in samples { - let clamped = sample.clamp(-1.0, 1.0); - let int_sample = (clamped * 32767.0) as i16; - buf.extend_from_slice(&int_sample.to_le_bytes()); - } - - buf -} diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs index c33eeef..a6e0525 100644 --- a/makima/src/server/mod.rs +++ b/makima/src/server/mod.rs @@ -6,7 +6,7 @@ pub mod openapi; pub mod state; use axum::{ - routing::{get, post}, + routing::get, Router, }; use tower_http::cors::{Any, CorsLayer}; @@ -14,7 +14,7 @@ use tower_http::trace::TraceLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::server::handlers::{listen, tts}; +use crate::server::handlers::listen; use crate::server::openapi::ApiDoc; use crate::server::state::SharedState; @@ -23,7 +23,6 @@ pub fn make_router(state: SharedState) -> Router { // API v1 routes let api_v1 = Router::new() .route("/listen", get(listen::websocket_handler)) - .route("/tts/synthesize", post(tts::synthesize_handler)) .with_state(state); let swagger = SwaggerUi::new("/swagger-ui") diff --git a/makima/src/server/openapi.rs b/makima/src/server/openapi.rs index 363d348..3e8c06c 100644 --- a/makima/src/server/openapi.rs +++ b/makima/src/server/openapi.rs @@ -2,20 +2,19 @@ use utoipa::OpenApi; -use crate::server::handlers::{listen, tts}; +use crate::server::handlers::listen; use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage, TranscriptMessage}; #[derive(OpenApi)] #[openapi( info( - title = "Makima Audio API", + title = "Makima Listen API", version = "1.0.0", - description = "Streaming audio APIs for speech-to-text and text-to-speech with voice cloning.", + description = "Streaming audio APIs for speech-to-text.", license(name = "MIT"), ), paths( listen::websocket_handler, - tts::synthesize_handler, ), components( schemas( @@ -27,8 +26,7 @@ use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage ) ), tags( - (name = "STT", description = "Speech-to-text streaming endpoints"), - (name = "TTS", description = "Text-to-speech synthesis endpoints"), + (name = "Listen", description = "Speech-to-text streaming endpoints"), ) )] pub struct ApiDoc; diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index 8eaf788..c38359d 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use tokio::sync::Mutex; use crate::listen::{DiarizationConfig, ParakeetTDT, Sortformer}; -use crate::tts::ChatterboxTTS; /// Shared application state containing ML models. /// @@ -14,8 +13,6 @@ pub struct AppState { pub parakeet: Mutex<ParakeetTDT>, /// Speaker diarization model (Sortformer) pub sortformer: Mutex<Sortformer>, - /// Text-to-speech model (ChatterboxTTS) - pub chatterbox: Mutex<ChatterboxTTS>, } impl AppState { @@ -24,11 +21,9 @@ impl AppState { /// # Arguments /// * `parakeet_model_dir` - Path to the Parakeet STT model directory /// * `sortformer_model_path` - Path to the Sortformer diarization model file - /// * `tts_model_dir` - Optional path to the ChatterboxTTS model directory pub fn new( parakeet_model_dir: &str, sortformer_model_path: &str, - tts_model_dir: Option<&str>, ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?; let sortformer = Sortformer::with_config( @@ -36,12 +31,10 @@ impl AppState { None, DiarizationConfig::callhome(), )?; - let chatterbox = ChatterboxTTS::from_pretrained(tts_model_dir)?; Ok(Self { parakeet: Mutex::new(parakeet), sortformer: Mutex::new(sortformer), - chatterbox: Mutex::new(chatterbox), }) } } |
