summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/voice.rs
blob: 91b650d489400970d3cd4bf498efe8011cfe002f (plain) (tree)



























































































































































































































































                                                                                               
//! Voice loading utilities for TTS voice cloning.
//!
//! Loads voice manifests and reference audio from the `voices/` directory.
//! Each voice is a directory containing:
//! - `manifest.json` — voice metadata (name, sample rate, backend, etc.)
//! - `reference.wav` — reference audio clip for voice cloning (5-15s, 24kHz mono)

use serde::Deserialize;
use std::path::{Path, PathBuf};

use crate::tts::{resample_to_24k, SAMPLE_RATE};

/// Default voice ID used when no voice is specified.
pub const DEFAULT_VOICE_ID: &str = "makima";

/// Voice manifest loaded from `voices/{voice_id}/manifest.json`.
#[derive(Debug, Clone, Deserialize)]
pub struct VoiceManifest {
    pub name: String,
    pub id: String,
    #[serde(default)]
    pub description: Option<String>,
    #[serde(default = "default_language")]
    pub language: String,
    #[serde(default)]
    pub accent: Option<String>,
    #[serde(default = "default_sample_rate")]
    pub sample_rate: u32,
    #[serde(default)]
    pub format: Option<String>,
    #[serde(default)]
    pub model_backend: Option<String>,
    #[serde(default = "default_reference_audio")]
    pub reference_audio: String,
    #[serde(default)]
    pub notes: Option<String>,
}

fn default_language() -> String {
    "en".to_string()
}

fn default_sample_rate() -> u32 {
    24_000
}

fn default_reference_audio() -> String {
    "reference.wav".to_string()
}

/// Loaded voice reference: manifest + decoded PCM samples at 24kHz.
#[derive(Debug, Clone)]
pub struct VoiceReference {
    pub manifest: VoiceManifest,
    /// PCM f32 samples resampled to 24kHz mono.
    pub samples: Vec<f32>,
    /// Always 24000 after resampling.
    pub sample_rate: u32,
}

/// Resolve the base directory for voice data.
///
/// Looks for the `voices/` directory relative to the current working directory,
/// or falls back to the executable's directory.
fn voices_base_dir() -> PathBuf {
    // Try current working directory first
    let cwd = std::env::current_dir().unwrap_or_default();
    let cwd_voices = cwd.join("voices");
    if cwd_voices.is_dir() {
        return cwd_voices;
    }

    // Try relative to executable
    if let Ok(exe) = std::env::current_exe() {
        if let Some(exe_dir) = exe.parent() {
            let exe_voices = exe_dir.join("voices");
            if exe_voices.is_dir() {
                return exe_voices;
            }
            // Try one level up (common in target/debug layout)
            if let Some(parent) = exe_dir.parent() {
                let parent_voices = parent.join("voices");
                if parent_voices.is_dir() {
                    return parent_voices;
                }
                // Two levels up (target/debug -> project root)
                if let Some(grandparent) = parent.parent() {
                    let gp_voices = grandparent.join("voices");
                    if gp_voices.is_dir() {
                        return gp_voices;
                    }
                }
            }
        }
    }

    // Default: assume cwd/voices
    cwd_voices
}

/// Load a voice manifest from `voices/{voice_id}/manifest.json`.
pub fn load_manifest(voice_id: &str) -> Result<VoiceManifest, VoiceLoadError> {
    let base = voices_base_dir();
    let manifest_path = base.join(voice_id).join("manifest.json");

    if !manifest_path.exists() {
        return Err(VoiceLoadError::NotFound(voice_id.to_string()));
    }

    let data = std::fs::read_to_string(&manifest_path).map_err(|e| {
        VoiceLoadError::Io(format!(
            "failed to read manifest at {}: {e}",
            manifest_path.display()
        ))
    })?;

    let manifest: VoiceManifest = serde_json::from_str(&data).map_err(|e| {
        VoiceLoadError::InvalidManifest(format!("failed to parse manifest: {e}"))
    })?;

    Ok(manifest)
}

/// Load a voice's reference audio as f32 PCM samples resampled to 24kHz.
///
/// Uses symphonia (via `crate::audio`) to decode the WAV file, then
/// resamples to 24kHz using `tts::resample_to_24k`.
pub fn load_reference_audio(voice_id: &str) -> Result<VoiceReference, VoiceLoadError> {
    let manifest = load_manifest(voice_id)?;

    let base = voices_base_dir();
    let audio_path = base.join(voice_id).join(&manifest.reference_audio);

    if !audio_path.exists() {
        return Err(VoiceLoadError::MissingAudio(format!(
            "reference audio not found at {}. See voices/{}/README.md for instructions.",
            audio_path.display(),
            voice_id,
        )));
    }

    load_reference_audio_from_path(&audio_path, manifest)
}

/// Load reference audio from a specific file path with a pre-loaded manifest.
fn load_reference_audio_from_path(
    audio_path: &Path,
    manifest: VoiceManifest,
) -> Result<VoiceReference, VoiceLoadError> {
    // Use symphonia-based decoder from crate::audio to decode the WAV
    let pcm = crate::audio::to_16k_mono_from_path(audio_path).map_err(|e| {
        VoiceLoadError::AudioDecode(format!("failed to decode {}: {e}", audio_path.display()))
    })?;

    // The audio module decodes to 16kHz mono; we need 24kHz for TTS.
    // Resample from 16kHz to 24kHz.
    let samples = if pcm.sample_rate == SAMPLE_RATE {
        pcm.samples
    } else {
        resample_to_24k(&pcm.samples, pcm.sample_rate)
    };

    tracing::info!(
        voice_id = %manifest.id,
        voice_name = %manifest.name,
        samples_len = samples.len(),
        duration_secs = samples.len() as f32 / SAMPLE_RATE as f32,
        "Loaded voice reference audio"
    );

    Ok(VoiceReference {
        manifest,
        samples,
        sample_rate: SAMPLE_RATE,
    })
}

/// Errors that can occur when loading a voice.
#[derive(Debug)]
pub enum VoiceLoadError {
    /// Voice directory not found.
    NotFound(String),
    /// IO error reading files.
    Io(String),
    /// Manifest JSON is invalid.
    InvalidManifest(String),
    /// Reference audio file is missing.
    MissingAudio(String),
    /// Failed to decode audio.
    AudioDecode(String),
}

impl std::fmt::Display for VoiceLoadError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            VoiceLoadError::NotFound(id) => {
                write!(f, "voice '{id}' not found (no voices/{id}/manifest.json)")
            }
            VoiceLoadError::Io(msg) => write!(f, "voice IO error: {msg}"),
            VoiceLoadError::InvalidManifest(msg) => write!(f, "invalid voice manifest: {msg}"),
            VoiceLoadError::MissingAudio(msg) => write!(f, "missing reference audio: {msg}"),
            VoiceLoadError::AudioDecode(msg) => write!(f, "audio decode error: {msg}"),
        }
    }
}

impl std::error::Error for VoiceLoadError {}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_voice_id() {
        assert_eq!(DEFAULT_VOICE_ID, "makima");
    }

    #[test]
    fn test_manifest_deserialize() {
        let json = r#"{
            "name": "Test Voice",
            "id": "test",
            "sample_rate": 24000,
            "reference_audio": "reference.wav"
        }"#;
        let manifest: VoiceManifest = serde_json::from_str(json).unwrap();
        assert_eq!(manifest.name, "Test Voice");
        assert_eq!(manifest.id, "test");
        assert_eq!(manifest.sample_rate, 24000);
        assert_eq!(manifest.reference_audio, "reference.wav");
        assert_eq!(manifest.language, "en");
    }

    #[test]
    fn test_manifest_deserialize_defaults() {
        let json = r#"{"name": "Minimal", "id": "min"}"#;
        let manifest: VoiceManifest = serde_json::from_str(json).unwrap();
        assert_eq!(manifest.language, "en");
        assert_eq!(manifest.sample_rate, 24000);
        assert_eq!(manifest.reference_audio, "reference.wav");
    }

    #[test]
    fn test_load_nonexistent_voice() {
        let result = load_manifest("nonexistent_voice_xyz");
        assert!(result.is_err());
        match result.unwrap_err() {
            VoiceLoadError::NotFound(id) => assert_eq!(id, "nonexistent_voice_xyz"),
            other => panic!("Expected NotFound, got: {other}"),
        }
    }
}