diff options
Diffstat (limited to 'makima/src/daemon/tui/ws_client.rs')
| -rw-r--r-- | makima/src/daemon/tui/ws_client.rs | 353 |
1 files changed, 353 insertions, 0 deletions
diff --git a/makima/src/daemon/tui/ws_client.rs b/makima/src/daemon/tui/ws_client.rs new file mode 100644 index 0000000..3462467 --- /dev/null +++ b/makima/src/daemon/tui/ws_client.rs @@ -0,0 +1,353 @@ +//! TUI WebSocket client for task output streaming. +//! +//! Uses a dedicated async thread to handle WebSocket communication, +//! bridging async/sync worlds via channels. + +use std::sync::mpsc as std_mpsc; +use std::thread; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; +use tokio::sync::mpsc as tokio_mpsc; +use uuid::Uuid; + +/// Commands sent from TUI to WebSocket client +#[derive(Debug, Clone)] +pub enum WsCommand { + /// Subscribe to task output + Subscribe { task_id: Uuid }, + /// Unsubscribe from task output + Unsubscribe { task_id: Uuid }, + /// Shutdown the WebSocket client + Shutdown, +} + +/// Events sent from WebSocket client to TUI +#[derive(Debug, Clone)] +pub enum WsEvent { + /// WebSocket connected + Connected, + /// WebSocket disconnected + Disconnected, + /// WebSocket reconnecting + Reconnecting { attempt: u32 }, + /// Subscription confirmed + Subscribed { task_id: Uuid }, + /// Unsubscription confirmed + Unsubscribed { task_id: Uuid }, + /// Task output received + TaskOutput(TaskOutputEvent), + /// Error occurred + Error { message: String }, +} + +/// Task output event from server +#[derive(Debug, Clone)] +pub struct TaskOutputEvent { + pub task_id: Uuid, + pub message_type: String, + pub content: String, + pub tool_name: Option<String>, + pub tool_input: Option<serde_json::Value>, + pub is_error: Option<bool>, + pub cost_usd: Option<f64>, + pub duration_ms: Option<u64>, + pub is_partial: bool, +} + +/// Messages sent to the WebSocket server +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +enum ClientMessage { + SubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + UnsubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, +} + +/// Messages received from the WebSocket server +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +enum ServerMessage { + OutputSubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + OutputUnsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "messageType")] + message_type: String, + content: String, + #[serde(rename = "toolName")] + tool_name: Option<String>, + #[serde(rename = "toolInput")] + tool_input: Option<serde_json::Value>, + #[serde(rename = "isError")] + is_error: Option<bool>, + #[serde(rename = "costUsd")] + cost_usd: Option<f64>, + #[serde(rename = "durationMs")] + duration_ms: Option<u64>, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + Error { + code: String, + message: String, + }, + // Other message types we don't care about + #[serde(other)] + Other, +} + +/// TUI WebSocket client handle +pub struct TuiWsClient { + /// Command sender to WebSocket thread + command_tx: tokio_mpsc::Sender<WsCommand>, + /// Event receiver from WebSocket thread + event_rx: std_mpsc::Receiver<WsEvent>, +} + +impl TuiWsClient { + /// Start a new WebSocket client in a dedicated thread + pub fn start(api_url: String, api_key: String) -> Self { + let (command_tx, command_rx) = tokio_mpsc::channel(32); + let (event_tx, event_rx) = std_mpsc::channel(); + + // Spawn as daemon thread so it doesn't block process exit + thread::Builder::new() + .name("ws-client".to_string()) + .spawn(move || { + let rt = match Runtime::new() { + Ok(rt) => rt, + Err(e) => { + let _ = event_tx.send(WsEvent::Error { + message: format!("Failed to create tokio runtime: {}", e), + }); + return; + } + }; + rt.block_on(run_ws_client(api_url, api_key, command_rx, event_tx)); + }) + .ok(); + + Self { + command_tx, + event_rx, + } + } + + /// Send a command to the WebSocket client (non-blocking) + pub fn send(&self, command: WsCommand) { + // Use try_send to avoid blocking on shutdown + let _ = self.command_tx.try_send(command); + } + + /// Subscribe to task output + pub fn subscribe(&self, task_id: Uuid) { + self.send(WsCommand::Subscribe { task_id }); + } + + /// Unsubscribe from task output + pub fn unsubscribe(&self, task_id: Uuid) { + self.send(WsCommand::Unsubscribe { task_id }); + } + + /// Shutdown the WebSocket client + pub fn shutdown(&self) { + self.send(WsCommand::Shutdown); + } + + /// Try to receive an event (non-blocking) + pub fn try_recv(&self) -> Option<WsEvent> { + self.event_rx.try_recv().ok() + } + + /// Receive an event with timeout + pub fn recv_timeout(&self, timeout: Duration) -> Option<WsEvent> { + self.event_rx.recv_timeout(timeout).ok() + } +} + +impl Drop for TuiWsClient { + fn drop(&mut self) { + // Try to send shutdown command, but don't wait + let _ = self.command_tx.try_send(WsCommand::Shutdown); + } +} + +/// WebSocket client main loop +async fn run_ws_client( + api_url: String, + api_key: String, + mut command_rx: tokio_mpsc::Receiver<WsCommand>, + event_tx: std_mpsc::Sender<WsEvent>, +) { + use futures::{SinkExt, StreamExt}; + use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest, tungstenite::Message}; + + // Build WebSocket URL from HTTP URL + let ws_url = api_url + .replace("https://", "wss://") + .replace("http://", "ws://"); + let ws_url = format!("{}/api/v1/mesh/tasks/subscribe", ws_url); + + let mut reconnect_attempt = 0u32; + let max_reconnect_delay = Duration::from_secs(30); + let initial_delay = Duration::from_secs(1); + + loop { + // Build request with API key header + let mut request = match ws_url.clone().into_client_request() { + Ok(r) => r, + Err(e) => { + let _ = event_tx.send(WsEvent::Error { + message: format!("Invalid URL: {}", e), + }); + return; + } + }; + + // Send both headers - server will try tool key first, then API key + if let Ok(header_value) = api_key.parse() { + request.headers_mut().insert("x-makima-tool-key", header_value); + } + if let Ok(header_value) = api_key.parse() { + request.headers_mut().insert("x-makima-api-key", header_value); + } + + if reconnect_attempt > 0 { + let _ = event_tx.send(WsEvent::Reconnecting { + attempt: reconnect_attempt, + }); + + // Exponential backoff + let delay = std::cmp::min( + initial_delay * 2u32.saturating_pow(reconnect_attempt - 1), + max_reconnect_delay, + ); + tokio::time::sleep(delay).await; + } + + // Try to connect + let (ws_stream, _) = match connect_async(request).await { + Ok(result) => { + reconnect_attempt = 0; + let _ = event_tx.send(WsEvent::Connected); + result + } + Err(e) => { + reconnect_attempt += 1; + let _ = event_tx.send(WsEvent::Error { + message: format!("Connection failed: {}", e), + }); + continue; + } + }; + + let (mut write, mut read) = ws_stream.split(); + + // Main message loop + loop { + tokio::select! { + // Handle commands from TUI + cmd = command_rx.recv() => { + match cmd { + Some(WsCommand::Subscribe { task_id }) => { + let msg = ClientMessage::SubscribeOutput { task_id }; + if let Ok(json) = serde_json::to_string(&msg) { + let _ = write.send(Message::Text(json)).await; + } + } + Some(WsCommand::Unsubscribe { task_id }) => { + let msg = ClientMessage::UnsubscribeOutput { task_id }; + if let Ok(json) = serde_json::to_string(&msg) { + let _ = write.send(Message::Text(json)).await; + } + } + Some(WsCommand::Shutdown) | None => { + let _ = write.close().await; + return; + } + } + } + + // Handle messages from server + msg = read.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&text) { + match server_msg { + ServerMessage::OutputSubscribed { task_id } => { + let _ = event_tx.send(WsEvent::Subscribed { task_id }); + } + ServerMessage::OutputUnsubscribed { task_id } => { + let _ = event_tx.send(WsEvent::Unsubscribed { task_id }); + } + ServerMessage::TaskOutput { + task_id, + message_type, + content, + tool_name, + tool_input, + is_error, + cost_usd, + duration_ms, + is_partial, + } => { + let _ = event_tx.send(WsEvent::TaskOutput(TaskOutputEvent { + task_id, + message_type, + content, + tool_name, + tool_input, + is_error, + cost_usd, + duration_ms, + is_partial, + })); + } + ServerMessage::Error { code, message } => { + let _ = event_tx.send(WsEvent::Error { + message: format!("{}: {}", code, message), + }); + } + ServerMessage::Other => { + // Ignore other message types + } + } + } + } + Some(Ok(Message::Ping(data))) => { + let _ = write.send(Message::Pong(data)).await; + } + Some(Ok(Message::Close(_))) | None => { + let _ = event_tx.send(WsEvent::Disconnected); + reconnect_attempt += 1; + break; // Reconnect + } + Some(Err(e)) => { + let _ = event_tx.send(WsEvent::Error { + message: format!("WebSocket error: {}", e), + }); + let _ = event_tx.send(WsEvent::Disconnected); + reconnect_attempt += 1; + break; // Reconnect + } + _ => {} + } + } + } + } + } +} |
