diff options
Diffstat (limited to 'makima/src/server/handlers/voice.rs')
| -rw-r--r-- | makima/src/server/handlers/voice.rs | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/makima/src/server/handlers/voice.rs b/makima/src/server/handlers/voice.rs new file mode 100644 index 0000000..91b650d --- /dev/null +++ b/makima/src/server/handlers/voice.rs @@ -0,0 +1,252 @@ +//! Voice loading utilities for TTS voice cloning. +//! +//! Loads voice manifests and reference audio from the `voices/` directory. +//! Each voice is a directory containing: +//! - `manifest.json` — voice metadata (name, sample rate, backend, etc.) +//! - `reference.wav` — reference audio clip for voice cloning (5-15s, 24kHz mono) + +use serde::Deserialize; +use std::path::{Path, PathBuf}; + +use crate::tts::{resample_to_24k, SAMPLE_RATE}; + +/// Default voice ID used when no voice is specified. +pub const DEFAULT_VOICE_ID: &str = "makima"; + +/// Voice manifest loaded from `voices/{voice_id}/manifest.json`. +#[derive(Debug, Clone, Deserialize)] +pub struct VoiceManifest { + pub name: String, + pub id: String, + #[serde(default)] + pub description: Option<String>, + #[serde(default = "default_language")] + pub language: String, + #[serde(default)] + pub accent: Option<String>, + #[serde(default = "default_sample_rate")] + pub sample_rate: u32, + #[serde(default)] + pub format: Option<String>, + #[serde(default)] + pub model_backend: Option<String>, + #[serde(default = "default_reference_audio")] + pub reference_audio: String, + #[serde(default)] + pub notes: Option<String>, +} + +fn default_language() -> String { + "en".to_string() +} + +fn default_sample_rate() -> u32 { + 24_000 +} + +fn default_reference_audio() -> String { + "reference.wav".to_string() +} + +/// Loaded voice reference: manifest + decoded PCM samples at 24kHz. +#[derive(Debug, Clone)] +pub struct VoiceReference { + pub manifest: VoiceManifest, + /// PCM f32 samples resampled to 24kHz mono. + pub samples: Vec<f32>, + /// Always 24000 after resampling. + pub sample_rate: u32, +} + +/// Resolve the base directory for voice data. +/// +/// Looks for the `voices/` directory relative to the current working directory, +/// or falls back to the executable's directory. +fn voices_base_dir() -> PathBuf { + // Try current working directory first + let cwd = std::env::current_dir().unwrap_or_default(); + let cwd_voices = cwd.join("voices"); + if cwd_voices.is_dir() { + return cwd_voices; + } + + // Try relative to executable + if let Ok(exe) = std::env::current_exe() { + if let Some(exe_dir) = exe.parent() { + let exe_voices = exe_dir.join("voices"); + if exe_voices.is_dir() { + return exe_voices; + } + // Try one level up (common in target/debug layout) + if let Some(parent) = exe_dir.parent() { + let parent_voices = parent.join("voices"); + if parent_voices.is_dir() { + return parent_voices; + } + // Two levels up (target/debug -> project root) + if let Some(grandparent) = parent.parent() { + let gp_voices = grandparent.join("voices"); + if gp_voices.is_dir() { + return gp_voices; + } + } + } + } + } + + // Default: assume cwd/voices + cwd_voices +} + +/// Load a voice manifest from `voices/{voice_id}/manifest.json`. +pub fn load_manifest(voice_id: &str) -> Result<VoiceManifest, VoiceLoadError> { + let base = voices_base_dir(); + let manifest_path = base.join(voice_id).join("manifest.json"); + + if !manifest_path.exists() { + return Err(VoiceLoadError::NotFound(voice_id.to_string())); + } + + let data = std::fs::read_to_string(&manifest_path).map_err(|e| { + VoiceLoadError::Io(format!( + "failed to read manifest at {}: {e}", + manifest_path.display() + )) + })?; + + let manifest: VoiceManifest = serde_json::from_str(&data).map_err(|e| { + VoiceLoadError::InvalidManifest(format!("failed to parse manifest: {e}")) + })?; + + Ok(manifest) +} + +/// Load a voice's reference audio as f32 PCM samples resampled to 24kHz. +/// +/// Uses symphonia (via `crate::audio`) to decode the WAV file, then +/// resamples to 24kHz using `tts::resample_to_24k`. +pub fn load_reference_audio(voice_id: &str) -> Result<VoiceReference, VoiceLoadError> { + let manifest = load_manifest(voice_id)?; + + let base = voices_base_dir(); + let audio_path = base.join(voice_id).join(&manifest.reference_audio); + + if !audio_path.exists() { + return Err(VoiceLoadError::MissingAudio(format!( + "reference audio not found at {}. See voices/{}/README.md for instructions.", + audio_path.display(), + voice_id, + ))); + } + + load_reference_audio_from_path(&audio_path, manifest) +} + +/// Load reference audio from a specific file path with a pre-loaded manifest. +fn load_reference_audio_from_path( + audio_path: &Path, + manifest: VoiceManifest, +) -> Result<VoiceReference, VoiceLoadError> { + // Use symphonia-based decoder from crate::audio to decode the WAV + let pcm = crate::audio::to_16k_mono_from_path(audio_path).map_err(|e| { + VoiceLoadError::AudioDecode(format!("failed to decode {}: {e}", audio_path.display())) + })?; + + // The audio module decodes to 16kHz mono; we need 24kHz for TTS. + // Resample from 16kHz to 24kHz. + let samples = if pcm.sample_rate == SAMPLE_RATE { + pcm.samples + } else { + resample_to_24k(&pcm.samples, pcm.sample_rate) + }; + + tracing::info!( + voice_id = %manifest.id, + voice_name = %manifest.name, + samples_len = samples.len(), + duration_secs = samples.len() as f32 / SAMPLE_RATE as f32, + "Loaded voice reference audio" + ); + + Ok(VoiceReference { + manifest, + samples, + sample_rate: SAMPLE_RATE, + }) +} + +/// Errors that can occur when loading a voice. +#[derive(Debug)] +pub enum VoiceLoadError { + /// Voice directory not found. + NotFound(String), + /// IO error reading files. + Io(String), + /// Manifest JSON is invalid. + InvalidManifest(String), + /// Reference audio file is missing. + MissingAudio(String), + /// Failed to decode audio. + AudioDecode(String), +} + +impl std::fmt::Display for VoiceLoadError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VoiceLoadError::NotFound(id) => { + write!(f, "voice '{id}' not found (no voices/{id}/manifest.json)") + } + VoiceLoadError::Io(msg) => write!(f, "voice IO error: {msg}"), + VoiceLoadError::InvalidManifest(msg) => write!(f, "invalid voice manifest: {msg}"), + VoiceLoadError::MissingAudio(msg) => write!(f, "missing reference audio: {msg}"), + VoiceLoadError::AudioDecode(msg) => write!(f, "audio decode error: {msg}"), + } + } +} + +impl std::error::Error for VoiceLoadError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_voice_id() { + assert_eq!(DEFAULT_VOICE_ID, "makima"); + } + + #[test] + fn test_manifest_deserialize() { + let json = r#"{ + "name": "Test Voice", + "id": "test", + "sample_rate": 24000, + "reference_audio": "reference.wav" + }"#; + let manifest: VoiceManifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.name, "Test Voice"); + assert_eq!(manifest.id, "test"); + assert_eq!(manifest.sample_rate, 24000); + assert_eq!(manifest.reference_audio, "reference.wav"); + assert_eq!(manifest.language, "en"); + } + + #[test] + fn test_manifest_deserialize_defaults() { + let json = r#"{"name": "Minimal", "id": "min"}"#; + let manifest: VoiceManifest = serde_json::from_str(json).unwrap(); + assert_eq!(manifest.language, "en"); + assert_eq!(manifest.sample_rate, 24000); + assert_eq!(manifest.reference_audio, "reference.wav"); + } + + #[test] + fn test_load_nonexistent_voice() { + let result = load_manifest("nonexistent_voice_xyz"); + assert!(result.is_err()); + match result.unwrap_err() { + VoiceLoadError::NotFound(id) => assert_eq!(id, "nonexistent_voice_xyz"), + other => panic!("Expected NotFound, got: {other}"), + } + } +} |
