//! STT WebSocket client for testing the Makima STT streaming endpoint. //! //! This tool reads an audio file and streams it to the server via WebSocket, //! printing transcription results as they arrive. Large files are decoded //! and streamed asynchronously without loading the entire file into memory. use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use anyhow::Result; use clap::Parser; use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tokio_tungstenite::{connect_async, tungstenite::Message}; use url::Url; #[derive(Parser)] #[command(name = "stt-client")] #[command(about = "WebSocket client for testing the Makima STT streaming endpoint")] struct Args { /// Audio file to stream (supports MP3, WAV, FLAC, OGG, AAC) #[arg(short, long)] file: PathBuf, /// Server WebSocket URL #[arg(short, long, default_value = "ws://localhost:8080/api/v1/listen")] url: String, /// Chunk size in milliseconds for streaming #[arg(short, long, default_value = "100")] chunk_ms: u32, /// Simulate real-time streaming (add delays between chunks) #[arg(long, default_value = "true")] realtime: bool, /// Show progress during streaming (may interleave with transcripts) #[arg(long, default_value = "false")] show_progress: bool, } #[derive(Serialize)] #[serde(tag = "type", rename_all = "camelCase")] enum ClientMessage { Start(StartMessage), Stop { reason: Option }, } #[derive(Serialize)] #[serde(rename_all = "camelCase")] struct StartMessage { sample_rate: u32, channels: u16, encoding: String, } #[derive(Deserialize, Debug)] #[serde(tag = "type", rename_all = "camelCase")] enum ServerMessage { Ready { session_id: String }, Transcript(TranscriptMessage), Error { code: String, message: String }, Stopped { reason: String }, } #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] struct TranscriptMessage { speaker: String, start: f32, end: f32, text: String, is_final: bool, } /// Audio format information extracted from file header. struct AudioFormat { sample_rate: u32, channels: u16, } /// A chunk of decoded audio samples. struct AudioChunk { samples: Vec, } #[tokio::main] async fn main() -> Result<()> { let args = Args::parse(); // Probe audio file to get format info without decoding eprintln!("[INFO] Probing audio file: {:?}", args.file); let format = probe_audio_format(&args.file)?; eprintln!( "[INFO] Audio format: {}Hz, {} channel(s)", format.sample_rate, format.channels ); // Connect to WebSocket let url = Url::parse(&args.url)?; eprintln!("[INFO] Connecting to {}...", url); let (ws_stream, _) = connect_async(url.as_str()).await?; let (mut write, mut read) = ws_stream.split(); // Send start message let start_msg = ClientMessage::Start(StartMessage { sample_rate: format.sample_rate, channels: format.channels, encoding: "pcm32f".to_string(), }); write .send(Message::Text(serde_json::to_string(&start_msg)?)) .await?; eprintln!("[INFO] Sent start message"); // Flag to signal when session is stopped let session_stopped = Arc::new(AtomicBool::new(false)); let session_stopped_clone = session_stopped.clone(); // Spawn task to receive and print messages let receiver = tokio::spawn(async move { while let Some(msg) = read.next().await { match msg { Ok(Message::Text(text)) => match serde_json::from_str::(&text) { Ok(ServerMessage::Ready { session_id }) => { eprintln!("[INFO] Session ready: {}", session_id); } Ok(ServerMessage::Transcript(t)) => { let final_marker = if t.is_final { " [FINAL]" } else { "" }; println!( "[{:.2}s - {:.2}s] {}: {}{}", t.start, t.end, t.speaker, t.text, final_marker ); } Ok(ServerMessage::Error { code, message }) => { eprintln!("[ERROR] {}: {}", code, message); } Ok(ServerMessage::Stopped { reason }) => { eprintln!("[INFO] Session stopped: {}", reason); session_stopped_clone.store(true, Ordering::SeqCst); break; } Err(e) => { eprintln!("[ERROR] Failed to parse message: {}", e); eprintln!("[DEBUG] Raw message: {}", text); } }, Ok(Message::Close(_)) => { eprintln!("[INFO] Connection closed by server"); break; } Err(e) => { eprintln!("[ERROR] WebSocket error: {}", e); break; } _ => {} } } }); // Calculate chunk size in samples let chunk_samples = (format.sample_rate * args.chunk_ms / 1000) as usize * format.channels as usize; // Create channel for streaming decoded audio chunks let (audio_tx, mut audio_rx) = mpsc::channel::(32); // Spawn blocking task to decode audio file and send chunks let file_path = args.file.clone(); let decoder_handle = tokio::task::spawn_blocking(move || { decode_audio_streaming(&file_path, chunk_samples, audio_tx) }); eprintln!( "[INFO] Streaming audio in {} ms chunks ({} samples each)...", args.chunk_ms, chunk_samples ); // Stream chunks as they're decoded let mut chunks_sent = 0usize; while let Some(chunk) = audio_rx.recv().await { // Convert f32 samples to bytes (little-endian) let bytes: Vec = chunk.samples.iter().flat_map(|&s| s.to_le_bytes()).collect(); write.send(Message::Binary(bytes.into())).await?; chunks_sent += 1; // Progress indicator if args.show_progress && chunks_sent % 50 == 0 { eprintln!("[PROGRESS] {} chunks streamed", chunks_sent); } // Simulate real-time streaming if enabled if args.realtime { tokio::time::sleep(tokio::time::Duration::from_millis(args.chunk_ms as u64)).await; } } // Wait for decoder to finish and check for errors decoder_handle.await??; eprintln!("[INFO] Streaming complete: {} chunks sent", chunks_sent); // Send stop message let stop_msg = ClientMessage::Stop { reason: Some("end_of_file".to_string()), }; write .send(Message::Text(serde_json::to_string(&stop_msg)?)) .await?; eprintln!("[INFO] Sent stop message, waiting for final results..."); // Wait for receiver to finish with a timeout let timeout = tokio::time::Duration::from_secs(30); match tokio::time::timeout(timeout, receiver).await { Ok(result) => { result?; } Err(_) => { if session_stopped.load(Ordering::SeqCst) { eprintln!("[INFO] Session completed"); } else { eprintln!("[WARN] Timeout waiting for server response"); } } } eprintln!("[INFO] Done!"); Ok(()) } /// Probe an audio file to extract format information without decoding. fn probe_audio_format(path: &PathBuf) -> Result { use symphonia::core::codecs::CODEC_TYPE_NULL; use symphonia::core::formats::FormatOptions; use symphonia::core::io::MediaSourceStream; use symphonia::core::meta::MetadataOptions; use symphonia::core::probe::Hint; let file = std::fs::File::open(path)?; let mss = MediaSourceStream::new(Box::new(file), Default::default()); let mut hint = Hint::new(); if let Some(ext) = path.extension().and_then(|e| e.to_str()) { hint.with_extension(ext); } let probed = symphonia::default::get_probe().format( &hint, mss, &FormatOptions::default(), &MetadataOptions::default(), )?; let track = probed .format .tracks() .iter() .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; let sample_rate = track.codec_params.sample_rate.unwrap_or(16000); let channels = track .codec_params .channels .map(|c| c.count() as u16) .unwrap_or(1); Ok(AudioFormat { sample_rate, channels, }) } /// Decode audio file and stream chunks through the channel. /// This runs in a blocking thread to avoid blocking the async runtime. fn decode_audio_streaming( path: &PathBuf, chunk_samples: usize, tx: mpsc::Sender, ) -> Result<()> { use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; use symphonia::core::formats::FormatOptions; use symphonia::core::io::MediaSourceStream; use symphonia::core::meta::MetadataOptions; use symphonia::core::probe::Hint; let file = std::fs::File::open(path)?; let mss = MediaSourceStream::new(Box::new(file), Default::default()); let mut hint = Hint::new(); if let Some(ext) = path.extension().and_then(|e| e.to_str()) { hint.with_extension(ext); } let probed = symphonia::default::get_probe().format( &hint, mss, &FormatOptions::default(), &MetadataOptions::default(), )?; let mut format = probed.format; let track = format .tracks() .iter() .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; let track_id = track.id; let mut decoder = symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?; // Buffer for accumulating samples until we have a full chunk let mut sample_buffer: Vec = Vec::with_capacity(chunk_samples * 2); loop { let packet = match format.next_packet() { Ok(p) => p, Err(symphonia::core::errors::Error::IoError(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { break; } Err(symphonia::core::errors::Error::ResetRequired) => { decoder.reset(); continue; } Err(_) => break, }; if packet.track_id() != track_id { continue; } let decoded = match decoder.decode(&packet) { Ok(d) => d, Err(symphonia::core::errors::Error::DecodeError(_)) => continue, Err(_) => continue, }; // Append decoded samples to buffer append_samples(&decoded, &mut sample_buffer); // Send complete chunks as they become available while sample_buffer.len() >= chunk_samples { let chunk: Vec = sample_buffer.drain(..chunk_samples).collect(); if tx.blocking_send(AudioChunk { samples: chunk }).is_err() { // Receiver dropped, stop decoding return Ok(()); } } } // Send any remaining samples as a final partial chunk if !sample_buffer.is_empty() { let _ = tx.blocking_send(AudioChunk { samples: sample_buffer, }); } Ok(()) } /// Append decoded audio samples to the output buffer. fn append_samples(buffer: &symphonia::core::audio::AudioBufferRef, out: &mut Vec) { use symphonia::core::audio::{AudioBufferRef, Signal}; match buffer { AudioBufferRef::U8(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push((plane[frame] as f32 - 128.0) / 128.0); } } } AudioBufferRef::U16(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push((plane[frame] as f32 - 32768.0) / 32768.0); } } } AudioBufferRef::U24(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0); } } } AudioBufferRef::U32(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0); } } } AudioBufferRef::S8(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame] as f32 / 128.0); } } } AudioBufferRef::S16(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame] as f32 / 32768.0); } } } AudioBufferRef::S24(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame].inner() as f32 / 8388608.0); } } } AudioBufferRef::S32(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame] as f32 / 2147483648.0); } } } AudioBufferRef::F32(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame]); } } } AudioBufferRef::F64(buf) => { for frame in 0..buf.frames() { for plane in buf.planes().planes() { out.push(plane[frame] as f32); } } } } }