diff options
Diffstat (limited to 'makima/src/server/handlers/tts.rs')
| -rw-r--r-- | makima/src/server/handlers/tts.rs | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs new file mode 100644 index 0000000..94a261f --- /dev/null +++ b/makima/src/server/handlers/tts.rs @@ -0,0 +1,185 @@ +//! HTTP handler for text-to-speech synthesis. + +use axum::{ + body::Body, + extract::{Multipart, State}, + http::{header, StatusCode}, + response::Response, + Json, +}; +use std::io::Cursor; + +use crate::audio::to_16k_mono_from_reader; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; +use crate::tts::SAMPLE_RATE; + +/// POST /api/v1/tts/synthesize +/// +/// Synthesize speech from text using voice cloning. +/// +/// Accepts multipart form data with: +/// - `text`: The text to synthesize (required) +/// - `voice`: Audio file for voice cloning reference (required) +/// +/// Returns: WAV audio file (24kHz mono) +#[utoipa::path( + post, + path = "/api/v1/tts/synthesize", + request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"), + responses( + (status = 200, description = "Generated audio file", content_type = "audio/wav"), + (status = 400, description = "Bad request", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + tag = "TTS" +)] +pub async fn synthesize_handler( + State(state): State<SharedState>, + mut multipart: Multipart, +) -> Result<Response, (StatusCode, Json<ApiError>)> { + let mut text: Option<String> = None; + let mut voice_samples: Option<Vec<f32>> = None; + let mut voice_sample_rate: u32 = 16_000; + + // Parse multipart fields + while let Some(field) = multipart.next_field().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MULTIPART_ERROR", e.to_string())), + ) + })? { + let name = field.name().unwrap_or("").to_string(); + + match name.as_str() { + "text" => { + text = Some(field.text().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())), + ) + })?); + } + "voice" => { + let data = field.bytes().await.map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())), + ) + })?; + + // Decode audio file to PCM samples + let cursor = Cursor::new(data.to_vec()); + let pcm = to_16k_mono_from_reader(cursor).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "AUDIO_DECODE_ERROR", + format!("Failed to decode voice audio: {}", e), + )), + ) + })?; + + voice_samples = Some(pcm.samples); + voice_sample_rate = pcm.sample_rate; + } + _ => { + // Ignore unknown fields + tracing::debug!("Ignoring unknown field: {}", name); + } + } + } + + // Validate required fields + let text = text.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MISSING_TEXT", "Text field is required")), + ) + })?; + + let samples = voice_samples.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "MISSING_VOICE", + "Voice audio file is required", + )), + ) + })?; + + tracing::info!( + text_len = text.len(), + voice_samples = samples.len(), + "Generating TTS" + ); + + // Generate TTS with voice cloning + let mut chatterbox = state.chatterbox.lock().await; + let audio = chatterbox + .generate_tts_with_samples(&text, &samples, voice_sample_rate) + .map_err(|e| { + tracing::error!(error = %e, "TTS generation failed"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("TTS_ERROR", e.to_string())), + ) + })?; + + tracing::info!(samples = audio.len(), "TTS generation complete"); + + // Encode as WAV + let wav_data = encode_wav(&audio, SAMPLE_RATE); + + // Return WAV response + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "audio/wav") + .header( + header::CONTENT_DISPOSITION, + "attachment; filename=\"output.wav\"", + ) + .body(Body::from(wav_data)) + .unwrap()) +} + +/// Encode f32 samples as a WAV file in memory. +fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> { + let mut buf = Vec::new(); + + let num_samples = samples.len() as u32; + let num_channels: u16 = 1; + let bits_per_sample: u16 = 16; + let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8; + let block_align = num_channels * bits_per_sample / 8; + let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8; + let file_size = 36 + data_size; + + // RIFF header + buf.extend_from_slice(b"RIFF"); + buf.extend_from_slice(&file_size.to_le_bytes()); + buf.extend_from_slice(b"WAVE"); + + // fmt chunk + buf.extend_from_slice(b"fmt "); + buf.extend_from_slice(&16u32.to_le_bytes()); + buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format + buf.extend_from_slice(&num_channels.to_le_bytes()); + buf.extend_from_slice(&sample_rate.to_le_bytes()); + buf.extend_from_slice(&byte_rate.to_le_bytes()); + buf.extend_from_slice(&block_align.to_le_bytes()); + buf.extend_from_slice(&bits_per_sample.to_le_bytes()); + + // data chunk + buf.extend_from_slice(b"data"); + buf.extend_from_slice(&data_size.to_le_bytes()); + + // Convert f32 samples to i16 PCM + for &sample in samples { + let clamped = sample.clamp(-1.0, 1.0); + let int_sample = (clamped * 32767.0) as i16; + buf.extend_from_slice(&int_sample.to_le_bytes()); + } + + buf +} |
