summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/tts.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/handlers/tts.rs')
-rw-r--r--makima/src/server/handlers/tts.rs185
1 files changed, 185 insertions, 0 deletions
diff --git a/makima/src/server/handlers/tts.rs b/makima/src/server/handlers/tts.rs
new file mode 100644
index 0000000..94a261f
--- /dev/null
+++ b/makima/src/server/handlers/tts.rs
@@ -0,0 +1,185 @@
+//! HTTP handler for text-to-speech synthesis.
+
+use axum::{
+ body::Body,
+ extract::{Multipart, State},
+ http::{header, StatusCode},
+ response::Response,
+ Json,
+};
+use std::io::Cursor;
+
+use crate::audio::to_16k_mono_from_reader;
+use crate::server::messages::ApiError;
+use crate::server::state::SharedState;
+use crate::tts::SAMPLE_RATE;
+
+/// POST /api/v1/tts/synthesize
+///
+/// Synthesize speech from text using voice cloning.
+///
+/// Accepts multipart form data with:
+/// - `text`: The text to synthesize (required)
+/// - `voice`: Audio file for voice cloning reference (required)
+///
+/// Returns: WAV audio file (24kHz mono)
+#[utoipa::path(
+ post,
+ path = "/api/v1/tts/synthesize",
+ request_body(content_type = "multipart/form-data", description = "Text and voice audio for synthesis"),
+ responses(
+ (status = 200, description = "Generated audio file", content_type = "audio/wav"),
+ (status = 400, description = "Bad request", body = ApiError),
+ (status = 500, description = "Internal server error", body = ApiError),
+ ),
+ tag = "TTS"
+)]
+pub async fn synthesize_handler(
+ State(state): State<SharedState>,
+ mut multipart: Multipart,
+) -> Result<Response, (StatusCode, Json<ApiError>)> {
+ let mut text: Option<String> = None;
+ let mut voice_samples: Option<Vec<f32>> = None;
+ let mut voice_sample_rate: u32 = 16_000;
+
+ // Parse multipart fields
+ while let Some(field) = multipart.next_field().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("MULTIPART_ERROR", e.to_string())),
+ )
+ })? {
+ let name = field.name().unwrap_or("").to_string();
+
+ match name.as_str() {
+ "text" => {
+ text = Some(field.text().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("TEXT_FIELD_ERROR", e.to_string())),
+ )
+ })?);
+ }
+ "voice" => {
+ let data = field.bytes().await.map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("VOICE_FIELD_ERROR", e.to_string())),
+ )
+ })?;
+
+ // Decode audio file to PCM samples
+ let cursor = Cursor::new(data.to_vec());
+ let pcm = to_16k_mono_from_reader(cursor).map_err(|e| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new(
+ "AUDIO_DECODE_ERROR",
+ format!("Failed to decode voice audio: {}", e),
+ )),
+ )
+ })?;
+
+ voice_samples = Some(pcm.samples);
+ voice_sample_rate = pcm.sample_rate;
+ }
+ _ => {
+ // Ignore unknown fields
+ tracing::debug!("Ignoring unknown field: {}", name);
+ }
+ }
+ }
+
+ // Validate required fields
+ let text = text.ok_or_else(|| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("MISSING_TEXT", "Text field is required")),
+ )
+ })?;
+
+ let samples = voice_samples.ok_or_else(|| {
+ (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new(
+ "MISSING_VOICE",
+ "Voice audio file is required",
+ )),
+ )
+ })?;
+
+ tracing::info!(
+ text_len = text.len(),
+ voice_samples = samples.len(),
+ "Generating TTS"
+ );
+
+ // Generate TTS with voice cloning
+ let mut chatterbox = state.chatterbox.lock().await;
+ let audio = chatterbox
+ .generate_tts_with_samples(&text, &samples, voice_sample_rate)
+ .map_err(|e| {
+ tracing::error!(error = %e, "TTS generation failed");
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("TTS_ERROR", e.to_string())),
+ )
+ })?;
+
+ tracing::info!(samples = audio.len(), "TTS generation complete");
+
+ // Encode as WAV
+ let wav_data = encode_wav(&audio, SAMPLE_RATE);
+
+ // Return WAV response
+ Ok(Response::builder()
+ .status(StatusCode::OK)
+ .header(header::CONTENT_TYPE, "audio/wav")
+ .header(
+ header::CONTENT_DISPOSITION,
+ "attachment; filename=\"output.wav\"",
+ )
+ .body(Body::from(wav_data))
+ .unwrap())
+}
+
+/// Encode f32 samples as a WAV file in memory.
+fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> {
+ let mut buf = Vec::new();
+
+ let num_samples = samples.len() as u32;
+ let num_channels: u16 = 1;
+ let bits_per_sample: u16 = 16;
+ let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
+ let block_align = num_channels * bits_per_sample / 8;
+ let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
+ let file_size = 36 + data_size;
+
+ // RIFF header
+ buf.extend_from_slice(b"RIFF");
+ buf.extend_from_slice(&file_size.to_le_bytes());
+ buf.extend_from_slice(b"WAVE");
+
+ // fmt chunk
+ buf.extend_from_slice(b"fmt ");
+ buf.extend_from_slice(&16u32.to_le_bytes());
+ buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
+ buf.extend_from_slice(&num_channels.to_le_bytes());
+ buf.extend_from_slice(&sample_rate.to_le_bytes());
+ buf.extend_from_slice(&byte_rate.to_le_bytes());
+ buf.extend_from_slice(&block_align.to_le_bytes());
+ buf.extend_from_slice(&bits_per_sample.to_le_bytes());
+
+ // data chunk
+ buf.extend_from_slice(b"data");
+ buf.extend_from_slice(&data_size.to_le_bytes());
+
+ // Convert f32 samples to i16 PCM
+ for &sample in samples {
+ let clamped = sample.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ buf.extend_from_slice(&int_sample.to_le_bytes());
+ }
+
+ buf
+}