summaryrefslogtreecommitdiff
path: root/makima/src/daemon/ws/client.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-11 05:52:14 +0000
committersoryu <soryu@soryu.co>2026-01-15 00:21:16 +0000
commit87044a747b47bd83249d61a45842c7f7b2eae56d (patch)
treeef2000ce79ffcc2723ef841acef5aa1deb1d5378 /makima/src/daemon/ws/client.rs
parent077820c4167c168072d217a1b01df840463a12a8 (diff)
downloadsoryu-87044a747b47bd83249d61a45842c7f7b2eae56d.tar.gz
soryu-87044a747b47bd83249d61a45842c7f7b2eae56d.zip
Contract system
Diffstat (limited to 'makima/src/daemon/ws/client.rs')
-rw-r--r--makima/src/daemon/ws/client.rs290
1 files changed, 290 insertions, 0 deletions
diff --git a/makima/src/daemon/ws/client.rs b/makima/src/daemon/ws/client.rs
new file mode 100644
index 0000000..67594a2
--- /dev/null
+++ b/makima/src/daemon/ws/client.rs
@@ -0,0 +1,290 @@
+//! WebSocket client for connecting to the makima server.
+
+use std::sync::Arc;
+use std::time::Duration;
+
+use backoff::backoff::Backoff;
+use backoff::ExponentialBackoff;
+use futures::{SinkExt, StreamExt};
+use tokio::sync::{mpsc, RwLock};
+use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}};
+use uuid::Uuid;
+
+use super::protocol::{DaemonCommand, DaemonMessage};
+use crate::daemon::config::ServerConfig;
+use crate::daemon::error::{DaemonError, Result};
+
+/// WebSocket client state.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ConnectionState {
+ /// Not connected to server.
+ Disconnected,
+ /// Currently connecting.
+ Connecting,
+ /// Connected and authenticated.
+ Connected,
+ /// Connection failed, will retry.
+ Reconnecting,
+ /// Permanently failed (e.g., auth failure).
+ Failed,
+}
+
+/// WebSocket client for daemon-server communication.
+pub struct WsClient {
+ config: ServerConfig,
+ machine_id: String,
+ hostname: String,
+ max_concurrent_tasks: i32,
+ state: Arc<RwLock<ConnectionState>>,
+ daemon_id: Arc<RwLock<Option<Uuid>>>,
+ /// Channel to receive messages to send to server.
+ outgoing_rx: mpsc::Receiver<DaemonMessage>,
+ /// Sender for outgoing messages (clone this to send messages).
+ outgoing_tx: mpsc::Sender<DaemonMessage>,
+ /// Channel to send received commands to the task manager.
+ incoming_tx: mpsc::Sender<DaemonCommand>,
+}
+
+impl WsClient {
+ /// Create a new WebSocket client.
+ pub fn new(
+ config: ServerConfig,
+ machine_id: String,
+ hostname: String,
+ max_concurrent_tasks: i32,
+ incoming_tx: mpsc::Sender<DaemonCommand>,
+ ) -> Self {
+ let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
+
+ Self {
+ config,
+ machine_id,
+ hostname,
+ max_concurrent_tasks,
+ state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
+ daemon_id: Arc::new(RwLock::new(None)),
+ outgoing_rx,
+ outgoing_tx,
+ incoming_tx,
+ }
+ }
+
+ /// Get a sender for outgoing messages.
+ pub fn sender(&self) -> mpsc::Sender<DaemonMessage> {
+ self.outgoing_tx.clone()
+ }
+
+ /// Get current connection state.
+ pub async fn state(&self) -> ConnectionState {
+ *self.state.read().await
+ }
+
+ /// Get daemon ID if authenticated.
+ pub async fn daemon_id(&self) -> Option<Uuid> {
+ *self.daemon_id.read().await
+ }
+
+ /// Run the WebSocket client with automatic reconnection.
+ pub async fn run(&mut self) -> Result<()> {
+ let mut backoff = ExponentialBackoff {
+ initial_interval: Duration::from_secs(self.config.reconnect_interval_secs),
+ max_interval: Duration::from_secs(60),
+ max_elapsed_time: if self.config.max_reconnect_attempts > 0 {
+ Some(Duration::from_secs(
+ self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10,
+ ))
+ } else {
+ None // Infinite retries
+ },
+ ..Default::default()
+ };
+
+ loop {
+ *self.state.write().await = ConnectionState::Connecting;
+ tracing::info!("Connecting to server: {}", self.config.url);
+
+ match self.connect_and_run().await {
+ Ok(()) => {
+ // Clean shutdown
+ tracing::info!("WebSocket connection closed cleanly");
+ break;
+ }
+ Err(DaemonError::AuthFailed(msg)) => {
+ tracing::error!("Authentication failed: {}", msg);
+ *self.state.write().await = ConnectionState::Failed;
+ return Err(DaemonError::AuthFailed(msg));
+ }
+ Err(e) => {
+ tracing::warn!("Connection error: {}", e);
+ *self.state.write().await = ConnectionState::Reconnecting;
+
+ if let Some(delay) = backoff.next_backoff() {
+ tracing::info!("Reconnecting in {:?}...", delay);
+ tokio::time::sleep(delay).await;
+ } else {
+ tracing::error!("Max reconnection attempts reached");
+ *self.state.write().await = ConnectionState::Failed;
+ return Err(DaemonError::ConnectionLost);
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Connect to server and run the message loop.
+ async fn connect_and_run(&mut self) -> Result<()> {
+ // Build WebSocket URL
+ let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url);
+ tracing::debug!("Connecting to WebSocket: {}", ws_url);
+
+ // Build request with API key header
+ let mut request = ws_url.into_client_request()?;
+ request.headers_mut().insert(
+ "x-makima-api-key",
+ self.config.api_key.parse().map_err(|_| {
+ DaemonError::AuthFailed("Invalid API key format".into())
+ })?,
+ );
+
+ // Connect with API key in headers
+ let (ws_stream, _response) = connect_async(request).await?;
+ let (mut write, mut read) = ws_stream.split();
+
+ // Send daemon info after connection (server authenticated us via header)
+ let info_msg = DaemonMessage::authenticate(
+ &self.config.api_key,
+ &self.machine_id,
+ &self.hostname,
+ self.max_concurrent_tasks,
+ );
+ let info_json = serde_json::to_string(&info_msg)?;
+ write.send(Message::Text(info_json)).await?;
+
+ // Wait for authentication response
+ let auth_response = read
+ .next()
+ .await
+ .ok_or(DaemonError::ConnectionLost)??;
+
+ let auth_text = match auth_response {
+ Message::Text(text) => text,
+ Message::Close(_) => return Err(DaemonError::ConnectionLost),
+ _ => return Err(DaemonError::AuthFailed("Unexpected response type".into())),
+ };
+
+ let command: DaemonCommand = serde_json::from_str(&auth_text)?;
+ match command {
+ DaemonCommand::Authenticated { daemon_id } => {
+ tracing::info!("Authenticated with daemon ID: {}", daemon_id);
+ *self.daemon_id.write().await = Some(daemon_id);
+ *self.state.write().await = ConnectionState::Connected;
+
+ // Send daemon directories info to server
+ let working_directory = std::env::current_dir()
+ .map(|p| p.to_string_lossy().to_string())
+ .unwrap_or_else(|_| ".".to_string());
+ let home_directory = dirs::home_dir()
+ .map(|h| h.join(".makima").join("home"))
+ .unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home"));
+ // Create home directory if it doesn't exist
+ if let Err(e) = std::fs::create_dir_all(&home_directory) {
+ tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e);
+ }
+ let home_directory_str = home_directory.to_string_lossy().to_string();
+ let worktrees_directory = dirs::home_dir()
+ .map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string())
+ .unwrap_or_else(|| "~/.makima/worktrees".to_string());
+
+ let dirs_msg = DaemonMessage::DaemonDirectories {
+ working_directory,
+ home_directory: home_directory_str,
+ worktrees_directory,
+ };
+ let dirs_json = serde_json::to_string(&dirs_msg)?;
+ write.send(Message::Text(dirs_json)).await?;
+ tracing::info!("Sent daemon directories info to server");
+ }
+ DaemonCommand::Error { code, message } => {
+ return Err(DaemonError::AuthFailed(format!("{}: {}", code, message)));
+ }
+ _ => {
+ return Err(DaemonError::AuthFailed(
+ "Unexpected response to authentication".into(),
+ ));
+ }
+ }
+
+ // Start main message loop
+ let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs);
+ let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
+
+ loop {
+ tokio::select! {
+ // Handle incoming server commands
+ msg = read.next() => {
+ match msg {
+ Some(Ok(Message::Text(text))) => {
+ tracing::info!("Received WebSocket message: {} bytes", text.len());
+ match serde_json::from_str::<DaemonCommand>(&text) {
+ Ok(command) => {
+ tracing::info!("Parsed command: {:?}", command);
+ tracing::info!("Sending command to task manager channel...");
+ if self.incoming_tx.send(command).await.is_err() {
+ tracing::warn!("Command receiver dropped, shutting down");
+ break;
+ }
+ tracing::info!("Command sent to task manager successfully");
+ }
+ Err(e) => {
+ tracing::warn!("Failed to parse server message: {}", e);
+ tracing::debug!("Raw message: {}", text);
+ }
+ }
+ }
+ Some(Ok(Message::Ping(data))) => {
+ write.send(Message::Pong(data)).await?;
+ }
+ Some(Ok(Message::Close(_))) | None => {
+ tracing::info!("Server closed connection");
+ return Err(DaemonError::ConnectionLost);
+ }
+ Some(Err(e)) => {
+ tracing::warn!("WebSocket error: {}", e);
+ return Err(e.into());
+ }
+ _ => {}
+ }
+ }
+
+ // Handle outgoing messages
+ msg = self.outgoing_rx.recv() => {
+ match msg {
+ Some(message) => {
+ let json = serde_json::to_string(&message)?;
+ tracing::trace!("Sending message: {}", json);
+ write.send(Message::Text(json)).await?;
+ }
+ None => {
+ // Sender dropped, shutdown
+ tracing::info!("Outgoing channel closed, shutting down");
+ break;
+ }
+ }
+ }
+
+ // Send heartbeat
+ _ = heartbeat_timer.tick() => {
+ // Get active task IDs from task manager
+ // For now, send empty list - will be connected to task manager
+ let heartbeat = DaemonMessage::heartbeat(vec![]);
+ let json = serde_json::to_string(&heartbeat)?;
+ write.send(Message::Text(json)).await?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+}