summaryrefslogblamecommitdiff
path: root/makima/src/audio.rs
blob: 8c969be21ec94b7458975849f707afe09d0b303c (plain) (tree)
















































































































































































































































                                                                                                      





                                                                                           





































































































































                                                                                              
use std::fs::File;
use std::io::{self, Read, Seek};
use std::path::Path;

use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::errors::Error as SymphoniaError;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::{MediaSourceStream, ReadOnlySource};
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;

pub const TARGET_SAMPLE_RATE: u32 = 16_000;
pub const TARGET_CHANNELS: u16 = 1;

#[derive(Debug, Clone)]
pub struct PcmAudio {
    pub samples: Vec<f32>,
    pub sample_rate: u32,
    pub channels: u16,
}

#[derive(Debug)]
pub enum AudioError {
    Io(io::Error),
    Decode(String),
    UnsupportedFormat,
    NoAudioTrack,
}

impl std::fmt::Display for AudioError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            AudioError::Io(err) => write!(f, "io error: {err}"),
            AudioError::Decode(err) => write!(f, "decode error: {err}"),
            AudioError::UnsupportedFormat => write!(f, "unsupported audio format"),
            AudioError::NoAudioTrack => write!(f, "no audio track found"),
        }
    }
}

impl std::error::Error for AudioError {}

impl From<io::Error> for AudioError {
    fn from(value: io::Error) -> Self {
        AudioError::Io(value)
    }
}

impl From<SymphoniaError> for AudioError {
    fn from(value: SymphoniaError) -> Self {
        match value {
            SymphoniaError::IoError(e) => AudioError::Io(e),
            SymphoniaError::Unsupported(_) => AudioError::UnsupportedFormat,
            other => AudioError::Decode(other.to_string()),
        }
    }
}

pub fn to_16k_mono_from_path(path: impl AsRef<Path>) -> Result<PcmAudio, AudioError> {
    let path = path.as_ref();
    let file = File::open(path)?;

    let mut hint = Hint::new();
    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
        hint.with_extension(ext);
    }

    decode_to_16k_mono(file, hint)
}

pub fn to_16k_mono_from_reader<R: Read + Seek + Send + Sync + 'static>(
    reader: R,
) -> Result<PcmAudio, AudioError> {
    decode_to_16k_mono(reader, Hint::new())
}

fn decode_to_16k_mono<R: Read + Seek + Send + Sync + 'static>(
    reader: R,
    hint: Hint,
) -> Result<PcmAudio, AudioError> {
    let source = MediaSourceStream::new(Box::new(ReadOnlySource::new(reader)), Default::default());

    let format_opts = FormatOptions::default();
    let metadata_opts = MetadataOptions::default();

    let probed = symphonia::default::get_probe().format(&hint, source, &format_opts, &metadata_opts)?;
    let mut format = probed.format;

    let track = format
        .tracks()
        .iter()
        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
        .ok_or(AudioError::NoAudioTrack)?;

    let track_id = track.id;
    let codec_params = track.codec_params.clone();

    let sample_rate = codec_params.sample_rate.ok_or(AudioError::Decode(
        "unknown sample rate".to_string(),
    ))?;
    let channels = codec_params
        .channels
        .map(|c| c.count() as u16)
        .unwrap_or(1);

    let decoder_opts = DecoderOptions::default();
    let mut decoder = symphonia::default::get_codecs().make(&codec_params, &decoder_opts)?;

    let mut interleaved: Vec<f32> = Vec::new();

    loop {
        let packet = match format.next_packet() {
            Ok(p) => p,
            Err(SymphoniaError::IoError(ref e)) if e.kind() == io::ErrorKind::UnexpectedEof => break,
            Err(SymphoniaError::ResetRequired) => {
                decoder.reset();
                continue;
            }
            Err(e) => return Err(e.into()),
        };

        if packet.track_id() != track_id {
            continue;
        }

        let decoded = match decoder.decode(&packet) {
            Ok(d) => d,
            Err(SymphoniaError::DecodeError(_)) => continue,
            Err(e) => return Err(e.into()),
        };

        append_samples(&decoded, &mut interleaved);
    }

    let mono = mixdown_to_mono(&interleaved, channels);
    let samples = resample_sinc(&mono, sample_rate, TARGET_SAMPLE_RATE);

    Ok(PcmAudio {
        samples,
        sample_rate: TARGET_SAMPLE_RATE,
        channels: TARGET_CHANNELS,
    })
}

fn append_samples(buffer: &AudioBufferRef, out: &mut Vec<f32>) {
    match buffer {
        AudioBufferRef::U8(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f32 - 128.0) / 128.0);
                }
            }
        }
        AudioBufferRef::U16(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f32 - 32768.0) / 32768.0);
                }
            }
        }
        AudioBufferRef::U24(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0);
                }
            }
        }
        AudioBufferRef::U32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0);
                }
            }
        }
        AudioBufferRef::S8(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 128.0);
                }
            }
        }
        AudioBufferRef::S16(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 32768.0);
                }
            }
        }
        AudioBufferRef::S24(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame].inner() as f32 / 8388608.0);
                }
            }
        }
        AudioBufferRef::S32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 2147483648.0);
                }
            }
        }
        AudioBufferRef::F32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame]);
                }
            }
        }
        AudioBufferRef::F64(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32);
                }
            }
        }
    }
}

fn mixdown_to_mono(interleaved: &[f32], channels: u16) -> Vec<f32> {
    if channels <= 1 {
        return interleaved.to_vec();
    }

    let channels = channels as usize;
    let frames = interleaved.len() / channels;

    let mut mono = Vec::with_capacity(frames);
    for frame in 0..frames {
        let base = frame * channels;
        let mut acc = 0.0f32;
        for c in 0..channels {
            acc += interleaved[base + c];
        }
        mono.push(acc / channels as f32);
    }

    mono
}

/// Resample and mixdown audio to 16kHz mono for STT processing.
pub fn resample_and_mixdown(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<f32> {
    let mono = mixdown_to_mono(samples, channels);
    resample_sinc(&mono, sample_rate, TARGET_SAMPLE_RATE)
}

fn resample_sinc(input: &[f32], input_rate: u32, output_rate: u32) -> Vec<f32> {
    if input_rate == output_rate {
        return input.to_vec();
    }
    if input.is_empty() {
        return Vec::new();
    }

    let ratio = input_rate as f64 / output_rate as f64;
    let output_len = ((input.len() as f64) / ratio).ceil() as usize;

    let cutoff = (output_rate as f64 / input_rate as f64).min(1.0);

    let radius: i32 = 32;
    let radius_f = radius as f64;
    let pi = std::f64::consts::PI;

    let mut output = Vec::with_capacity(output_len);
    for n in 0..output_len {
        let t = n as f64 * ratio;
        let center = t.floor() as i32;
        let frac = t - (center as f64);

        let mut acc = 0.0f64;
        let mut norm = 0.0f64;

        for k in -radius..=radius {
            let idx = center + k;
            if idx < 0 || (idx as usize) >= input.len() {
                continue;
            }

            let x = (k as f64) - frac;
            let d = x.abs();
            if d > radius_f {
                continue;
            }

            let window = 0.5 * (1.0 + (pi * d / radius_f).cos());

            let z = x * cutoff;
            let sinc = if z == 0.0 {
                1.0
            } else {
                let pz = pi * z;
                pz.sin() / pz
            };

            let weight = cutoff * sinc * window;
            acc += input[idx as usize] as f64 * weight;
            norm += weight;
        }

        let y = if norm == 0.0 { 0.0 } else { acc / norm };
        output.push(y as f32);
    }

    output
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Cursor;

    fn create_wav_buffer(sample_rate: u32, channels: u16, samples: &[i16]) -> Vec<u8> {
        let mut buf = Vec::new();
        let data_size = (samples.len() * 2) as u32;
        let file_size = 36 + data_size;

        buf.extend_from_slice(b"RIFF");
        buf.extend_from_slice(&file_size.to_le_bytes());
        buf.extend_from_slice(b"WAVE");

        buf.extend_from_slice(b"fmt ");
        buf.extend_from_slice(&16u32.to_le_bytes());
        buf.extend_from_slice(&1u16.to_le_bytes());
        buf.extend_from_slice(&channels.to_le_bytes());
        buf.extend_from_slice(&sample_rate.to_le_bytes());
        let byte_rate = sample_rate * channels as u32 * 2;
        buf.extend_from_slice(&byte_rate.to_le_bytes());
        let block_align = channels * 2;
        buf.extend_from_slice(&block_align.to_le_bytes());
        buf.extend_from_slice(&16u16.to_le_bytes());

        buf.extend_from_slice(b"data");
        buf.extend_from_slice(&data_size.to_le_bytes());
        for &s in samples {
            buf.extend_from_slice(&s.to_le_bytes());
        }

        buf
    }

    #[test]
    fn converts_stereo_to_mono() {
        let mut samples = Vec::new();
        for _ in 0..(TARGET_SAMPLE_RATE / 10) {
            samples.push(10_000i16);
            samples.push(0i16);
        }

        let wav = create_wav_buffer(TARGET_SAMPLE_RATE, 2, &samples);
        let cursor = Cursor::new(wav);

        let normalized = to_16k_mono_from_reader(cursor).unwrap();

        assert_eq!(normalized.sample_rate, TARGET_SAMPLE_RATE);
        assert_eq!(normalized.channels, TARGET_CHANNELS);
        let mean =
            normalized.samples.iter().copied().sum::<f32>() / normalized.samples.len() as f32;
        let expected = (10_000.0 / 32768.0) / 2.0;
        assert!((mean - expected).abs() < 1e-3);
    }

    #[test]
    fn resamples_to_16k() {
        let samples: Vec<i16> = vec![0; 48_000];
        let wav = create_wav_buffer(48_000, 1, &samples);
        let cursor = Cursor::new(wav);

        let normalized = to_16k_mono_from_reader(cursor).unwrap();

        assert_eq!(normalized.sample_rate, TARGET_SAMPLE_RATE);
        assert_eq!(normalized.channels, TARGET_CHANNELS);
        assert_eq!(normalized.samples.len(), TARGET_SAMPLE_RATE as usize);
        let max_abs = normalized
            .samples
            .iter()
            .copied()
            .fold(0.0f32, |m, v| m.max(v.abs()));
        assert!(max_abs <= 1e-6);
    }
}