summaryrefslogblamecommitdiff
path: root/makima/src/daemon/tui/ws_client.rs
blob: 3462467fff1b6c850c11694c8a6f1c075b24a9c1 (plain) (tree)
































































































































































































































































































































































                                                                                                         
//! TUI WebSocket client for task output streaming.
//!
//! Uses a dedicated async thread to handle WebSocket communication,
//! bridging async/sync worlds via channels.

use std::sync::mpsc as std_mpsc;
use std::thread;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use tokio::runtime::Runtime;
use tokio::sync::mpsc as tokio_mpsc;
use uuid::Uuid;

/// Commands sent from TUI to WebSocket client
#[derive(Debug, Clone)]
pub enum WsCommand {
    /// Subscribe to task output
    Subscribe { task_id: Uuid },
    /// Unsubscribe from task output
    Unsubscribe { task_id: Uuid },
    /// Shutdown the WebSocket client
    Shutdown,
}

/// Events sent from WebSocket client to TUI
#[derive(Debug, Clone)]
pub enum WsEvent {
    /// WebSocket connected
    Connected,
    /// WebSocket disconnected
    Disconnected,
    /// WebSocket reconnecting
    Reconnecting { attempt: u32 },
    /// Subscription confirmed
    Subscribed { task_id: Uuid },
    /// Unsubscription confirmed
    Unsubscribed { task_id: Uuid },
    /// Task output received
    TaskOutput(TaskOutputEvent),
    /// Error occurred
    Error { message: String },
}

/// Task output event from server
#[derive(Debug, Clone)]
pub struct TaskOutputEvent {
    pub task_id: Uuid,
    pub message_type: String,
    pub content: String,
    pub tool_name: Option<String>,
    pub tool_input: Option<serde_json::Value>,
    pub is_error: Option<bool>,
    pub cost_usd: Option<f64>,
    pub duration_ms: Option<u64>,
    pub is_partial: bool,
}

/// Messages sent to the WebSocket server
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ClientMessage {
    SubscribeOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    UnsubscribeOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
}

/// Messages received from the WebSocket server
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
enum ServerMessage {
    OutputSubscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    OutputUnsubscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    TaskOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        #[serde(rename = "messageType")]
        message_type: String,
        content: String,
        #[serde(rename = "toolName")]
        tool_name: Option<String>,
        #[serde(rename = "toolInput")]
        tool_input: Option<serde_json::Value>,
        #[serde(rename = "isError")]
        is_error: Option<bool>,
        #[serde(rename = "costUsd")]
        cost_usd: Option<f64>,
        #[serde(rename = "durationMs")]
        duration_ms: Option<u64>,
        #[serde(rename = "isPartial")]
        is_partial: bool,
    },
    Error {
        code: String,
        message: String,
    },
    // Other message types we don't care about
    #[serde(other)]
    Other,
}

/// TUI WebSocket client handle
pub struct TuiWsClient {
    /// Command sender to WebSocket thread
    command_tx: tokio_mpsc::Sender<WsCommand>,
    /// Event receiver from WebSocket thread
    event_rx: std_mpsc::Receiver<WsEvent>,
}

impl TuiWsClient {
    /// Start a new WebSocket client in a dedicated thread
    pub fn start(api_url: String, api_key: String) -> Self {
        let (command_tx, command_rx) = tokio_mpsc::channel(32);
        let (event_tx, event_rx) = std_mpsc::channel();

        // Spawn as daemon thread so it doesn't block process exit
        thread::Builder::new()
            .name("ws-client".to_string())
            .spawn(move || {
                let rt = match Runtime::new() {
                    Ok(rt) => rt,
                    Err(e) => {
                        let _ = event_tx.send(WsEvent::Error {
                            message: format!("Failed to create tokio runtime: {}", e),
                        });
                        return;
                    }
                };
                rt.block_on(run_ws_client(api_url, api_key, command_rx, event_tx));
            })
            .ok();

        Self {
            command_tx,
            event_rx,
        }
    }

    /// Send a command to the WebSocket client (non-blocking)
    pub fn send(&self, command: WsCommand) {
        // Use try_send to avoid blocking on shutdown
        let _ = self.command_tx.try_send(command);
    }

    /// Subscribe to task output
    pub fn subscribe(&self, task_id: Uuid) {
        self.send(WsCommand::Subscribe { task_id });
    }

    /// Unsubscribe from task output
    pub fn unsubscribe(&self, task_id: Uuid) {
        self.send(WsCommand::Unsubscribe { task_id });
    }

    /// Shutdown the WebSocket client
    pub fn shutdown(&self) {
        self.send(WsCommand::Shutdown);
    }

    /// Try to receive an event (non-blocking)
    pub fn try_recv(&self) -> Option<WsEvent> {
        self.event_rx.try_recv().ok()
    }

    /// Receive an event with timeout
    pub fn recv_timeout(&self, timeout: Duration) -> Option<WsEvent> {
        self.event_rx.recv_timeout(timeout).ok()
    }
}

impl Drop for TuiWsClient {
    fn drop(&mut self) {
        // Try to send shutdown command, but don't wait
        let _ = self.command_tx.try_send(WsCommand::Shutdown);
    }
}

/// WebSocket client main loop
async fn run_ws_client(
    api_url: String,
    api_key: String,
    mut command_rx: tokio_mpsc::Receiver<WsCommand>,
    event_tx: std_mpsc::Sender<WsEvent>,
) {
    use futures::{SinkExt, StreamExt};
    use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest, tungstenite::Message};

    // Build WebSocket URL from HTTP URL
    let ws_url = api_url
        .replace("https://", "wss://")
        .replace("http://", "ws://");
    let ws_url = format!("{}/api/v1/mesh/tasks/subscribe", ws_url);

    let mut reconnect_attempt = 0u32;
    let max_reconnect_delay = Duration::from_secs(30);
    let initial_delay = Duration::from_secs(1);

    loop {
        // Build request with API key header
        let mut request = match ws_url.clone().into_client_request() {
            Ok(r) => r,
            Err(e) => {
                let _ = event_tx.send(WsEvent::Error {
                    message: format!("Invalid URL: {}", e),
                });
                return;
            }
        };

        // Send both headers - server will try tool key first, then API key
        if let Ok(header_value) = api_key.parse() {
            request.headers_mut().insert("x-makima-tool-key", header_value);
        }
        if let Ok(header_value) = api_key.parse() {
            request.headers_mut().insert("x-makima-api-key", header_value);
        }

        if reconnect_attempt > 0 {
            let _ = event_tx.send(WsEvent::Reconnecting {
                attempt: reconnect_attempt,
            });

            // Exponential backoff
            let delay = std::cmp::min(
                initial_delay * 2u32.saturating_pow(reconnect_attempt - 1),
                max_reconnect_delay,
            );
            tokio::time::sleep(delay).await;
        }

        // Try to connect
        let (ws_stream, _) = match connect_async(request).await {
            Ok(result) => {
                reconnect_attempt = 0;
                let _ = event_tx.send(WsEvent::Connected);
                result
            }
            Err(e) => {
                reconnect_attempt += 1;
                let _ = event_tx.send(WsEvent::Error {
                    message: format!("Connection failed: {}", e),
                });
                continue;
            }
        };

        let (mut write, mut read) = ws_stream.split();

        // Main message loop
        loop {
            tokio::select! {
                // Handle commands from TUI
                cmd = command_rx.recv() => {
                    match cmd {
                        Some(WsCommand::Subscribe { task_id }) => {
                            let msg = ClientMessage::SubscribeOutput { task_id };
                            if let Ok(json) = serde_json::to_string(&msg) {
                                let _ = write.send(Message::Text(json)).await;
                            }
                        }
                        Some(WsCommand::Unsubscribe { task_id }) => {
                            let msg = ClientMessage::UnsubscribeOutput { task_id };
                            if let Ok(json) = serde_json::to_string(&msg) {
                                let _ = write.send(Message::Text(json)).await;
                            }
                        }
                        Some(WsCommand::Shutdown) | None => {
                            let _ = write.close().await;
                            return;
                        }
                    }
                }

                // Handle messages from server
                msg = read.next() => {
                    match msg {
                        Some(Ok(Message::Text(text))) => {
                            if let Ok(server_msg) = serde_json::from_str::<ServerMessage>(&text) {
                                match server_msg {
                                    ServerMessage::OutputSubscribed { task_id } => {
                                        let _ = event_tx.send(WsEvent::Subscribed { task_id });
                                    }
                                    ServerMessage::OutputUnsubscribed { task_id } => {
                                        let _ = event_tx.send(WsEvent::Unsubscribed { task_id });
                                    }
                                    ServerMessage::TaskOutput {
                                        task_id,
                                        message_type,
                                        content,
                                        tool_name,
                                        tool_input,
                                        is_error,
                                        cost_usd,
                                        duration_ms,
                                        is_partial,
                                    } => {
                                        let _ = event_tx.send(WsEvent::TaskOutput(TaskOutputEvent {
                                            task_id,
                                            message_type,
                                            content,
                                            tool_name,
                                            tool_input,
                                            is_error,
                                            cost_usd,
                                            duration_ms,
                                            is_partial,
                                        }));
                                    }
                                    ServerMessage::Error { code, message } => {
                                        let _ = event_tx.send(WsEvent::Error {
                                            message: format!("{}: {}", code, message),
                                        });
                                    }
                                    ServerMessage::Other => {
                                        // Ignore other message types
                                    }
                                }
                            }
                        }
                        Some(Ok(Message::Ping(data))) => {
                            let _ = write.send(Message::Pong(data)).await;
                        }
                        Some(Ok(Message::Close(_))) | None => {
                            let _ = event_tx.send(WsEvent::Disconnected);
                            reconnect_attempt += 1;
                            break; // Reconnect
                        }
                        Some(Err(e)) => {
                            let _ = event_tx.send(WsEvent::Error {
                                message: format!("WebSocket error: {}", e),
                            });
                            let _ = event_tx.send(WsEvent::Disconnected);
                            reconnect_attempt += 1;
                            break; // Reconnect
                        }
                        _ => {}
                    }
                }
            }
        }
    }
}