//! STT WebSocket client for testing the Makima STT streaming endpoint.
//!
//! This tool reads an audio file and streams it to the server via WebSocket,
//! printing transcription results as they arrive. Large files are decoded
//! and streamed asynchronously without loading the entire file into memory.
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use anyhow::Result;
use clap::Parser;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use url::Url;
#[derive(Parser)]
#[command(name = "stt-client")]
#[command(about = "WebSocket client for testing the Makima STT streaming endpoint")]
struct Args {
/// Audio file to stream (supports MP3, WAV, FLAC, OGG, AAC)
#[arg(short, long)]
file: PathBuf,
/// Server WebSocket URL
#[arg(short, long, default_value = "ws://localhost:8080/api/v1/listen")]
url: String,
/// Chunk size in milliseconds for streaming
#[arg(short, long, default_value = "100")]
chunk_ms: u32,
/// Simulate real-time streaming (add delays between chunks)
#[arg(long, default_value = "true")]
realtime: bool,
/// Show progress during streaming (may interleave with transcripts)
#[arg(long, default_value = "false")]
show_progress: bool,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ClientMessage {
Start(StartMessage),
Stop { reason: Option<String> },
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct StartMessage {
sample_rate: u32,
channels: u16,
encoding: String,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ServerMessage {
Ready { session_id: String },
Transcript(TranscriptMessage),
Error { code: String, message: String },
Stopped { reason: String },
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct TranscriptMessage {
speaker: String,
start: f32,
end: f32,
text: String,
is_final: bool,
}
/// Audio format information extracted from file header.
struct AudioFormat {
sample_rate: u32,
channels: u16,
}
/// A chunk of decoded audio samples.
struct AudioChunk {
samples: Vec<f32>,
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
// Probe audio file to get format info without decoding
eprintln!("[INFO] Probing audio file: {:?}", args.file);
let format = probe_audio_format(&args.file)?;
eprintln!(
"[INFO] Audio format: {}Hz, {} channel(s)",
format.sample_rate, format.channels
);
// Connect to WebSocket
let url = Url::parse(&args.url)?;
eprintln!("[INFO] Connecting to {}...", url);
let (ws_stream, _) = connect_async(url.as_str()).await?;
let (mut write, mut read) = ws_stream.split();
// Send start message
let start_msg = ClientMessage::Start(StartMessage {
sample_rate: format.sample_rate,
channels: format.channels,
encoding: "pcm32f".to_string(),
});
write
.send(Message::Text(serde_json::to_string(&start_msg)?))
.await?;
eprintln!("[INFO] Sent start message");
// Flag to signal when session is stopped
let session_stopped = Arc::new(AtomicBool::new(false));
let session_stopped_clone = session_stopped.clone();
// Spawn task to receive and print messages
let receiver = tokio::spawn(async move {
while let Some(msg) = read.next().await {
match msg {
Ok(Message::Text(text)) => match serde_json::from_str::<ServerMessage>(&text) {
Ok(ServerMessage::Ready { session_id }) => {
eprintln!("[INFO] Session ready: {}", session_id);
}
Ok(ServerMessage::Transcript(t)) => {
let final_marker = if t.is_final { " [FINAL]" } else { "" };
println!(
"[{:.2}s - {:.2}s] {}: {}{}",
t.start, t.end, t.speaker, t.text, final_marker
);
}
Ok(ServerMessage::Error { code, message }) => {
eprintln!("[ERROR] {}: {}", code, message);
}
Ok(ServerMessage::Stopped { reason }) => {
eprintln!("[INFO] Session stopped: {}", reason);
session_stopped_clone.store(true, Ordering::SeqCst);
break;
}
Err(e) => {
eprintln!("[ERROR] Failed to parse message: {}", e);
eprintln!("[DEBUG] Raw message: {}", text);
}
},
Ok(Message::Close(_)) => {
eprintln!("[INFO] Connection closed by server");
break;
}
Err(e) => {
eprintln!("[ERROR] WebSocket error: {}", e);
break;
}
_ => {}
}
}
});
// Calculate chunk size in samples
let chunk_samples =
(format.sample_rate * args.chunk_ms / 1000) as usize * format.channels as usize;
// Create channel for streaming decoded audio chunks
let (audio_tx, mut audio_rx) = mpsc::channel::<AudioChunk>(32);
// Spawn blocking task to decode audio file and send chunks
let file_path = args.file.clone();
let decoder_handle = tokio::task::spawn_blocking(move || {
decode_audio_streaming(&file_path, chunk_samples, audio_tx)
});
eprintln!(
"[INFO] Streaming audio in {} ms chunks ({} samples each)...",
args.chunk_ms, chunk_samples
);
// Stream chunks as they're decoded
let mut chunks_sent = 0usize;
while let Some(chunk) = audio_rx.recv().await {
// Convert f32 samples to bytes (little-endian)
let bytes: Vec<u8> = chunk.samples.iter().flat_map(|&s| s.to_le_bytes()).collect();
write.send(Message::Binary(bytes.into())).await?;
chunks_sent += 1;
// Progress indicator
if args.show_progress && chunks_sent % 50 == 0 {
eprintln!("[PROGRESS] {} chunks streamed", chunks_sent);
}
// Simulate real-time streaming if enabled
if args.realtime {
tokio::time::sleep(tokio::time::Duration::from_millis(args.chunk_ms as u64)).await;
}
}
// Wait for decoder to finish and check for errors
decoder_handle.await??;
eprintln!("[INFO] Streaming complete: {} chunks sent", chunks_sent);
// Send stop message
let stop_msg = ClientMessage::Stop {
reason: Some("end_of_file".to_string()),
};
write
.send(Message::Text(serde_json::to_string(&stop_msg)?))
.await?;
eprintln!("[INFO] Sent stop message, waiting for final results...");
// Wait for receiver to finish with a timeout
let timeout = tokio::time::Duration::from_secs(30);
match tokio::time::timeout(timeout, receiver).await {
Ok(result) => {
result?;
}
Err(_) => {
if session_stopped.load(Ordering::SeqCst) {
eprintln!("[INFO] Session completed");
} else {
eprintln!("[WARN] Timeout waiting for server response");
}
}
}
eprintln!("[INFO] Done!");
Ok(())
}
/// Probe an audio file to extract format information without decoding.
fn probe_audio_format(path: &PathBuf) -> Result<AudioFormat> {
use symphonia::core::codecs::CODEC_TYPE_NULL;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
let file = std::fs::File::open(path)?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let probed = symphonia::default::get_probe().format(
&hint,
mss,
&FormatOptions::default(),
&MetadataOptions::default(),
)?;
let track = probed
.format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| anyhow::anyhow!("No audio track found"))?;
let sample_rate = track.codec_params.sample_rate.unwrap_or(16000);
let channels = track
.codec_params
.channels
.map(|c| c.count() as u16)
.unwrap_or(1);
Ok(AudioFormat {
sample_rate,
channels,
})
}
/// Decode audio file and stream chunks through the channel.
/// This runs in a blocking thread to avoid blocking the async runtime.
fn decode_audio_streaming(
path: &PathBuf,
chunk_samples: usize,
tx: mpsc::Sender<AudioChunk>,
) -> Result<()> {
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
let file = std::fs::File::open(path)?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let probed = symphonia::default::get_probe().format(
&hint,
mss,
&FormatOptions::default(),
&MetadataOptions::default(),
)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.ok_or_else(|| anyhow::anyhow!("No audio track found"))?;
let track_id = track.id;
let mut decoder =
symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
// Buffer for accumulating samples until we have a full chunk
let mut sample_buffer: Vec<f32> = Vec::with_capacity(chunk_samples * 2);
loop {
let packet = match format.next_packet() {
Ok(p) => p,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(symphonia::core::errors::Error::ResetRequired) => {
decoder.reset();
continue;
}
Err(_) => break,
};
if packet.track_id() != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(d) => d,
Err(symphonia::core::errors::Error::DecodeError(_)) => continue,
Err(_) => continue,
};
// Append decoded samples to buffer
append_samples(&decoded, &mut sample_buffer);
// Send complete chunks as they become available
while sample_buffer.len() >= chunk_samples {
let chunk: Vec<f32> = sample_buffer.drain(..chunk_samples).collect();
if tx.blocking_send(AudioChunk { samples: chunk }).is_err() {
// Receiver dropped, stop decoding
return Ok(());
}
}
}
// Send any remaining samples as a final partial chunk
if !sample_buffer.is_empty() {
let _ = tx.blocking_send(AudioChunk {
samples: sample_buffer,
});
}
Ok(())
}
/// Append decoded audio samples to the output buffer.
fn append_samples(buffer: &symphonia::core::audio::AudioBufferRef, out: &mut Vec<f32>) {
use symphonia::core::audio::{AudioBufferRef, Signal};
match buffer {
AudioBufferRef::U8(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push((plane[frame] as f32 - 128.0) / 128.0);
}
}
}
AudioBufferRef::U16(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push((plane[frame] as f32 - 32768.0) / 32768.0);
}
}
}
AudioBufferRef::U24(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0);
}
}
}
AudioBufferRef::U32(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0);
}
}
}
AudioBufferRef::S8(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame] as f32 / 128.0);
}
}
}
AudioBufferRef::S16(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame] as f32 / 32768.0);
}
}
}
AudioBufferRef::S24(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame].inner() as f32 / 8388608.0);
}
}
}
AudioBufferRef::S32(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame] as f32 / 2147483648.0);
}
}
}
AudioBufferRef::F32(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame]);
}
}
}
AudioBufferRef::F64(buf) => {
for frame in 0..buf.frames() {
for plane in buf.planes().planes() {
out.push(plane[frame] as f32);
}
}
}
}
}