summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-20 15:36:04 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit01088f4f1915e36a7d0d8d8756f62f8207a48911 (patch)
tree8fdbba900f3f4bba32bae76e2e0378848a90cf93 /tools
parentab9166170043ba5e0ce974e5b7accf0939d686e3 (diff)
downloadsoryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.tar.gz
soryu-01088f4f1915e36a7d0d8d8756f62f8207a48911.zip
Implement makima listen websockets server
Diffstat (limited to 'tools')
-rw-r--r--tools/stt-client/Cargo.toml16
-rw-r--r--tools/stt-client/src/main.rs443
2 files changed, 459 insertions, 0 deletions
diff --git a/tools/stt-client/Cargo.toml b/tools/stt-client/Cargo.toml
new file mode 100644
index 0000000..314b27f
--- /dev/null
+++ b/tools/stt-client/Cargo.toml
@@ -0,0 +1,16 @@
+[package]
+name = "stt-client"
+version = "0.1.0"
+edition = "2021"
+description = "WebSocket client for testing the Makima STT streaming endpoint"
+
+[dependencies]
+tokio = { version = "1.0", features = ["full"] }
+tokio-tungstenite = { version = "0.24", features = ["native-tls"] }
+futures = "0.3"
+serde = { version = "1.0", features = ["derive"] }
+serde_json = "1.0"
+clap = { version = "4", features = ["derive"] }
+symphonia = { version = "0.5", features = ["mp3", "aac", "flac", "ogg", "vorbis", "wav", "pcm"] }
+anyhow = "1.0"
+url = "2.5"
diff --git a/tools/stt-client/src/main.rs b/tools/stt-client/src/main.rs
new file mode 100644
index 0000000..73a9f8e
--- /dev/null
+++ b/tools/stt-client/src/main.rs
@@ -0,0 +1,443 @@
+//! STT WebSocket client for testing the Makima STT streaming endpoint.
+//!
+//! This tool reads an audio file and streams it to the server via WebSocket,
+//! printing transcription results as they arrive. Large files are decoded
+//! and streamed asynchronously without loading the entire file into memory.
+
+use std::path::PathBuf;
+use std::sync::atomic::{AtomicBool, Ordering};
+use std::sync::Arc;
+
+use anyhow::Result;
+use clap::Parser;
+use futures::{SinkExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use tokio::sync::mpsc;
+use tokio_tungstenite::{connect_async, tungstenite::Message};
+use url::Url;
+
+#[derive(Parser)]
+#[command(name = "stt-client")]
+#[command(about = "WebSocket client for testing the Makima STT streaming endpoint")]
+struct Args {
+ /// Audio file to stream (supports MP3, WAV, FLAC, OGG, AAC)
+ #[arg(short, long)]
+ file: PathBuf,
+
+ /// Server WebSocket URL
+ #[arg(short, long, default_value = "ws://localhost:8080/api/v1/listen")]
+ url: String,
+
+ /// Chunk size in milliseconds for streaming
+ #[arg(short, long, default_value = "100")]
+ chunk_ms: u32,
+
+ /// Simulate real-time streaming (add delays between chunks)
+ #[arg(long, default_value = "true")]
+ realtime: bool,
+
+ /// Show progress during streaming (may interleave with transcripts)
+ #[arg(long, default_value = "false")]
+ show_progress: bool,
+}
+
+#[derive(Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+enum ClientMessage {
+ Start(StartMessage),
+ Stop { reason: Option<String> },
+}
+
+#[derive(Serialize)]
+#[serde(rename_all = "camelCase")]
+struct StartMessage {
+ sample_rate: u32,
+ channels: u16,
+ encoding: String,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "camelCase")]
+enum ServerMessage {
+ Ready { session_id: String },
+ Transcript(TranscriptMessage),
+ Error { code: String, message: String },
+ Stopped { reason: String },
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(rename_all = "camelCase")]
+struct TranscriptMessage {
+ speaker: String,
+ start: f32,
+ end: f32,
+ text: String,
+ is_final: bool,
+}
+
+/// Audio format information extracted from file header.
+struct AudioFormat {
+ sample_rate: u32,
+ channels: u16,
+}
+
+/// A chunk of decoded audio samples.
+struct AudioChunk {
+ samples: Vec<f32>,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ let args = Args::parse();
+
+ // Probe audio file to get format info without decoding
+ eprintln!("[INFO] Probing audio file: {:?}", args.file);
+ let format = probe_audio_format(&args.file)?;
+ eprintln!(
+ "[INFO] Audio format: {}Hz, {} channel(s)",
+ format.sample_rate, format.channels
+ );
+
+ // Connect to WebSocket
+ let url = Url::parse(&args.url)?;
+ eprintln!("[INFO] Connecting to {}...", url);
+ let (ws_stream, _) = connect_async(url.as_str()).await?;
+ let (mut write, mut read) = ws_stream.split();
+
+ // Send start message
+ let start_msg = ClientMessage::Start(StartMessage {
+ sample_rate: format.sample_rate,
+ channels: format.channels,
+ encoding: "pcm32f".to_string(),
+ });
+ write
+ .send(Message::Text(serde_json::to_string(&start_msg)?))
+ .await?;
+ eprintln!("[INFO] Sent start message");
+
+ // Flag to signal when session is stopped
+ let session_stopped = Arc::new(AtomicBool::new(false));
+ let session_stopped_clone = session_stopped.clone();
+
+ // Spawn task to receive and print messages
+ let receiver = tokio::spawn(async move {
+ while let Some(msg) = read.next().await {
+ match msg {
+ Ok(Message::Text(text)) => match serde_json::from_str::<ServerMessage>(&text) {
+ Ok(ServerMessage::Ready { session_id }) => {
+ eprintln!("[INFO] Session ready: {}", session_id);
+ }
+ Ok(ServerMessage::Transcript(t)) => {
+ let final_marker = if t.is_final { " [FINAL]" } else { "" };
+ println!(
+ "[{:.2}s - {:.2}s] {}: {}{}",
+ t.start, t.end, t.speaker, t.text, final_marker
+ );
+ }
+ Ok(ServerMessage::Error { code, message }) => {
+ eprintln!("[ERROR] {}: {}", code, message);
+ }
+ Ok(ServerMessage::Stopped { reason }) => {
+ eprintln!("[INFO] Session stopped: {}", reason);
+ session_stopped_clone.store(true, Ordering::SeqCst);
+ break;
+ }
+ Err(e) => {
+ eprintln!("[ERROR] Failed to parse message: {}", e);
+ eprintln!("[DEBUG] Raw message: {}", text);
+ }
+ },
+ Ok(Message::Close(_)) => {
+ eprintln!("[INFO] Connection closed by server");
+ break;
+ }
+ Err(e) => {
+ eprintln!("[ERROR] WebSocket error: {}", e);
+ break;
+ }
+ _ => {}
+ }
+ }
+ });
+
+ // Calculate chunk size in samples
+ let chunk_samples =
+ (format.sample_rate * args.chunk_ms / 1000) as usize * format.channels as usize;
+
+ // Create channel for streaming decoded audio chunks
+ let (audio_tx, mut audio_rx) = mpsc::channel::<AudioChunk>(32);
+
+ // Spawn blocking task to decode audio file and send chunks
+ let file_path = args.file.clone();
+ let decoder_handle = tokio::task::spawn_blocking(move || {
+ decode_audio_streaming(&file_path, chunk_samples, audio_tx)
+ });
+
+ eprintln!(
+ "[INFO] Streaming audio in {} ms chunks ({} samples each)...",
+ args.chunk_ms, chunk_samples
+ );
+
+ // Stream chunks as they're decoded
+ let mut chunks_sent = 0usize;
+ while let Some(chunk) = audio_rx.recv().await {
+ // Convert f32 samples to bytes (little-endian)
+ let bytes: Vec<u8> = chunk.samples.iter().flat_map(|&s| s.to_le_bytes()).collect();
+
+ write.send(Message::Binary(bytes.into())).await?;
+ chunks_sent += 1;
+
+ // Progress indicator
+ if args.show_progress && chunks_sent % 50 == 0 {
+ eprintln!("[PROGRESS] {} chunks streamed", chunks_sent);
+ }
+
+ // Simulate real-time streaming if enabled
+ if args.realtime {
+ tokio::time::sleep(tokio::time::Duration::from_millis(args.chunk_ms as u64)).await;
+ }
+ }
+
+ // Wait for decoder to finish and check for errors
+ decoder_handle.await??;
+
+ eprintln!("[INFO] Streaming complete: {} chunks sent", chunks_sent);
+
+ // Send stop message
+ let stop_msg = ClientMessage::Stop {
+ reason: Some("end_of_file".to_string()),
+ };
+ write
+ .send(Message::Text(serde_json::to_string(&stop_msg)?))
+ .await?;
+ eprintln!("[INFO] Sent stop message, waiting for final results...");
+
+ // Wait for receiver to finish with a timeout
+ let timeout = tokio::time::Duration::from_secs(30);
+ match tokio::time::timeout(timeout, receiver).await {
+ Ok(result) => {
+ result?;
+ }
+ Err(_) => {
+ if session_stopped.load(Ordering::SeqCst) {
+ eprintln!("[INFO] Session completed");
+ } else {
+ eprintln!("[WARN] Timeout waiting for server response");
+ }
+ }
+ }
+
+ eprintln!("[INFO] Done!");
+ Ok(())
+}
+
+/// Probe an audio file to extract format information without decoding.
+fn probe_audio_format(path: &PathBuf) -> Result<AudioFormat> {
+ use symphonia::core::codecs::CODEC_TYPE_NULL;
+ use symphonia::core::formats::FormatOptions;
+ use symphonia::core::io::MediaSourceStream;
+ use symphonia::core::meta::MetadataOptions;
+ use symphonia::core::probe::Hint;
+
+ let file = std::fs::File::open(path)?;
+ let mss = MediaSourceStream::new(Box::new(file), Default::default());
+
+ let mut hint = Hint::new();
+ if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
+ hint.with_extension(ext);
+ }
+
+ let probed = symphonia::default::get_probe().format(
+ &hint,
+ mss,
+ &FormatOptions::default(),
+ &MetadataOptions::default(),
+ )?;
+
+ let track = probed
+ .format
+ .tracks()
+ .iter()
+ .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
+ .ok_or_else(|| anyhow::anyhow!("No audio track found"))?;
+
+ let sample_rate = track.codec_params.sample_rate.unwrap_or(16000);
+ let channels = track
+ .codec_params
+ .channels
+ .map(|c| c.count() as u16)
+ .unwrap_or(1);
+
+ Ok(AudioFormat {
+ sample_rate,
+ channels,
+ })
+}
+
+/// Decode audio file and stream chunks through the channel.
+/// This runs in a blocking thread to avoid blocking the async runtime.
+fn decode_audio_streaming(
+ path: &PathBuf,
+ chunk_samples: usize,
+ tx: mpsc::Sender<AudioChunk>,
+) -> Result<()> {
+ use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
+ use symphonia::core::formats::FormatOptions;
+ use symphonia::core::io::MediaSourceStream;
+ use symphonia::core::meta::MetadataOptions;
+ use symphonia::core::probe::Hint;
+
+ let file = std::fs::File::open(path)?;
+ let mss = MediaSourceStream::new(Box::new(file), Default::default());
+
+ let mut hint = Hint::new();
+ if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
+ hint.with_extension(ext);
+ }
+
+ let probed = symphonia::default::get_probe().format(
+ &hint,
+ mss,
+ &FormatOptions::default(),
+ &MetadataOptions::default(),
+ )?;
+
+ let mut format = probed.format;
+ let track = format
+ .tracks()
+ .iter()
+ .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
+ .ok_or_else(|| anyhow::anyhow!("No audio track found"))?;
+
+ let track_id = track.id;
+ let mut decoder =
+ symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
+
+ // Buffer for accumulating samples until we have a full chunk
+ let mut sample_buffer: Vec<f32> = Vec::with_capacity(chunk_samples * 2);
+
+ loop {
+ let packet = match format.next_packet() {
+ Ok(p) => p,
+ Err(symphonia::core::errors::Error::IoError(ref e))
+ if e.kind() == std::io::ErrorKind::UnexpectedEof =>
+ {
+ break;
+ }
+ Err(symphonia::core::errors::Error::ResetRequired) => {
+ decoder.reset();
+ continue;
+ }
+ Err(_) => break,
+ };
+
+ if packet.track_id() != track_id {
+ continue;
+ }
+
+ let decoded = match decoder.decode(&packet) {
+ Ok(d) => d,
+ Err(symphonia::core::errors::Error::DecodeError(_)) => continue,
+ Err(_) => continue,
+ };
+
+ // Append decoded samples to buffer
+ append_samples(&decoded, &mut sample_buffer);
+
+ // Send complete chunks as they become available
+ while sample_buffer.len() >= chunk_samples {
+ let chunk: Vec<f32> = sample_buffer.drain(..chunk_samples).collect();
+ if tx.blocking_send(AudioChunk { samples: chunk }).is_err() {
+ // Receiver dropped, stop decoding
+ return Ok(());
+ }
+ }
+ }
+
+ // Send any remaining samples as a final partial chunk
+ if !sample_buffer.is_empty() {
+ let _ = tx.blocking_send(AudioChunk {
+ samples: sample_buffer,
+ });
+ }
+
+ Ok(())
+}
+
+/// Append decoded audio samples to the output buffer.
+fn append_samples(buffer: &symphonia::core::audio::AudioBufferRef, out: &mut Vec<f32>) {
+ use symphonia::core::audio::{AudioBufferRef, Signal};
+
+ match buffer {
+ AudioBufferRef::U8(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push((plane[frame] as f32 - 128.0) / 128.0);
+ }
+ }
+ }
+ AudioBufferRef::U16(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push((plane[frame] as f32 - 32768.0) / 32768.0);
+ }
+ }
+ }
+ AudioBufferRef::U24(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push((plane[frame].inner() as f32 - 8388608.0) / 8388608.0);
+ }
+ }
+ }
+ AudioBufferRef::U32(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push((plane[frame] as f64 - 2147483648.0) as f32 / 2147483648.0);
+ }
+ }
+ }
+ AudioBufferRef::S8(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame] as f32 / 128.0);
+ }
+ }
+ }
+ AudioBufferRef::S16(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame] as f32 / 32768.0);
+ }
+ }
+ }
+ AudioBufferRef::S24(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame].inner() as f32 / 8388608.0);
+ }
+ }
+ }
+ AudioBufferRef::S32(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame] as f32 / 2147483648.0);
+ }
+ }
+ }
+ AudioBufferRef::F32(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame]);
+ }
+ }
+ }
+ AudioBufferRef::F64(buf) => {
+ for frame in 0..buf.frames() {
+ for plane in buf.planes().planes() {
+ out.push(plane[frame] as f32);
+ }
+ }
+ }
+ }
+}