summaryrefslogblamecommitdiff
path: root/tools/stt-client/src/main.rs
blob: 8b81b601c83dd0c6109cbbf25c5eb1caa7d55964 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

















                                                                             


                                                                     







                                                                                    






                                                                    













                                                                        











                                         

























































                                                            
                                          




















































































































































































































































































































































                                                                                                
//! STT WebSocket client for testing the Makima STT streaming endpoint.
//!
//! This tool reads an audio file and streams it to the server via WebSocket,
//! printing transcription results as they arrive. Large files are decoded
//! and streamed asynchronously without loading the entire file into memory.

use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use anyhow::Result;
use clap::Parser;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use url::Url;

const DEFAULT_REMOTE_URL: &str = "wss://api.makima.jp/api/v1/listen";
const DEFAULT_LOCAL_URL: &str = "ws://localhost:8080/api/v1/listen";

#[derive(Parser)]
#[command(name = "stt-client")]
#[command(about = "WebSocket client for testing the Makima STT streaming endpoint")]
struct Args {
    /// Audio file to stream (supports MP3, WAV, FLAC, OGG, AAC)
    #[arg(short, long)]
    file: PathBuf,

    /// Server WebSocket URL (overrides --local)
    #[arg(short, long)]
    url: Option<String>,

    /// Use local server (ws://localhost:8080) instead of remote API
    #[arg(long, default_value = "false")]
    local: bool,

    /// Chunk size in milliseconds for streaming
    #[arg(short, long, default_value = "100")]
    chunk_ms: u32,

    /// Simulate real-time streaming (add delays between chunks)
    #[arg(long, default_value = "true")]
    realtime: bool,

    /// Show progress during streaming (may interleave with transcripts)
    #[arg(long, default_value = "false")]
    show_progress: bool,
}

impl Args {
    fn get_url(&self) -> &str {
        if let Some(ref url) = self.url {
            url
        } else if self.local {
            DEFAULT_LOCAL_URL
        } else {
            DEFAULT_REMOTE_URL
        }
    }
}

#[derive(Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ClientMessage {
    Start(StartMessage),
    Stop { reason: Option<String> },
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct StartMessage {
    sample_rate: u32,
    channels: u16,
    encoding: String,
}

#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ServerMessage {
    Ready { session_id: String },
    Transcript(TranscriptMessage),
    Error { code: String, message: String },
    Stopped { reason: String },
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct TranscriptMessage {
    speaker: String,
    start: f32,
    end: f32,
    text: String,
    is_final: bool,
}

/// Audio format information extracted from file header.
struct AudioFormat {
    sample_rate: u32,
    channels: u16,
}

/// A chunk of decoded audio samples.
struct AudioChunk {
    samples: Vec<f32>,
}

#[tokio::main]
async fn main() -> Result<()> {
    let args = Args::parse();

    // Probe audio file to get format info without decoding
    eprintln!("[INFO] Probing audio file: {:?}", args.file);
    let format = probe_audio_format(&args.file)?;
    eprintln!(
        "[INFO] Audio format: {}Hz, {} channel(s)",
        format.sample_rate, format.channels
    );

    // Connect to WebSocket
    let url = Url::parse(args.get_url())?;
    eprintln!("[INFO] Connecting to {}...", url);
    let (ws_stream, _) = connect_async(url.as_str()).await?;
    let (mut write, mut read) = ws_stream.split();

    // Send start message
    let start_msg = ClientMessage::Start(StartMessage {
        sample_rate: format.sample_rate,
        channels: format.channels,
        encoding: "pcm32f".to_string(),
    });
    write
        .send(Message::Text(serde_json::to_string(&start_msg)?))
        .await?;
    eprintln!("[INFO] Sent start message");

    // Flag to signal when session is stopped
    let session_stopped = Arc::new(AtomicBool::new(false));
    let session_stopped_clone = session_stopped.clone();

    // Spawn task to receive and print messages
    let receiver = tokio::spawn(async move {
        while let Some(msg) = read.next().await {
            match msg {
                Ok(Message::Text(text)) => match serde_json::from_str::<ServerMessage>(&text) {
                    Ok(ServerMessage::Ready { session_id }) => {
                        eprintln!("[INFO] Session ready: {}", session_id);
                    }
                    Ok(ServerMessage::Transcript(t)) => {
                        let final_marker = if t.is_final { " [FINAL]" } else { "" };
                        println!(
                            "[{:.2}s - {:.2}s] {}: {}{}",
                            t.start, t.end, t.speaker, t.text, final_marker
                        );
                    }
                    Ok(ServerMessage::Error { code, message }) => {
                        eprintln!("[ERROR] {}: {}", code, message);
                    }
                    Ok(ServerMessage::Stopped { reason }) => {
                        eprintln!("[INFO] Session stopped: {}", reason);
                        session_stopped_clone.store(true, Ordering::SeqCst);
                        break;
                    }
                    Err(e) => {
                        eprintln!("[ERROR] Failed to parse message: {}", e);
                        eprintln!("[DEBUG] Raw message: {}", text);
                    }
                },
                Ok(Message::Close(_)) => {
                    eprintln!("[INFO] Connection closed by server");
                    break;
                }
                Err(e) => {
                    eprintln!("[ERROR] WebSocket error: {}", e);
                    break;
                }
                _ => {}
            }
        }
    });

    // Calculate chunk size in samples
    let chunk_samples =
        (format.sample_rate * args.chunk_ms / 1000) as usize * format.channels as usize;

    // Create channel for streaming decoded audio chunks
    let (audio_tx, mut audio_rx) = mpsc::channel::<AudioChunk>(32);

    // Spawn blocking task to decode audio file and send chunks
    let file_path = args.file.clone();
    let decoder_handle = tokio::task::spawn_blocking(move || {
        decode_audio_streaming(&file_path, chunk_samples, audio_tx)
    });

    eprintln!(
        "[INFO] Streaming audio in {} ms chunks ({} samples each)...",
        args.chunk_ms, chunk_samples
    );

    // Stream chunks as they're decoded
    let mut chunks_sent = 0usize;
    while let Some(chunk) = audio_rx.recv().await {
        // Convert f32 samples to bytes (little-endian)
        let bytes: Vec<u8> = chunk.samples.iter().flat_map(|&s| s.to_le_bytes()).collect();

        write.send(Message::Binary(bytes.into())).await?;
        chunks_sent += 1;

        // Progress indicator
        if args.show_progress && chunks_sent % 50 == 0 {
            eprintln!("[PROGRESS] {} chunks streamed", chunks_sent);
        }

        // Simulate real-time streaming if enabled
        if args.realtime {
            tokio::time::sleep(tokio::time::Duration::from_millis(args.chunk_ms as u64)).await;
        }
    }

    // Wait for decoder to finish and check for errors
    decoder_handle.await??;

    eprintln!("[INFO] Streaming complete: {} chunks sent", chunks_sent);

    // Send stop message
    let stop_msg = ClientMessage::Stop {
        reason: Some("end_of_file".to_string()),
    };
    write
        .send(Message::Text(serde_json::to_string(&stop_msg)?))
        .await?;
    eprintln!("[INFO] Sent stop message, waiting for final results...");

    // Wait for receiver to finish with a timeout
    let timeout = tokio::time::Duration::from_secs(30);
    match tokio::time::timeout(timeout, receiver).await {
        Ok(result) => {
            result?;
        }
        Err(_) => {
            if session_stopped.load(Ordering::SeqCst) {
                eprintln!("[INFO] Session completed");
            } else {
                eprintln!("[WARN] Timeout waiting for server response");
            }
        }
    }

    eprintln!("[INFO] Done!");
    Ok(())
}

/// Probe an audio file to extract format information without decoding.
fn probe_audio_format(path: &PathBuf) -> Result<AudioFormat> {
    use symphonia::core::codecs::CODEC_TYPE_NULL;
    use symphonia::core::formats::FormatOptions;
    use symphonia::core::io::MediaSourceStream;
    use symphonia::core::meta::MetadataOptions;
    use symphonia::core::probe::Hint;

    let file = std::fs::File::open(path)?;
    let mss = MediaSourceStream::new(Box::new(file), Default::default());

    let mut hint = Hint::new();
    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
        hint.with_extension(ext);
    }

    let probed = symphonia::default::get_probe().format(
        &hint,
        mss,
        &FormatOptions::default(),
        &MetadataOptions::default(),
    )?;

    let track = probed
        .format
        .tracks()
        .iter()
        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
        .ok_or_else(|| anyhow::anyhow!("No audio track found"))?;

    let sample_rate = track.codec_params.sample_rate.unwrap_or(16000);
    let channels = track
        .codec_params
        .channels
        .map(|c| c.count() as u16)
        .unwrap_or(1);

    Ok(AudioFormat {
        sample_rate,
        channels,
    })
}

/// Decode audio file and stream chunks through the channel.
/// This runs in a blocking thread to avoid blocking the async runtime.
fn decode_audio_streaming(
    path: &PathBuf,
    chunk_samples: usize,
    tx: mpsc::Sender<AudioChunk>,
) -> Result<()> {
    use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
    use symphonia::core::formats::FormatOptions;
    use symphonia::core::io::MediaSourceStream;
    use symphonia::core::meta::MetadataOptions;
    use symphonia::core::probe::Hint;

    let file = std::fs::File::open(path)?;
    let mss = MediaSourceStream::new(Box::new(file), Default::default());

    let mut hint = Hint::new();
    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
        hint.with_extension(ext);
    }

    let probed = symphonia::default::get_probe().format(
        &hint,
        mss,
        &FormatOptions::default(),
        &MetadataOptions::default(),
    )?;

    let mut format = probed.format;
    let track = format
        .tracks()
        .iter()
        .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
        .ok_or_else(|| anyhow::anyhow!("No audio track found"))?;

    let track_id = track.id;
    let mut decoder =
        symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;

    // Buffer for accumulating samples until we have a full chunk
    let mut sample_buffer: Vec<f32> = Vec::with_capacity(chunk_samples * 2);

    loop {
        let packet = match format.next_packet() {
            Ok(p) => p,
            Err(symphonia::core::errors::Error::IoError(ref e))
                if e.kind() == std::io::ErrorKind::UnexpectedEof =>
            {
                break;
            }
            Err(symphonia::core::errors::Error::ResetRequired) => {
                decoder.reset();
                continue;
            }
            Err(_) => break,
        };

        if packet.track_id() != track_id {
            continue;
        }

        let decoded = match decoder.decode(&packet) {
            Ok(d) => d,
            Err(symphonia::core::errors::Error::DecodeError(_)) => continue,
            Err(_) => continue,
        };

        // Append decoded samples to buffer
        append_samples(&decoded, &mut sample_buffer);

        // Send complete chunks as they become available
        while sample_buffer.len() >= chunk_samples {
            let chunk: Vec<f32> = sample_buffer.drain(..chunk_samples).collect();
            if tx.blocking_send(AudioChunk { samples: chunk }).is_err() {
                // Receiver dropped, stop decoding
                return Ok(());
            }
        }
    }

    // Send any remaining samples as a final partial chunk
    if !sample_buffer.is_empty() {
        let _ = tx.blocking_send(AudioChunk {
            samples: sample_buffer,
        });
    }

    Ok(())
}

/// Append decoded audio samples to the output buffer.
fn append_samples(buffer: &symphonia::core::audio::AudioBufferRef, out: &mut Vec<f32>) {
    use symphonia::core::audio::{AudioBufferRef, Signal};

    match buffer {
        AudioBufferRef::U8(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f32 - 128.0) / 128.0);
                }
            }
        }
        AudioBufferRef::U16(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f32 - 32768.0) / 32768.0);
                }
            }
        }
        AudioBufferRef::U24(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0);
                }
            }
        }
        AudioBufferRef::U32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0);
                }
            }
        }
        AudioBufferRef::S8(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 128.0);
                }
            }
        }
        AudioBufferRef::S16(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 32768.0);
                }
            }
        }
        AudioBufferRef::S24(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame].inner() as f32 / 8388608.0);
                }
            }
        }
        AudioBufferRef::S32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32 / 2147483648.0);
                }
            }
        }
        AudioBufferRef::F32(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame]);
                }
            }
        }
        AudioBufferRef::F64(buf) => {
            for frame in 0..buf.frames() {
                for plane in buf.planes().planes() {
                    out.push(plane[frame] as f32);
                }
            }
        }
    }
}