summaryrefslogtreecommitdiff
path: root/makima/src/daemon/tui/ws_client.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/daemon/tui/ws_client.rs')
-rw-r--r--makima/src/daemon/tui/ws_client.rs353
1 files changed, 353 insertions, 0 deletions
diff --git a/makima/src/daemon/tui/ws_client.rs b/makima/src/daemon/tui/ws_client.rs
new file mode 100644
index 0000000..3462467
--- /dev/null
+++ b/makima/src/daemon/tui/ws_client.rs
@@ -0,0 +1,353 @@
+//! TUI WebSocket client for task output streaming.
+//!
+//! Uses a dedicated async thread to handle WebSocket communication,
+//! bridging async/sync worlds via channels.
+
+use std::sync::mpsc as std_mpsc;
+use std::thread;
+use std::time::Duration;
+
+use serde::{Deserialize, Serialize};
+use tokio::runtime::Runtime;
+use tokio::sync::mpsc as tokio_mpsc;
+use uuid::Uuid;
+
+/// Commands sent from TUI to WebSocket client
+#[derive(Debug, Clone)]
+pub enum WsCommand {
+ /// Subscribe to task output
+ Subscribe { task_id: Uuid },
+ /// Unsubscribe from task output
+ Unsubscribe { task_id: Uuid },
+ /// Shutdown the WebSocket client
+ Shutdown,
+}
+
+/// Events sent from WebSocket client to TUI
+#[derive(Debug, Clone)]
+pub enum WsEvent {
+ /// WebSocket connected
+ Connected,
+ /// WebSocket disconnected
+ Disconnected,
+ /// WebSocket reconnecting
+ Reconnecting { attempt: u32 },
+ /// Subscription confirmed
+ Subscribed { task_id: Uuid },
+ /// Unsubscription confirmed
+ Unsubscribed { task_id: Uuid },
+ /// Task output received
+ TaskOutput(TaskOutputEvent),
+ /// Error occurred
+ Error { message: String },
+}
+
+/// Task output event from server
+#[derive(Debug, Clone)]
+pub struct TaskOutputEvent {
+ pub task_id: Uuid,
+ pub message_type: String,
+ pub content: String,
+ pub tool_name: Option<String>,
+ pub tool_input: Option<serde_json::Value>,
+ pub is_error: Option<bool>,
+ pub cost_usd: Option<f64>,
+ pub duration_ms: Option<u64>,
+ pub is_partial: bool,
+}
+
+/// Messages sent to the WebSocket server
+#[derive(Debug, Clone, Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+enum ClientMessage {
+ SubscribeOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ UnsubscribeOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+}
+
+/// Messages received from the WebSocket server
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+enum ServerMessage {
+ OutputSubscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ OutputUnsubscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ TaskOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ #[serde(rename = "messageType")]
+ message_type: String,
+ content: String,
+ #[serde(rename = "toolName")]
+ tool_name: Option<String>,
+ #[serde(rename = "toolInput")]
+ tool_input: Option<serde_json::Value>,
+ #[serde(rename = "isError")]
+ is_error: Option<bool>,
+ #[serde(rename = "costUsd")]
+ cost_usd: Option<f64>,
+ #[serde(rename = "durationMs")]
+ duration_ms: Option<u64>,
+ #[serde(rename = "isPartial")]
+ is_partial: bool,
+ },
+ Error {
+ code: String,
+ message: String,
+ },
+ // Other message types we don't care about
+ #[serde(other)]
+ Other,
+}
+
+/// TUI WebSocket client handle
+pub struct TuiWsClient {
+ /// Command sender to WebSocket thread
+ command_tx: tokio_mpsc::Sender<WsCommand>,
+ /// Event receiver from WebSocket thread
+ event_rx: std_mpsc::Receiver<WsEvent>,
+}
+
+impl TuiWsClient {
+ /// Start a new WebSocket client in a dedicated thread
+ pub fn start(api_url: String, api_key: String) -> Self {
+ let (command_tx, command_rx) = tokio_mpsc::channel(32);
+ let (event_tx, event_rx) = std_mpsc::channel();
+
+ // Spawn as daemon thread so it doesn't block process exit
+ thread::Builder::new()
+ .name("ws-client".to_string())
+ .spawn(move || {
+ let rt = match Runtime::new() {
+ Ok(rt) => rt,
+ Err(e) => {
+ let _ = event_tx.send(WsEvent::Error {
+ message: format!("Failed to create tokio runtime: {}", e),
+ });
+ return;
+ }
+ };
+ rt.block_on(run_ws_client(api_url, api_key, command_rx, event_tx));
+ })
+ .ok();
+
+ Self {
+ command_tx,
+ event_rx,
+ }
+ }
+
+ /// Send a command to the WebSocket client (non-blocking)
+ pub fn send(&self, command: WsCommand) {
+ // Use try_send to avoid blocking on shutdown
+ let _ = self.command_tx.try_send(command);
+ }
+
+ /// Subscribe to task output
+ pub fn subscribe(&self, task_id: Uuid) {
+ self.send(WsCommand::Subscribe { task_id });
+ }
+
+ /// Unsubscribe from task output
+ pub fn unsubscribe(&self, task_id: Uuid) {
+ self.send(WsCommand::Unsubscribe { task_id });
+ }
+
+ /// Shutdown the WebSocket client
+ pub fn shutdown(&self) {
+ self.send(WsCommand::Shutdown);
+ }
+
+ /// Try to receive an event (non-blocking)
+ pub fn try_recv(&self) -> Option<WsEvent> {
+ self.event_rx.try_recv().ok()
+ }
+
+ /// Receive an event with timeout
+ pub fn recv_timeout(&self, timeout: Duration) -> Option<WsEvent> {
+ self.event_rx.recv_timeout(timeout).ok()
+ }
+}
+
+impl Drop for TuiWsClient {
+ fn drop(&mut self) {
+ // Try to send shutdown command, but don't wait
+ let _ = self.command_tx.try_send(WsCommand::Shutdown);
+ }
+}
+
+/// WebSocket client main loop
+async fn run_ws_client(
+ api_url: String,
+ api_key: String,
+ mut command_rx: tokio_mpsc::Receiver<WsCommand>,
+ event_tx: std_mpsc::Sender<WsEvent>,
+) {
+ use futures::{SinkExt, StreamExt};
+ use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest, tungstenite::Message};
+
+ // Build WebSocket URL from HTTP URL
+ let ws_url = api_url
+ .replace("https://", "wss://")
+ .replace("http://", "ws://");
+ let ws_url = format!("{}/api/v1/mesh/tasks/subscribe", ws_url);
+
+ let mut reconnect_attempt = 0u32;
+ let max_reconnect_delay = Duration::from_secs(30);
+ let initial_delay = Duration::from_secs(1);
+
+ loop {
+ // Build request with API key header
+ let mut request = match ws_url.clone().into_client_request() {
+ Ok(r) => r,
+ Err(e) => {
+ let _ = event_tx.send(WsEvent::Error {
+ message: format!("Invalid URL: {}", e),
+ });
+ return;
+ }
+ };
+
+ // Send both headers - server will try tool key first, then API key
+ if let Ok(header_value) = api_key.parse() {
+ request.headers_mut().insert("x-makima-tool-key", header_value);
+ }
+ if let Ok(header_value) = api_key.parse() {
+ request.headers_mut().insert("x-makima-api-key", header_value);
+ }
+
+ if reconnect_attempt > 0 {
+ let _ = event_tx.send(WsEvent::Reconnecting {
+ attempt: reconnect_attempt,
+ });
+
+ // Exponential backoff
+ let delay = std::cmp::min(
+ initial_delay * 2u32.saturating_pow(reconnect_attempt - 1),
+ max_reconnect_delay,
+ );
+ tokio::time::sleep(delay).await;
+ }
+
+ // Try to connect
+ let (ws_stream, _) = match connect_async(request).await {
+ Ok(result) => {
+ reconnect_attempt = 0;
+ let _ = event_tx.send(WsEvent::Connected);
+ result
+ }
+ Err(e) => {
+ reconnect_attempt += 1;
+ let _ = event_tx.send(WsEvent::Error {
+ message: format!("Connection failed: {}", e),
+ });
+ continue;
+ }
+ };
+
+ let (mut write, mut read) = ws_stream.split();
+
+ // Main message loop
+ loop {
+ tokio::select! {
+ // Handle commands from TUI
+ cmd = command_rx.recv() => {
+ match cmd {
+ Some(WsCommand::Subscribe { task_id }) => {
+ let msg = ClientMessage::SubscribeOutput { task_id };
+ if let Ok(json) = serde_json::to_string(&msg) {
+ let _ = write.send(Message::Text(json)).await;
+ }
+ }
+ Some(WsCommand::Unsubscribe { task_id }) => {
+ let msg = ClientMessage::UnsubscribeOutput { task_id };
+ if let Ok(json) = serde_json::to_string(&msg) {
+ let _ = write.send(Message::Text(json)).await;
+ }
+ }
+ Some(WsCommand::Shutdown) | None => {
+ let _ = write.close().await;
+ return;
+ }
+ }
+ }
+
+ // Handle messages from server
+ msg = read.next() => {
+ match msg {
+ Some(Ok(Message::Text(text))) => {
+ if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&text) {
+ match server_msg {
+ ServerMessage::OutputSubscribed { task_id } => {
+ let _ = event_tx.send(WsEvent::Subscribed { task_id });
+ }
+ ServerMessage::OutputUnsubscribed { task_id } => {
+ let _ = event_tx.send(WsEvent::Unsubscribed { task_id });
+ }
+ ServerMessage::TaskOutput {
+ task_id,
+ message_type,
+ content,
+ tool_name,
+ tool_input,
+ is_error,
+ cost_usd,
+ duration_ms,
+ is_partial,
+ } => {
+ let _ = event_tx.send(WsEvent::TaskOutput(TaskOutputEvent {
+ task_id,
+ message_type,
+ content,
+ tool_name,
+ tool_input,
+ is_error,
+ cost_usd,
+ duration_ms,
+ is_partial,
+ }));
+ }
+ ServerMessage::Error { code, message } => {
+ let _ = event_tx.send(WsEvent::Error {
+ message: format!("{}: {}", code, message),
+ });
+ }
+ ServerMessage::Other => {
+ // Ignore other message types
+ }
+ }
+ }
+ }
+ Some(Ok(Message::Ping(data))) => {
+ let _ = write.send(Message::Pong(data)).await;
+ }
+ Some(Ok(Message::Close(_))) | None => {
+ let _ = event_tx.send(WsEvent::Disconnected);
+ reconnect_attempt += 1;
+ break; // Reconnect
+ }
+ Some(Err(e)) => {
+ let _ = event_tx.send(WsEvent::Error {
+ message: format!("WebSocket error: {}", e),
+ });
+ let _ = event_tx.send(WsEvent::Disconnected);
+ reconnect_attempt += 1;
+ break; // Reconnect
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+ }
+}