//! WebSocket client for connecting to the makima server. use std::sync::Arc; use std::time::Duration; use backoff::backoff::Backoff; use backoff::ExponentialBackoff; use futures::{SinkExt, StreamExt}; use tokio::sync::{mpsc, RwLock}; use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}}; use uuid::Uuid; use super::protocol::{DaemonCommand, DaemonMessage}; use crate::daemon::config::ServerConfig; use crate::daemon::error::{DaemonError, Result}; /// WebSocket client state. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConnectionState { /// Not connected to server. Disconnected, /// Currently connecting. Connecting, /// Connected and authenticated. Connected, /// Connection failed, will retry. Reconnecting, /// Permanently failed (e.g., auth failure). Failed, } /// WebSocket client for daemon-server communication. pub struct WsClient { config: ServerConfig, machine_id: String, hostname: String, max_concurrent_tasks: i32, state: Arc>, daemon_id: Arc>>, /// Channel to receive messages to send to server. outgoing_rx: mpsc::Receiver, /// Sender for outgoing messages (clone this to send messages). outgoing_tx: mpsc::Sender, /// Channel to send received commands to the task manager. incoming_tx: mpsc::Sender, } impl WsClient { /// Create a new WebSocket client. pub fn new( config: ServerConfig, machine_id: String, hostname: String, max_concurrent_tasks: i32, incoming_tx: mpsc::Sender, ) -> Self { let (outgoing_tx, outgoing_rx) = mpsc::channel(256); Self { config, machine_id, hostname, max_concurrent_tasks, state: Arc::new(RwLock::new(ConnectionState::Disconnected)), daemon_id: Arc::new(RwLock::new(None)), outgoing_rx, outgoing_tx, incoming_tx, } } /// Get a sender for outgoing messages. pub fn sender(&self) -> mpsc::Sender { self.outgoing_tx.clone() } /// Get current connection state. pub async fn state(&self) -> ConnectionState { *self.state.read().await } /// Get daemon ID if authenticated. pub async fn daemon_id(&self) -> Option { *self.daemon_id.read().await } /// Run the WebSocket client with automatic reconnection. pub async fn run(&mut self) -> Result<()> { let mut backoff = ExponentialBackoff { initial_interval: Duration::from_secs(self.config.reconnect_interval_secs), max_interval: Duration::from_secs(60), max_elapsed_time: if self.config.max_reconnect_attempts > 0 { Some(Duration::from_secs( self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10, )) } else { None // Infinite retries }, ..Default::default() }; loop { *self.state.write().await = ConnectionState::Connecting; tracing::info!("Connecting to server: {}", self.config.url); match self.connect_and_run().await { Ok(()) => { // Clean shutdown tracing::info!("WebSocket connection closed cleanly"); break; } Err(DaemonError::AuthFailed(msg)) => { tracing::error!("Authentication failed: {}", msg); *self.state.write().await = ConnectionState::Failed; return Err(DaemonError::AuthFailed(msg)); } Err(e) => { tracing::warn!("Connection error: {}", e); *self.state.write().await = ConnectionState::Reconnecting; if let Some(delay) = backoff.next_backoff() { tracing::info!("Reconnecting in {:?}...", delay); tokio::time::sleep(delay).await; } else { tracing::error!("Max reconnection attempts reached"); *self.state.write().await = ConnectionState::Failed; return Err(DaemonError::ConnectionLost); } } } } Ok(()) } /// Connect to server and run the message loop. async fn connect_and_run(&mut self) -> Result<()> { // Build WebSocket URL let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url); tracing::debug!("Connecting to WebSocket: {}", ws_url); // Build request with API key header let mut request = ws_url.into_client_request()?; request.headers_mut().insert( "x-makima-api-key", self.config.api_key.parse().map_err(|_| { DaemonError::AuthFailed("Invalid API key format".into()) })?, ); // Connect with API key in headers let (ws_stream, _response) = connect_async(request).await?; let (mut write, mut read) = ws_stream.split(); // Send daemon info after connection (server authenticated us via header) let info_msg = DaemonMessage::authenticate( &self.config.api_key, &self.machine_id, &self.hostname, self.max_concurrent_tasks, ); let info_json = serde_json::to_string(&info_msg)?; write.send(Message::Text(info_json)).await?; // Wait for authentication response let auth_response = read .next() .await .ok_or(DaemonError::ConnectionLost)??; let auth_text = match auth_response { Message::Text(text) => text, Message::Close(_) => return Err(DaemonError::ConnectionLost), _ => return Err(DaemonError::AuthFailed("Unexpected response type".into())), }; let command: DaemonCommand = serde_json::from_str(&auth_text)?; match command { DaemonCommand::Authenticated { daemon_id } => { tracing::info!("Authenticated with daemon ID: {}", daemon_id); *self.daemon_id.write().await = Some(daemon_id); *self.state.write().await = ConnectionState::Connected; // Send daemon directories info to server let working_directory = std::env::current_dir() .map(|p| p.to_string_lossy().to_string()) .unwrap_or_else(|_| ".".to_string()); let home_directory = dirs::home_dir() .map(|h| h.join(".makima").join("home")) .unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home")); // Create home directory if it doesn't exist if let Err(e) = std::fs::create_dir_all(&home_directory) { tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e); } let home_directory_str = home_directory.to_string_lossy().to_string(); let worktrees_directory = dirs::home_dir() .map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string()) .unwrap_or_else(|| "~/.makima/worktrees".to_string()); let dirs_msg = DaemonMessage::DaemonDirectories { working_directory, home_directory: home_directory_str, worktrees_directory, }; let dirs_json = serde_json::to_string(&dirs_msg)?; write.send(Message::Text(dirs_json)).await?; tracing::info!("Sent daemon directories info to server"); } DaemonCommand::Error { code, message } => { return Err(DaemonError::AuthFailed(format!("{}: {}", code, message))); } _ => { return Err(DaemonError::AuthFailed( "Unexpected response to authentication".into(), )); } } // Start main message loop let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs); let mut heartbeat_timer = tokio::time::interval(heartbeat_interval); loop { tokio::select! { // Handle incoming server commands msg = read.next() => { match msg { Some(Ok(Message::Text(text))) => { tracing::info!("Received WebSocket message: {} bytes", text.len()); match serde_json::from_str::(&text) { Ok(command) => { tracing::info!("Parsed command: {:?}", command); tracing::info!("Sending command to task manager channel..."); if self.incoming_tx.send(command).await.is_err() { tracing::warn!("Command receiver dropped, shutting down"); break; } tracing::info!("Command sent to task manager successfully"); } Err(e) => { tracing::warn!("Failed to parse server message: {}", e); tracing::debug!("Raw message: {}", text); } } } Some(Ok(Message::Ping(data))) => { write.send(Message::Pong(data)).await?; } Some(Ok(Message::Close(_))) | None => { tracing::info!("Server closed connection"); return Err(DaemonError::ConnectionLost); } Some(Err(e)) => { tracing::warn!("WebSocket error: {}", e); return Err(e.into()); } _ => {} } } // Handle outgoing messages msg = self.outgoing_rx.recv() => { match msg { Some(message) => { let json = serde_json::to_string(&message)?; tracing::trace!("Sending message: {}", json); write.send(Message::Text(json)).await?; } None => { // Sender dropped, shutdown tracing::info!("Outgoing channel closed, shutting down"); break; } } } // Send heartbeat _ = heartbeat_timer.tick() => { // Get active task IDs from task manager // For now, send empty list - will be connected to task manager let heartbeat = DaemonMessage::heartbeat(vec![]); let json = serde_json::to_string(&heartbeat)?; write.send(Message::Text(json)).await?; } } } Ok(()) } }