From 87044a747b47bd83249d61a45842c7f7b2eae56d Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 11 Jan 2026 05:52:14 +0000 Subject: Contract system --- makima/src/daemon/ws/client.rs | 290 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 makima/src/daemon/ws/client.rs (limited to 'makima/src/daemon/ws/client.rs') diff --git a/makima/src/daemon/ws/client.rs b/makima/src/daemon/ws/client.rs new file mode 100644 index 0000000..67594a2 --- /dev/null +++ b/makima/src/daemon/ws/client.rs @@ -0,0 +1,290 @@ +//! WebSocket client for connecting to the makima server. + +use std::sync::Arc; +use std::time::Duration; + +use backoff::backoff::Backoff; +use backoff::ExponentialBackoff; +use futures::{SinkExt, StreamExt}; +use tokio::sync::{mpsc, RwLock}; +use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}}; +use uuid::Uuid; + +use super::protocol::{DaemonCommand, DaemonMessage}; +use crate::daemon::config::ServerConfig; +use crate::daemon::error::{DaemonError, Result}; + +/// WebSocket client state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + /// Not connected to server. + Disconnected, + /// Currently connecting. + Connecting, + /// Connected and authenticated. + Connected, + /// Connection failed, will retry. + Reconnecting, + /// Permanently failed (e.g., auth failure). + Failed, +} + +/// WebSocket client for daemon-server communication. +pub struct WsClient { + config: ServerConfig, + machine_id: String, + hostname: String, + max_concurrent_tasks: i32, + state: Arc>, + daemon_id: Arc>>, + /// Channel to receive messages to send to server. + outgoing_rx: mpsc::Receiver, + /// Sender for outgoing messages (clone this to send messages). + outgoing_tx: mpsc::Sender, + /// Channel to send received commands to the task manager. + incoming_tx: mpsc::Sender, +} + +impl WsClient { + /// Create a new WebSocket client. + pub fn new( + config: ServerConfig, + machine_id: String, + hostname: String, + max_concurrent_tasks: i32, + incoming_tx: mpsc::Sender, + ) -> Self { + let (outgoing_tx, outgoing_rx) = mpsc::channel(256); + + Self { + config, + machine_id, + hostname, + max_concurrent_tasks, + state: Arc::new(RwLock::new(ConnectionState::Disconnected)), + daemon_id: Arc::new(RwLock::new(None)), + outgoing_rx, + outgoing_tx, + incoming_tx, + } + } + + /// Get a sender for outgoing messages. + pub fn sender(&self) -> mpsc::Sender { + self.outgoing_tx.clone() + } + + /// Get current connection state. + pub async fn state(&self) -> ConnectionState { + *self.state.read().await + } + + /// Get daemon ID if authenticated. + pub async fn daemon_id(&self) -> Option { + *self.daemon_id.read().await + } + + /// Run the WebSocket client with automatic reconnection. + pub async fn run(&mut self) -> Result<()> { + let mut backoff = ExponentialBackoff { + initial_interval: Duration::from_secs(self.config.reconnect_interval_secs), + max_interval: Duration::from_secs(60), + max_elapsed_time: if self.config.max_reconnect_attempts > 0 { + Some(Duration::from_secs( + self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10, + )) + } else { + None // Infinite retries + }, + ..Default::default() + }; + + loop { + *self.state.write().await = ConnectionState::Connecting; + tracing::info!("Connecting to server: {}", self.config.url); + + match self.connect_and_run().await { + Ok(()) => { + // Clean shutdown + tracing::info!("WebSocket connection closed cleanly"); + break; + } + Err(DaemonError::AuthFailed(msg)) => { + tracing::error!("Authentication failed: {}", msg); + *self.state.write().await = ConnectionState::Failed; + return Err(DaemonError::AuthFailed(msg)); + } + Err(e) => { + tracing::warn!("Connection error: {}", e); + *self.state.write().await = ConnectionState::Reconnecting; + + if let Some(delay) = backoff.next_backoff() { + tracing::info!("Reconnecting in {:?}...", delay); + tokio::time::sleep(delay).await; + } else { + tracing::error!("Max reconnection attempts reached"); + *self.state.write().await = ConnectionState::Failed; + return Err(DaemonError::ConnectionLost); + } + } + } + } + + Ok(()) + } + + /// Connect to server and run the message loop. + async fn connect_and_run(&mut self) -> Result<()> { + // Build WebSocket URL + let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url); + tracing::debug!("Connecting to WebSocket: {}", ws_url); + + // Build request with API key header + let mut request = ws_url.into_client_request()?; + request.headers_mut().insert( + "x-makima-api-key", + self.config.api_key.parse().map_err(|_| { + DaemonError::AuthFailed("Invalid API key format".into()) + })?, + ); + + // Connect with API key in headers + let (ws_stream, _response) = connect_async(request).await?; + let (mut write, mut read) = ws_stream.split(); + + // Send daemon info after connection (server authenticated us via header) + let info_msg = DaemonMessage::authenticate( + &self.config.api_key, + &self.machine_id, + &self.hostname, + self.max_concurrent_tasks, + ); + let info_json = serde_json::to_string(&info_msg)?; + write.send(Message::Text(info_json)).await?; + + // Wait for authentication response + let auth_response = read + .next() + .await + .ok_or(DaemonError::ConnectionLost)??; + + let auth_text = match auth_response { + Message::Text(text) => text, + Message::Close(_) => return Err(DaemonError::ConnectionLost), + _ => return Err(DaemonError::AuthFailed("Unexpected response type".into())), + }; + + let command: DaemonCommand = serde_json::from_str(&auth_text)?; + match command { + DaemonCommand::Authenticated { daemon_id } => { + tracing::info!("Authenticated with daemon ID: {}", daemon_id); + *self.daemon_id.write().await = Some(daemon_id); + *self.state.write().await = ConnectionState::Connected; + + // Send daemon directories info to server + let working_directory = std::env::current_dir() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| ".".to_string()); + let home_directory = dirs::home_dir() + .map(|h| h.join(".makima").join("home")) + .unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home")); + // Create home directory if it doesn't exist + if let Err(e) = std::fs::create_dir_all(&home_directory) { + tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e); + } + let home_directory_str = home_directory.to_string_lossy().to_string(); + let worktrees_directory = dirs::home_dir() + .map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string()) + .unwrap_or_else(|| "~/.makima/worktrees".to_string()); + + let dirs_msg = DaemonMessage::DaemonDirectories { + working_directory, + home_directory: home_directory_str, + worktrees_directory, + }; + let dirs_json = serde_json::to_string(&dirs_msg)?; + write.send(Message::Text(dirs_json)).await?; + tracing::info!("Sent daemon directories info to server"); + } + DaemonCommand::Error { code, message } => { + return Err(DaemonError::AuthFailed(format!("{}: {}", code, message))); + } + _ => { + return Err(DaemonError::AuthFailed( + "Unexpected response to authentication".into(), + )); + } + } + + // Start main message loop + let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs); + let mut heartbeat_timer = tokio::time::interval(heartbeat_interval); + + loop { + tokio::select! { + // Handle incoming server commands + msg = read.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + tracing::info!("Received WebSocket message: {} bytes", text.len()); + match serde_json::from_str::(&text) { + Ok(command) => { + tracing::info!("Parsed command: {:?}", command); + tracing::info!("Sending command to task manager channel..."); + if self.incoming_tx.send(command).await.is_err() { + tracing::warn!("Command receiver dropped, shutting down"); + break; + } + tracing::info!("Command sent to task manager successfully"); + } + Err(e) => { + tracing::warn!("Failed to parse server message: {}", e); + tracing::debug!("Raw message: {}", text); + } + } + } + Some(Ok(Message::Ping(data))) => { + write.send(Message::Pong(data)).await?; + } + Some(Ok(Message::Close(_))) | None => { + tracing::info!("Server closed connection"); + return Err(DaemonError::ConnectionLost); + } + Some(Err(e)) => { + tracing::warn!("WebSocket error: {}", e); + return Err(e.into()); + } + _ => {} + } + } + + // Handle outgoing messages + msg = self.outgoing_rx.recv() => { + match msg { + Some(message) => { + let json = serde_json::to_string(&message)?; + tracing::trace!("Sending message: {}", json); + write.send(Message::Text(json)).await?; + } + None => { + // Sender dropped, shutdown + tracing::info!("Outgoing channel closed, shutting down"); + break; + } + } + } + + // Send heartbeat + _ = heartbeat_timer.tick() => { + // Get active task IDs from task manager + // For now, send empty list - will be connected to task manager + let heartbeat = DaemonMessage::heartbeat(vec![]); + let json = serde_json::to_string(&heartbeat)?; + write.send(Message::Text(json)).await?; + } + } + } + + Ok(()) + } +} -- cgit v1.2.3