diff options
| author | soryu <soryu@soryu.co> | 2025-12-20 15:36:04 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 01088f4f1915e36a7d0d8d8756f62f8207a48911 (patch) | |
| tree | 8fdbba900f3f4bba32bae76e2e0378848a90cf93 /tools | |
| parent | ab9166170043ba5e0ce974e5b7accf0939d686e3 (diff) | |
| download | soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.tar.gz soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.zip | |
Implement makima listen websockets server
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/stt-client/Cargo.toml | 16 | ||||
| -rw-r--r-- | tools/stt-client/src/main.rs | 443 |
2 files changed, 459 insertions, 0 deletions
diff --git a/tools/stt-client/Cargo.toml b/tools/stt-client/Cargo.toml new file mode 100644 index 0000000..314b27f --- /dev/null +++ b/tools/stt-client/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "stt-client" +version = "0.1.0" +edition = "2021" +description = "WebSocket client for testing the Makima STT streaming endpoint" + +[dependencies] +tokio = { version = "1.0", features = ["full"] } +tokio-tungstenite = { version = "0.24", features = ["native-tls"] } +futures = "0.3" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +clap = { version = "4", features = ["derive"] } +symphonia = { version = "0.5", features = ["mp3", "aac", "flac", "ogg", "vorbis", "wav", "pcm"] } +anyhow = "1.0" +url = "2.5" diff --git a/tools/stt-client/src/main.rs b/tools/stt-client/src/main.rs new file mode 100644 index 0000000..73a9f8e --- /dev/null +++ b/tools/stt-client/src/main.rs @@ -0,0 +1,443 @@ +//! STT WebSocket client for testing the Makima STT streaming endpoint. +//! +//! This tool reads an audio file and streams it to the server via WebSocket, +//! printing transcription results as they arrive. Large files are decoded +//! and streamed asynchronously without loading the entire file into memory. + +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use anyhow::Result; +use clap::Parser; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio_tungstenite::{connect_async, tungstenite::Message}; +use url::Url; + +#[derive(Parser)] +#[command(name = "stt-client")] +#[command(about = "WebSocket client for testing the Makima STT streaming endpoint")] +struct Args { + /// Audio file to stream (supports MP3, WAV, FLAC, OGG, AAC) + #[arg(short, long)] + file: PathBuf, + + /// Server WebSocket URL + #[arg(short, long, default_value = "ws://localhost:8080/api/v1/listen")] + url: String, + + /// Chunk size in milliseconds for streaming + #[arg(short, long, default_value = "100")] + chunk_ms: u32, + + /// Simulate real-time streaming (add delays between chunks) + #[arg(long, default_value = "true")] + realtime: bool, + + /// Show progress during streaming (may interleave with transcripts) + #[arg(long, default_value = "false")] + show_progress: bool, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +enum ClientMessage { + Start(StartMessage), + Stop { reason: Option<String> }, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct StartMessage { + sample_rate: u32, + channels: u16, + encoding: String, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "camelCase")] +enum ServerMessage { + Ready { session_id: String }, + Transcript(TranscriptMessage), + Error { code: String, message: String }, + Stopped { reason: String }, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct TranscriptMessage { + speaker: String, + start: f32, + end: f32, + text: String, + is_final: bool, +} + +/// Audio format information extracted from file header. +struct AudioFormat { + sample_rate: u32, + channels: u16, +} + +/// A chunk of decoded audio samples. +struct AudioChunk { + samples: Vec<f32>, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + // Probe audio file to get format info without decoding + eprintln!("[INFO] Probing audio file: {:?}", args.file); + let format = probe_audio_format(&args.file)?; + eprintln!( + "[INFO] Audio format: {}Hz, {} channel(s)", + format.sample_rate, format.channels + ); + + // Connect to WebSocket + let url = Url::parse(&args.url)?; + eprintln!("[INFO] Connecting to {}...", url); + let (ws_stream, _) = connect_async(url.as_str()).await?; + let (mut write, mut read) = ws_stream.split(); + + // Send start message + let start_msg = ClientMessage::Start(StartMessage { + sample_rate: format.sample_rate, + channels: format.channels, + encoding: "pcm32f".to_string(), + }); + write + .send(Message::Text(serde_json::to_string(&start_msg)?)) + .await?; + eprintln!("[INFO] Sent start message"); + + // Flag to signal when session is stopped + let session_stopped = Arc::new(AtomicBool::new(false)); + let session_stopped_clone = session_stopped.clone(); + + // Spawn task to receive and print messages + let receiver = tokio::spawn(async move { + while let Some(msg) = read.next().await { + match msg { + Ok(Message::Text(text)) => match serde_json::from_str::<ServerMessage>(&text) { + Ok(ServerMessage::Ready { session_id }) => { + eprintln!("[INFO] Session ready: {}", session_id); + } + Ok(ServerMessage::Transcript(t)) => { + let final_marker = if t.is_final { " [FINAL]" } else { "" }; + println!( + "[{:.2}s - {:.2}s] {}: {}{}", + t.start, t.end, t.speaker, t.text, final_marker + ); + } + Ok(ServerMessage::Error { code, message }) => { + eprintln!("[ERROR] {}: {}", code, message); + } + Ok(ServerMessage::Stopped { reason }) => { + eprintln!("[INFO] Session stopped: {}", reason); + session_stopped_clone.store(true, Ordering::SeqCst); + break; + } + Err(e) => { + eprintln!("[ERROR] Failed to parse message: {}", e); + eprintln!("[DEBUG] Raw message: {}", text); + } + }, + Ok(Message::Close(_)) => { + eprintln!("[INFO] Connection closed by server"); + break; + } + Err(e) => { + eprintln!("[ERROR] WebSocket error: {}", e); + break; + } + _ => {} + } + } + }); + + // Calculate chunk size in samples + let chunk_samples = + (format.sample_rate * args.chunk_ms / 1000) as usize * format.channels as usize; + + // Create channel for streaming decoded audio chunks + let (audio_tx, mut audio_rx) = mpsc::channel::<AudioChunk>(32); + + // Spawn blocking task to decode audio file and send chunks + let file_path = args.file.clone(); + let decoder_handle = tokio::task::spawn_blocking(move || { + decode_audio_streaming(&file_path, chunk_samples, audio_tx) + }); + + eprintln!( + "[INFO] Streaming audio in {} ms chunks ({} samples each)...", + args.chunk_ms, chunk_samples + ); + + // Stream chunks as they're decoded + let mut chunks_sent = 0usize; + while let Some(chunk) = audio_rx.recv().await { + // Convert f32 samples to bytes (little-endian) + let bytes: Vec<u8> = chunk.samples.iter().flat_map(|&s| s.to_le_bytes()).collect(); + + write.send(Message::Binary(bytes.into())).await?; + chunks_sent += 1; + + // Progress indicator + if args.show_progress && chunks_sent % 50 == 0 { + eprintln!("[PROGRESS] {} chunks streamed", chunks_sent); + } + + // Simulate real-time streaming if enabled + if args.realtime { + tokio::time::sleep(tokio::time::Duration::from_millis(args.chunk_ms as u64)).await; + } + } + + // Wait for decoder to finish and check for errors + decoder_handle.await??; + + eprintln!("[INFO] Streaming complete: {} chunks sent", chunks_sent); + + // Send stop message + let stop_msg = ClientMessage::Stop { + reason: Some("end_of_file".to_string()), + }; + write + .send(Message::Text(serde_json::to_string(&stop_msg)?)) + .await?; + eprintln!("[INFO] Sent stop message, waiting for final results..."); + + // Wait for receiver to finish with a timeout + let timeout = tokio::time::Duration::from_secs(30); + match tokio::time::timeout(timeout, receiver).await { + Ok(result) => { + result?; + } + Err(_) => { + if session_stopped.load(Ordering::SeqCst) { + eprintln!("[INFO] Session completed"); + } else { + eprintln!("[WARN] Timeout waiting for server response"); + } + } + } + + eprintln!("[INFO] Done!"); + Ok(()) +} + +/// Probe an audio file to extract format information without decoding. +fn probe_audio_format(path: &PathBuf) -> Result<AudioFormat> { + use symphonia::core::codecs::CODEC_TYPE_NULL; + use symphonia::core::formats::FormatOptions; + use symphonia::core::io::MediaSourceStream; + use symphonia::core::meta::MetadataOptions; + use symphonia::core::probe::Hint; + + let file = std::fs::File::open(path)?; + let mss = MediaSourceStream::new(Box::new(file), Default::default()); + + let mut hint = Hint::new(); + if let Some(ext) = path.extension().and_then(|e| e.to_str()) { + hint.with_extension(ext); + } + + let probed = symphonia::default::get_probe().format( + &hint, + mss, + &FormatOptions::default(), + &MetadataOptions::default(), + )?; + + let track = probed + .format + .tracks() + .iter() + .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; + + let sample_rate = track.codec_params.sample_rate.unwrap_or(16000); + let channels = track + .codec_params + .channels + .map(|c| c.count() as u16) + .unwrap_or(1); + + Ok(AudioFormat { + sample_rate, + channels, + }) +} + +/// Decode audio file and stream chunks through the channel. +/// This runs in a blocking thread to avoid blocking the async runtime. +fn decode_audio_streaming( + path: &PathBuf, + chunk_samples: usize, + tx: mpsc::Sender<AudioChunk>, +) -> Result<()> { + use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; + use symphonia::core::formats::FormatOptions; + use symphonia::core::io::MediaSourceStream; + use symphonia::core::meta::MetadataOptions; + use symphonia::core::probe::Hint; + + let file = std::fs::File::open(path)?; + let mss = MediaSourceStream::new(Box::new(file), Default::default()); + + let mut hint = Hint::new(); + if let Some(ext) = path.extension().and_then(|e| e.to_str()) { + hint.with_extension(ext); + } + + let probed = symphonia::default::get_probe().format( + &hint, + mss, + &FormatOptions::default(), + &MetadataOptions::default(), + )?; + + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .ok_or_else(|| anyhow::anyhow!("No audio track found"))?; + + let track_id = track.id; + let mut decoder = + symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?; + + // Buffer for accumulating samples until we have a full chunk + let mut sample_buffer: Vec<f32> = Vec::with_capacity(chunk_samples * 2); + + loop { + let packet = match format.next_packet() { + Ok(p) => p, + Err(symphonia::core::errors::Error::IoError(ref e)) + if e.kind() == std::io::ErrorKind::UnexpectedEof => + { + break; + } + Err(symphonia::core::errors::Error::ResetRequired) => { + decoder.reset(); + continue; + } + Err(_) => break, + }; + + if packet.track_id() != track_id { + continue; + } + + let decoded = match decoder.decode(&packet) { + Ok(d) => d, + Err(symphonia::core::errors::Error::DecodeError(_)) => continue, + Err(_) => continue, + }; + + // Append decoded samples to buffer + append_samples(&decoded, &mut sample_buffer); + + // Send complete chunks as they become available + while sample_buffer.len() >= chunk_samples { + let chunk: Vec<f32> = sample_buffer.drain(..chunk_samples).collect(); + if tx.blocking_send(AudioChunk { samples: chunk }).is_err() { + // Receiver dropped, stop decoding + return Ok(()); + } + } + } + + // Send any remaining samples as a final partial chunk + if !sample_buffer.is_empty() { + let _ = tx.blocking_send(AudioChunk { + samples: sample_buffer, + }); + } + + Ok(()) +} + +/// Append decoded audio samples to the output buffer. +fn append_samples(buffer: &symphonia::core::audio::AudioBufferRef, out: &mut Vec<f32>) { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + match buffer { + AudioBufferRef::U8(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push((plane[frame] as f32 - 128.0) / 128.0); + } + } + } + AudioBufferRef::U16(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push((plane[frame] as f32 - 32768.0) / 32768.0); + } + } + } + AudioBufferRef::U24(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0); + } + } + } + AudioBufferRef::U32(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0); + } + } + } + AudioBufferRef::S8(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame] as f32 / 128.0); + } + } + } + AudioBufferRef::S16(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame] as f32 / 32768.0); + } + } + } + AudioBufferRef::S24(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame].inner() as f32 / 8388608.0); + } + } + } + AudioBufferRef::S32(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame] as f32 / 2147483648.0); + } + } + } + AudioBufferRef::F32(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame]); + } + } + } + AudioBufferRef::F64(buf) => { + for frame in 0..buf.frames() { + for plane in buf.planes().planes() { + out.push(plane[frame] as f32); + } + } + } + } +} |
