summaryrefslogblamecommitdiff
path: root/makima/daemon/src/ws/client.rs
blob: ba1263fc8e404a780458cb9907e6e65e981b61f8 (plain) (tree)

































































































































































































































































































                                                                                                         
//! WebSocket client for connecting to the makima server.

use std::sync::Arc;
use std::time::Duration;

use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}};
use uuid::Uuid;

use super::protocol::{DaemonCommand, DaemonMessage};
use crate::config::ServerConfig;
use crate::error::{DaemonError, Result};

/// WebSocket client state.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
    /// Not connected to server.
    Disconnected,
    /// Currently connecting.
    Connecting,
    /// Connected and authenticated.
    Connected,
    /// Connection failed, will retry.
    Reconnecting,
    /// Permanently failed (e.g., auth failure).
    Failed,
}

/// WebSocket client for daemon-server communication.
pub struct WsClient {
    config: ServerConfig,
    machine_id: String,
    hostname: String,
    max_concurrent_tasks: i32,
    state: Arc<RwLock<ConnectionState>>,
    daemon_id: Arc<RwLock<Option<Uuid>>>,
    /// Channel to receive messages to send to server.
    outgoing_rx: mpsc::Receiver<DaemonMessage>,
    /// Sender for outgoing messages (clone this to send messages).
    outgoing_tx: mpsc::Sender<DaemonMessage>,
    /// Channel to send received commands to the task manager.
    incoming_tx: mpsc::Sender<DaemonCommand>,
}

impl WsClient {
    /// Create a new WebSocket client.
    pub fn new(
        config: ServerConfig,
        machine_id: String,
        hostname: String,
        max_concurrent_tasks: i32,
        incoming_tx: mpsc::Sender<DaemonCommand>,
    ) -> Self {
        let (outgoing_tx, outgoing_rx) = mpsc::channel(256);

        Self {
            config,
            machine_id,
            hostname,
            max_concurrent_tasks,
            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
            daemon_id: Arc::new(RwLock::new(None)),
            outgoing_rx,
            outgoing_tx,
            incoming_tx,
        }
    }

    /// Get a sender for outgoing messages.
    pub fn sender(&self) -> mpsc::Sender<DaemonMessage> {
        self.outgoing_tx.clone()
    }

    /// Get current connection state.
    pub async fn state(&self) -> ConnectionState {
        *self.state.read().await
    }

    /// Get daemon ID if authenticated.
    pub async fn daemon_id(&self) -> Option<Uuid> {
        *self.daemon_id.read().await
    }

    /// Run the WebSocket client with automatic reconnection.
    pub async fn run(&mut self) -> Result<()> {
        let mut backoff = ExponentialBackoff {
            initial_interval: Duration::from_secs(self.config.reconnect_interval_secs),
            max_interval: Duration::from_secs(60),
            max_elapsed_time: if self.config.max_reconnect_attempts > 0 {
                Some(Duration::from_secs(
                    self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10,
                ))
            } else {
                None // Infinite retries
            },
            ..Default::default()
        };

        loop {
            *self.state.write().await = ConnectionState::Connecting;
            tracing::info!("Connecting to server: {}", self.config.url);

            match self.connect_and_run().await {
                Ok(()) => {
                    // Clean shutdown
                    tracing::info!("WebSocket connection closed cleanly");
                    break;
                }
                Err(DaemonError::AuthFailed(msg)) => {
                    tracing::error!("Authentication failed: {}", msg);
                    *self.state.write().await = ConnectionState::Failed;
                    return Err(DaemonError::AuthFailed(msg));
                }
                Err(e) => {
                    tracing::warn!("Connection error: {}", e);
                    *self.state.write().await = ConnectionState::Reconnecting;

                    if let Some(delay) = backoff.next_backoff() {
                        tracing::info!("Reconnecting in {:?}...", delay);
                        tokio::time::sleep(delay).await;
                    } else {
                        tracing::error!("Max reconnection attempts reached");
                        *self.state.write().await = ConnectionState::Failed;
                        return Err(DaemonError::ConnectionLost);
                    }
                }
            }
        }

        Ok(())
    }

    /// Connect to server and run the message loop.
    async fn connect_and_run(&mut self) -> Result<()> {
        // Build WebSocket URL
        let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url);
        tracing::debug!("Connecting to WebSocket: {}", ws_url);

        // Build request with API key header
        let mut request = ws_url.into_client_request()?;
        request.headers_mut().insert(
            "x-makima-api-key",
            self.config.api_key.parse().map_err(|_| {
                DaemonError::AuthFailed("Invalid API key format".into())
            })?,
        );

        // Connect with API key in headers
        let (ws_stream, _response) = connect_async(request).await?;
        let (mut write, mut read) = ws_stream.split();

        // Send daemon info after connection (server authenticated us via header)
        let info_msg = DaemonMessage::authenticate(
            &self.config.api_key,
            &self.machine_id,
            &self.hostname,
            self.max_concurrent_tasks,
        );
        let info_json = serde_json::to_string(&info_msg)?;
        write.send(Message::Text(info_json)).await?;

        // Wait for authentication response
        let auth_response = read
            .next()
            .await
            .ok_or(DaemonError::ConnectionLost)??;

        let auth_text = match auth_response {
            Message::Text(text) => text,
            Message::Close(_) => return Err(DaemonError::ConnectionLost),
            _ => return Err(DaemonError::AuthFailed("Unexpected response type".into())),
        };

        let command: DaemonCommand = serde_json::from_str(&auth_text)?;
        match command {
            DaemonCommand::Authenticated { daemon_id } => {
                tracing::info!("Authenticated with daemon ID: {}", daemon_id);
                *self.daemon_id.write().await = Some(daemon_id);
                *self.state.write().await = ConnectionState::Connected;

                // Send daemon directories info to server
                let working_directory = std::env::current_dir()
                    .map(|p| p.to_string_lossy().to_string())
                    .unwrap_or_else(|_| ".".to_string());
                let home_directory = dirs::home_dir()
                    .map(|h| h.join(".makima").join("home"))
                    .unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home"));
                // Create home directory if it doesn't exist
                if let Err(e) = std::fs::create_dir_all(&home_directory) {
                    tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e);
                }
                let home_directory_str = home_directory.to_string_lossy().to_string();
                let worktrees_directory = dirs::home_dir()
                    .map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string())
                    .unwrap_or_else(|| "~/.makima/worktrees".to_string());

                let dirs_msg = DaemonMessage::DaemonDirectories {
                    working_directory,
                    home_directory: home_directory_str,
                    worktrees_directory,
                };
                let dirs_json = serde_json::to_string(&dirs_msg)?;
                write.send(Message::Text(dirs_json)).await?;
                tracing::info!("Sent daemon directories info to server");
            }
            DaemonCommand::Error { code, message } => {
                return Err(DaemonError::AuthFailed(format!("{}: {}", code, message)));
            }
            _ => {
                return Err(DaemonError::AuthFailed(
                    "Unexpected response to authentication".into(),
                ));
            }
        }

        // Start main message loop
        let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs);
        let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);

        loop {
            tokio::select! {
                // Handle incoming server commands
                msg = read.next() => {
                    match msg {
                        Some(Ok(Message::Text(text))) => {
                            tracing::info!("Received WebSocket message: {} bytes", text.len());
                            match serde_json::from_str::<DaemonCommand>(&text) {
                                Ok(command) => {
                                    tracing::info!("Parsed command: {:?}", command);
                                    tracing::info!("Sending command to task manager channel...");
                                    if self.incoming_tx.send(command).await.is_err() {
                                        tracing::warn!("Command receiver dropped, shutting down");
                                        break;
                                    }
                                    tracing::info!("Command sent to task manager successfully");
                                }
                                Err(e) => {
                                    tracing::warn!("Failed to parse server message: {}", e);
                                    tracing::debug!("Raw message: {}", text);
                                }
                            }
                        }
                        Some(Ok(Message::Ping(data))) => {
                            write.send(Message::Pong(data)).await?;
                        }
                        Some(Ok(Message::Close(_))) | None => {
                            tracing::info!("Server closed connection");
                            return Err(DaemonError::ConnectionLost);
                        }
                        Some(Err(e)) => {
                            tracing::warn!("WebSocket error: {}", e);
                            return Err(e.into());
                        }
                        _ => {}
                    }
                }

                // Handle outgoing messages
                msg = self.outgoing_rx.recv() => {
                    match msg {
                        Some(message) => {
                            let json = serde_json::to_string(&message)?;
                            tracing::trace!("Sending message: {}", json);
                            write.send(Message::Text(json)).await?;
                        }
                        None => {
                            // Sender dropped, shutdown
                            tracing::info!("Outgoing channel closed, shutting down");
                            break;
                        }
                    }
                }

                // Send heartbeat
                _ = heartbeat_timer.tick() => {
                    // Get active task IDs from task manager
                    // For now, send empty list - will be connected to task manager
                    let heartbeat = DaemonMessage::heartbeat(vec![]);
                    let json = serde_json::to_string(&heartbeat)?;
                    write.send(Message::Text(json)).await?;
                }
            }
        }

        Ok(())
    }
}