//! Voice loading utilities for TTS voice cloning.
//!
//! Loads voice manifests and reference audio from the `voices/` directory.
//! Each voice is a directory containing:
//! - `manifest.json` — voice metadata (name, sample rate, backend, etc.)
//! - `reference.wav` — reference audio clip for voice cloning (5-15s, 24kHz mono)
use serde::Deserialize;
use std::path::{Path, PathBuf};
use crate::tts::{resample_to_24k, SAMPLE_RATE};
/// Default voice ID used when no voice is specified.
pub const DEFAULT_VOICE_ID: &str = "makima";
/// Voice manifest loaded from `voices/{voice_id}/manifest.json`.
#[derive(Debug, Clone, Deserialize)]
pub struct VoiceManifest {
pub name: String,
pub id: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default = "default_language")]
pub language: String,
#[serde(default)]
pub accent: Option<String>,
#[serde(default = "default_sample_rate")]
pub sample_rate: u32,
#[serde(default)]
pub format: Option<String>,
#[serde(default)]
pub model_backend: Option<String>,
#[serde(default = "default_reference_audio")]
pub reference_audio: String,
#[serde(default)]
pub notes: Option<String>,
}
fn default_language() -> String {
"en".to_string()
}
fn default_sample_rate() -> u32 {
24_000
}
fn default_reference_audio() -> String {
"reference.wav".to_string()
}
/// Loaded voice reference: manifest + decoded PCM samples at 24kHz.
#[derive(Debug, Clone)]
pub struct VoiceReference {
pub manifest: VoiceManifest,
/// PCM f32 samples resampled to 24kHz mono.
pub samples: Vec<f32>,
/// Always 24000 after resampling.
pub sample_rate: u32,
}
/// Resolve the base directory for voice data.
///
/// Looks for the `voices/` directory relative to the current working directory,
/// or falls back to the executable's directory.
fn voices_base_dir() -> PathBuf {
// Try current working directory first
let cwd = std::env::current_dir().unwrap_or_default();
let cwd_voices = cwd.join("voices");
if cwd_voices.is_dir() {
return cwd_voices;
}
// Try relative to executable
if let Ok(exe) = std::env::current_exe() {
if let Some(exe_dir) = exe.parent() {
let exe_voices = exe_dir.join("voices");
if exe_voices.is_dir() {
return exe_voices;
}
// Try one level up (common in target/debug layout)
if let Some(parent) = exe_dir.parent() {
let parent_voices = parent.join("voices");
if parent_voices.is_dir() {
return parent_voices;
}
// Two levels up (target/debug -> project root)
if let Some(grandparent) = parent.parent() {
let gp_voices = grandparent.join("voices");
if gp_voices.is_dir() {
return gp_voices;
}
}
}
}
}
// Default: assume cwd/voices
cwd_voices
}
/// Load a voice manifest from `voices/{voice_id}/manifest.json`.
pub fn load_manifest(voice_id: &str) -> Result<VoiceManifest, VoiceLoadError> {
let base = voices_base_dir();
let manifest_path = base.join(voice_id).join("manifest.json");
if !manifest_path.exists() {
return Err(VoiceLoadError::NotFound(voice_id.to_string()));
}
let data = std::fs::read_to_string(&manifest_path).map_err(|e| {
VoiceLoadError::Io(format!(
"failed to read manifest at {}: {e}",
manifest_path.display()
))
})?;
let manifest: VoiceManifest = serde_json::from_str(&data).map_err(|e| {
VoiceLoadError::InvalidManifest(format!("failed to parse manifest: {e}"))
})?;
Ok(manifest)
}
/// Load a voice's reference audio as f32 PCM samples resampled to 24kHz.
///
/// Uses symphonia (via `crate::audio`) to decode the WAV file, then
/// resamples to 24kHz using `tts::resample_to_24k`.
pub fn load_reference_audio(voice_id: &str) -> Result<VoiceReference, VoiceLoadError> {
let manifest = load_manifest(voice_id)?;
let base = voices_base_dir();
let audio_path = base.join(voice_id).join(&manifest.reference_audio);
if !audio_path.exists() {
return Err(VoiceLoadError::MissingAudio(format!(
"reference audio not found at {}. See voices/{}/README.md for instructions.",
audio_path.display(),
voice_id,
)));
}
load_reference_audio_from_path(&audio_path, manifest)
}
/// Load reference audio from a specific file path with a pre-loaded manifest.
fn load_reference_audio_from_path(
audio_path: &Path,
manifest: VoiceManifest,
) -> Result<VoiceReference, VoiceLoadError> {
// Use symphonia-based decoder from crate::audio to decode the WAV
let pcm = crate::audio::to_16k_mono_from_path(audio_path).map_err(|e| {
VoiceLoadError::AudioDecode(format!("failed to decode {}: {e}", audio_path.display()))
})?;
// The audio module decodes to 16kHz mono; we need 24kHz for TTS.
// Resample from 16kHz to 24kHz.
let samples = if pcm.sample_rate == SAMPLE_RATE {
pcm.samples
} else {
resample_to_24k(&pcm.samples, pcm.sample_rate)
};
tracing::info!(
voice_id = %manifest.id,
voice_name = %manifest.name,
samples_len = samples.len(),
duration_secs = samples.len() as f32 / SAMPLE_RATE as f32,
"Loaded voice reference audio"
);
Ok(VoiceReference {
manifest,
samples,
sample_rate: SAMPLE_RATE,
})
}
/// Errors that can occur when loading a voice.
#[derive(Debug)]
pub enum VoiceLoadError {
/// Voice directory not found.
NotFound(String),
/// IO error reading files.
Io(String),
/// Manifest JSON is invalid.
InvalidManifest(String),
/// Reference audio file is missing.
MissingAudio(String),
/// Failed to decode audio.
AudioDecode(String),
}
impl std::fmt::Display for VoiceLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VoiceLoadError::NotFound(id) => {
write!(f, "voice '{id}' not found (no voices/{id}/manifest.json)")
}
VoiceLoadError::Io(msg) => write!(f, "voice IO error: {msg}"),
VoiceLoadError::InvalidManifest(msg) => write!(f, "invalid voice manifest: {msg}"),
VoiceLoadError::MissingAudio(msg) => write!(f, "missing reference audio: {msg}"),
VoiceLoadError::AudioDecode(msg) => write!(f, "audio decode error: {msg}"),
}
}
}
impl std::error::Error for VoiceLoadError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_voice_id() {
assert_eq!(DEFAULT_VOICE_ID, "makima");
}
#[test]
fn test_manifest_deserialize() {
let json = r#"{
"name": "Test Voice",
"id": "test",
"sample_rate": 24000,
"reference_audio": "reference.wav"
}"#;
let manifest: VoiceManifest = serde_json::from_str(json).unwrap();
assert_eq!(manifest.name, "Test Voice");
assert_eq!(manifest.id, "test");
assert_eq!(manifest.sample_rate, 24000);
assert_eq!(manifest.reference_audio, "reference.wav");
assert_eq!(manifest.language, "en");
}
#[test]
fn test_manifest_deserialize_defaults() {
let json = r#"{"name": "Minimal", "id": "min"}"#;
let manifest: VoiceManifest = serde_json::from_str(json).unwrap();
assert_eq!(manifest.language, "en");
assert_eq!(manifest.sample_rate, 24000);
assert_eq!(manifest.reference_audio, "reference.wav");
}
#[test]
fn test_load_nonexistent_voice() {
let result = load_manifest("nonexistent_voice_xyz");
assert!(result.is_err());
match result.unwrap_err() {
VoiceLoadError::NotFound(id) => assert_eq!(id, "nonexistent_voice_xyz"),
other => panic!("Expected NotFound, got: {other}"),
}
}
}