diff options
Diffstat (limited to 'makima/daemon/src/ws')
| -rw-r--r-- | makima/daemon/src/ws/client.rs | 290 | ||||
| -rw-r--r-- | makima/daemon/src/ws/mod.rs | 7 | ||||
| -rw-r--r-- | makima/daemon/src/ws/protocol.rs | 511 |
3 files changed, 808 insertions, 0 deletions
diff --git a/makima/daemon/src/ws/client.rs b/makima/daemon/src/ws/client.rs new file mode 100644 index 0000000..ba1263f --- /dev/null +++ b/makima/daemon/src/ws/client.rs @@ -0,0 +1,290 @@ +//! WebSocket client for connecting to the makima server. + +use std::sync::Arc; +use std::time::Duration; + +use backoff::backoff::Backoff; +use backoff::ExponentialBackoff; +use futures::{SinkExt, StreamExt}; +use tokio::sync::{mpsc, RwLock}; +use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}}; +use uuid::Uuid; + +use super::protocol::{DaemonCommand, DaemonMessage}; +use crate::config::ServerConfig; +use crate::error::{DaemonError, Result}; + +/// WebSocket client state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + /// Not connected to server. + Disconnected, + /// Currently connecting. + Connecting, + /// Connected and authenticated. + Connected, + /// Connection failed, will retry. + Reconnecting, + /// Permanently failed (e.g., auth failure). + Failed, +} + +/// WebSocket client for daemon-server communication. +pub struct WsClient { + config: ServerConfig, + machine_id: String, + hostname: String, + max_concurrent_tasks: i32, + state: Arc<RwLock<ConnectionState>>, + daemon_id: Arc<RwLock<Option<Uuid>>>, + /// Channel to receive messages to send to server. + outgoing_rx: mpsc::Receiver<DaemonMessage>, + /// Sender for outgoing messages (clone this to send messages). + outgoing_tx: mpsc::Sender<DaemonMessage>, + /// Channel to send received commands to the task manager. + incoming_tx: mpsc::Sender<DaemonCommand>, +} + +impl WsClient { + /// Create a new WebSocket client. + pub fn new( + config: ServerConfig, + machine_id: String, + hostname: String, + max_concurrent_tasks: i32, + incoming_tx: mpsc::Sender<DaemonCommand>, + ) -> Self { + let (outgoing_tx, outgoing_rx) = mpsc::channel(256); + + Self { + config, + machine_id, + hostname, + max_concurrent_tasks, + state: Arc::new(RwLock::new(ConnectionState::Disconnected)), + daemon_id: Arc::new(RwLock::new(None)), + outgoing_rx, + outgoing_tx, + incoming_tx, + } + } + + /// Get a sender for outgoing messages. + pub fn sender(&self) -> mpsc::Sender<DaemonMessage> { + self.outgoing_tx.clone() + } + + /// Get current connection state. + pub async fn state(&self) -> ConnectionState { + *self.state.read().await + } + + /// Get daemon ID if authenticated. + pub async fn daemon_id(&self) -> Option<Uuid> { + *self.daemon_id.read().await + } + + /// Run the WebSocket client with automatic reconnection. + pub async fn run(&mut self) -> Result<()> { + let mut backoff = ExponentialBackoff { + initial_interval: Duration::from_secs(self.config.reconnect_interval_secs), + max_interval: Duration::from_secs(60), + max_elapsed_time: if self.config.max_reconnect_attempts > 0 { + Some(Duration::from_secs( + self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10, + )) + } else { + None // Infinite retries + }, + ..Default::default() + }; + + loop { + *self.state.write().await = ConnectionState::Connecting; + tracing::info!("Connecting to server: {}", self.config.url); + + match self.connect_and_run().await { + Ok(()) => { + // Clean shutdown + tracing::info!("WebSocket connection closed cleanly"); + break; + } + Err(DaemonError::AuthFailed(msg)) => { + tracing::error!("Authentication failed: {}", msg); + *self.state.write().await = ConnectionState::Failed; + return Err(DaemonError::AuthFailed(msg)); + } + Err(e) => { + tracing::warn!("Connection error: {}", e); + *self.state.write().await = ConnectionState::Reconnecting; + + if let Some(delay) = backoff.next_backoff() { + tracing::info!("Reconnecting in {:?}...", delay); + tokio::time::sleep(delay).await; + } else { + tracing::error!("Max reconnection attempts reached"); + *self.state.write().await = ConnectionState::Failed; + return Err(DaemonError::ConnectionLost); + } + } + } + } + + Ok(()) + } + + /// Connect to server and run the message loop. + async fn connect_and_run(&mut self) -> Result<()> { + // Build WebSocket URL + let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url); + tracing::debug!("Connecting to WebSocket: {}", ws_url); + + // Build request with API key header + let mut request = ws_url.into_client_request()?; + request.headers_mut().insert( + "x-makima-api-key", + self.config.api_key.parse().map_err(|_| { + DaemonError::AuthFailed("Invalid API key format".into()) + })?, + ); + + // Connect with API key in headers + let (ws_stream, _response) = connect_async(request).await?; + let (mut write, mut read) = ws_stream.split(); + + // Send daemon info after connection (server authenticated us via header) + let info_msg = DaemonMessage::authenticate( + &self.config.api_key, + &self.machine_id, + &self.hostname, + self.max_concurrent_tasks, + ); + let info_json = serde_json::to_string(&info_msg)?; + write.send(Message::Text(info_json)).await?; + + // Wait for authentication response + let auth_response = read + .next() + .await + .ok_or(DaemonError::ConnectionLost)??; + + let auth_text = match auth_response { + Message::Text(text) => text, + Message::Close(_) => return Err(DaemonError::ConnectionLost), + _ => return Err(DaemonError::AuthFailed("Unexpected response type".into())), + }; + + let command: DaemonCommand = serde_json::from_str(&auth_text)?; + match command { + DaemonCommand::Authenticated { daemon_id } => { + tracing::info!("Authenticated with daemon ID: {}", daemon_id); + *self.daemon_id.write().await = Some(daemon_id); + *self.state.write().await = ConnectionState::Connected; + + // Send daemon directories info to server + let working_directory = std::env::current_dir() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|_| ".".to_string()); + let home_directory = dirs::home_dir() + .map(|h| h.join(".makima").join("home")) + .unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home")); + // Create home directory if it doesn't exist + if let Err(e) = std::fs::create_dir_all(&home_directory) { + tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e); + } + let home_directory_str = home_directory.to_string_lossy().to_string(); + let worktrees_directory = dirs::home_dir() + .map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string()) + .unwrap_or_else(|| "~/.makima/worktrees".to_string()); + + let dirs_msg = DaemonMessage::DaemonDirectories { + working_directory, + home_directory: home_directory_str, + worktrees_directory, + }; + let dirs_json = serde_json::to_string(&dirs_msg)?; + write.send(Message::Text(dirs_json)).await?; + tracing::info!("Sent daemon directories info to server"); + } + DaemonCommand::Error { code, message } => { + return Err(DaemonError::AuthFailed(format!("{}: {}", code, message))); + } + _ => { + return Err(DaemonError::AuthFailed( + "Unexpected response to authentication".into(), + )); + } + } + + // Start main message loop + let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs); + let mut heartbeat_timer = tokio::time::interval(heartbeat_interval); + + loop { + tokio::select! { + // Handle incoming server commands + msg = read.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + tracing::info!("Received WebSocket message: {} bytes", text.len()); + match serde_json::from_str::<DaemonCommand>(&text) { + Ok(command) => { + tracing::info!("Parsed command: {:?}", command); + tracing::info!("Sending command to task manager channel..."); + if self.incoming_tx.send(command).await.is_err() { + tracing::warn!("Command receiver dropped, shutting down"); + break; + } + tracing::info!("Command sent to task manager successfully"); + } + Err(e) => { + tracing::warn!("Failed to parse server message: {}", e); + tracing::debug!("Raw message: {}", text); + } + } + } + Some(Ok(Message::Ping(data))) => { + write.send(Message::Pong(data)).await?; + } + Some(Ok(Message::Close(_))) | None => { + tracing::info!("Server closed connection"); + return Err(DaemonError::ConnectionLost); + } + Some(Err(e)) => { + tracing::warn!("WebSocket error: {}", e); + return Err(e.into()); + } + _ => {} + } + } + + // Handle outgoing messages + msg = self.outgoing_rx.recv() => { + match msg { + Some(message) => { + let json = serde_json::to_string(&message)?; + tracing::trace!("Sending message: {}", json); + write.send(Message::Text(json)).await?; + } + None => { + // Sender dropped, shutdown + tracing::info!("Outgoing channel closed, shutting down"); + break; + } + } + } + + // Send heartbeat + _ = heartbeat_timer.tick() => { + // Get active task IDs from task manager + // For now, send empty list - will be connected to task manager + let heartbeat = DaemonMessage::heartbeat(vec![]); + let json = serde_json::to_string(&heartbeat)?; + write.send(Message::Text(json)).await?; + } + } + } + + Ok(()) + } +} diff --git a/makima/daemon/src/ws/mod.rs b/makima/daemon/src/ws/mod.rs new file mode 100644 index 0000000..5a0e9d1 --- /dev/null +++ b/makima/daemon/src/ws/mod.rs @@ -0,0 +1,7 @@ +//! WebSocket client and protocol types for daemon-server communication. + +pub mod client; +pub mod protocol; + +pub use client::{ConnectionState, WsClient}; +pub use protocol::{BranchInfo, DaemonCommand, DaemonMessage}; diff --git a/makima/daemon/src/ws/protocol.rs b/makima/daemon/src/ws/protocol.rs new file mode 100644 index 0000000..7c2ad6d --- /dev/null +++ b/makima/daemon/src/ws/protocol.rs @@ -0,0 +1,511 @@ +//! Protocol types for daemon-server communication. +//! +//! These types mirror the server's protocol exactly for compatibility. + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Message from daemon to server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonMessage { + /// Authentication request (first message required). + Authenticate { + #[serde(rename = "apiKey")] + api_key: String, + #[serde(rename = "machineId")] + machine_id: String, + hostname: String, + #[serde(rename = "maxConcurrentTasks")] + max_concurrent_tasks: i32, + }, + + /// Periodic heartbeat with current status. + Heartbeat { + #[serde(rename = "activeTasks")] + active_tasks: Vec<Uuid>, + }, + + /// Task output streaming (stdout/stderr from Claude Code). + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + output: String, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + + /// Task status change notification. + TaskStatusChange { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "oldStatus")] + old_status: String, + #[serde(rename = "newStatus")] + new_status: String, + }, + + /// Task progress update with summary. + TaskProgress { + #[serde(rename = "taskId")] + task_id: Uuid, + summary: String, + }, + + /// Task completion notification. + TaskComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + error: Option<String>, + }, + + /// Register a tool key for orchestrator API access. + RegisterToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + /// The API key for this orchestrator to use when calling mesh endpoints. + key: String, + }, + + /// Revoke a tool key when task completes. + RevokeToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + // ========================================================================= + // Merge Response Messages (sent by daemon after processing merge commands) + // ========================================================================= + + /// Response to ListBranches command. + BranchList { + #[serde(rename = "taskId")] + task_id: Uuid, + branches: Vec<BranchInfo>, + }, + + /// Response to MergeStatus command. + MergeStatusResponse { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "inProgress")] + in_progress: bool, + #[serde(rename = "sourceBranch")] + source_branch: Option<String>, + #[serde(rename = "conflictedFiles")] + conflicted_files: Vec<String>, + }, + + /// Response to merge operations (MergeStart, MergeResolve, MergeCommit, MergeAbort). + MergeResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + #[serde(rename = "commitSha")] + commit_sha: Option<String>, + /// Present only when conflicts occurred. + conflicts: Option<Vec<String>>, + }, + + /// Response to CheckMergeComplete command. + MergeCompleteCheck { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "canComplete")] + can_complete: bool, + #[serde(rename = "unmergedBranches")] + unmerged_branches: Vec<String>, + #[serde(rename = "mergedCount")] + merged_count: u32, + #[serde(rename = "skippedCount")] + skipped_count: u32, + }, + + // ========================================================================= + // Completion Action Response Messages + // ========================================================================= + + /// Response to RetryCompletionAction command. + CompletionActionResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// PR URL if action was "pr" and successful. + #[serde(rename = "prUrl")] + pr_url: Option<String>, + }, + + /// Report daemon's available directories for task output. + DaemonDirectories { + /// Current working directory of the daemon. + #[serde(rename = "workingDirectory")] + working_directory: String, + /// Path to ~/.makima/home directory (for cloning completed work). + #[serde(rename = "homeDirectory")] + home_directory: String, + /// Path to worktrees directory (~/.makima/worktrees). + #[serde(rename = "worktreesDirectory")] + worktrees_directory: String, + }, + + /// Response to CloneWorktree command. + CloneWorktreeResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// The path where the worktree was cloned. + #[serde(rename = "targetDir")] + target_dir: Option<String>, + }, + + /// Response to CheckTargetExists command. + CheckTargetExistsResult { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Whether the target directory exists. + exists: bool, + /// The path that was checked. + #[serde(rename = "targetDir")] + target_dir: String, + }, +} + +/// Information about a branch (used in BranchList message). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BranchInfo { + /// Full branch name. + pub name: String, + /// Task ID extracted from branch name (if parseable). + #[serde(rename = "taskId")] + pub task_id: Option<Uuid>, + /// Whether this branch has been merged. + #[serde(rename = "isMerged")] + pub is_merged: bool, + /// Short SHA of the last commit. + #[serde(rename = "lastCommit")] + pub last_commit: String, + /// Subject line of the last commit. + #[serde(rename = "lastCommitMessage")] + pub last_commit_message: String, +} + +/// Command from server to daemon. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonCommand { + /// Confirm successful authentication. + Authenticated { + #[serde(rename = "daemonId")] + daemon_id: Uuid, + }, + + /// Spawn a new task in a container. + SpawnTask { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages). + #[serde(rename = "taskName")] + task_name: String, + plan: String, + #[serde(rename = "repoUrl")] + repo_url: Option<String>, + #[serde(rename = "baseBranch")] + base_branch: Option<String>, + /// Target branch to merge into (used for completion actions). + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + /// Parent task ID if this is a subtask. + #[serde(rename = "parentTaskId")] + parent_task_id: Option<Uuid>, + /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask). + depth: i32, + /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks). + #[serde(rename = "isOrchestrator")] + is_orchestrator: bool, + /// Path to user's local repository (outside ~/.makima) for completion actions. + #[serde(rename = "targetRepoPath")] + target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr". + #[serde(rename = "completionAction")] + completion_action: Option<String>, + /// Task ID to continue from (copy worktree from this task). + #[serde(rename = "continueFromTaskId")] + continue_from_task_id: Option<Uuid>, + /// Files to copy from parent task's worktree. + #[serde(rename = "copyFiles")] + copy_files: Option<Vec<String>>, + }, + + /// Pause a running task. + PauseTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + /// Resume a paused task. + ResumeTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + /// Interrupt a task (gracefully or forced). + InterruptTask { + #[serde(rename = "taskId")] + task_id: Uuid, + graceful: bool, + }, + + /// Send a message to a running task. + SendMessage { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + + /// Inject context about sibling task progress. + InjectSiblingContext { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "siblingTaskId")] + sibling_task_id: Uuid, + #[serde(rename = "siblingName")] + sibling_name: String, + #[serde(rename = "siblingStatus")] + sibling_status: String, + #[serde(rename = "progressSummary")] + progress_summary: Option<String>, + #[serde(rename = "changedFiles")] + changed_files: Vec<String>, + }, + + // ========================================================================= + // Merge Commands (for orchestrators to merge subtask branches) + // ========================================================================= + + /// List all subtask branches for a task. + ListBranches { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + /// Start merging a subtask branch. + MergeStart { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "sourceBranch")] + source_branch: String, + }, + + /// Get current merge status. + MergeStatus { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + /// Resolve a merge conflict. + MergeResolve { + #[serde(rename = "taskId")] + task_id: Uuid, + file: String, + /// "ours" or "theirs" + strategy: String, + }, + + /// Commit the current merge. + MergeCommit { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + + /// Abort the current merge. + MergeAbort { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + /// Skip merging a subtask branch (mark as intentionally not merged). + MergeSkip { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "subtaskId")] + subtask_id: Uuid, + reason: String, + }, + + /// Check if all subtask branches have been merged or skipped (completion gate). + CheckMergeComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + // ========================================================================= + // Completion Action Commands + // ========================================================================= + + /// Retry a completion action for a completed task. + RetryCompletionAction { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages). + #[serde(rename = "taskName")] + task_name: String, + /// The action to execute: "branch", "merge", or "pr". + action: String, + /// Path to the target repository. + #[serde(rename = "targetRepoPath")] + target_repo_path: String, + /// Target branch to merge into (for merge/pr actions). + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + }, + + /// Clone worktree to a target directory. + CloneWorktree { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to the target directory. + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Check if a target directory exists. + CheckTargetExists { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to check. + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Error response. + Error { + code: String, + message: String, + }, +} + +impl DaemonMessage { + /// Create an authentication message. + pub fn authenticate( + api_key: &str, + machine_id: &str, + hostname: &str, + max_concurrent_tasks: i32, + ) -> Self { + Self::Authenticate { + api_key: api_key.to_string(), + machine_id: machine_id.to_string(), + hostname: hostname.to_string(), + max_concurrent_tasks, + } + } + + /// Create a heartbeat message. + pub fn heartbeat(active_tasks: Vec<Uuid>) -> Self { + Self::Heartbeat { active_tasks } + } + + /// Create a task output message. + pub fn task_output(task_id: Uuid, output: String, is_partial: bool) -> Self { + Self::TaskOutput { + task_id, + output, + is_partial, + } + } + + /// Create a task status change message. + pub fn task_status_change(task_id: Uuid, old_status: &str, new_status: &str) -> Self { + Self::TaskStatusChange { + task_id, + old_status: old_status.to_string(), + new_status: new_status.to_string(), + } + } + + /// Create a task progress message. + pub fn task_progress(task_id: Uuid, summary: String) -> Self { + Self::TaskProgress { task_id, summary } + } + + /// Create a task complete message. + pub fn task_complete(task_id: Uuid, success: bool, error: Option<String>) -> Self { + Self::TaskComplete { + task_id, + success, + error, + } + } + + /// Create a register tool key message. + pub fn register_tool_key(task_id: Uuid, key: String) -> Self { + Self::RegisterToolKey { task_id, key } + } + + /// Create a revoke tool key message. + pub fn revoke_tool_key(task_id: Uuid) -> Self { + Self::RevokeToolKey { task_id } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_daemon_message_serialization() { + let msg = DaemonMessage::authenticate("key123", "machine-abc", "worker-1", 4); + let json = serde_json::to_string(&msg).unwrap(); + assert!(json.contains("\"type\":\"authenticate\"")); + assert!(json.contains("\"apiKey\":\"key123\"")); + assert!(json.contains("\"machineId\":\"machine-abc\"")); + } + + #[test] + fn test_daemon_command_deserialization() { + let json = r#"{"type":"spawnTask","taskId":"550e8400-e29b-41d4-a716-446655440000","plan":"Build the feature","repoUrl":"https://github.com/test/repo","baseBranch":"main","parentTaskId":null,"depth":0,"isOrchestrator":false}"#; + let cmd: DaemonCommand = serde_json::from_str(json).unwrap(); + match cmd { + DaemonCommand::SpawnTask { + plan, + repo_url, + base_branch, + parent_task_id, + depth, + is_orchestrator, + .. + } => { + assert_eq!(plan, "Build the feature"); + assert_eq!(repo_url, Some("https://github.com/test/repo".to_string())); + assert_eq!(base_branch, Some("main".to_string())); + assert_eq!(parent_task_id, None); + assert_eq!(depth, 0); + assert!(!is_orchestrator); + } + _ => panic!("Expected SpawnTask"), + } + } + + #[test] + fn test_orchestrator_spawn_deserialization() { + let json = r#"{"type":"spawnTask","taskId":"550e8400-e29b-41d4-a716-446655440000","plan":"Coordinate subtasks","repoUrl":"https://github.com/test/repo","baseBranch":"main","parentTaskId":null,"depth":0,"isOrchestrator":true}"#; + let cmd: DaemonCommand = serde_json::from_str(json).unwrap(); + match cmd { + DaemonCommand::SpawnTask { + is_orchestrator, + depth, + .. + } => { + assert!(is_orchestrator); + assert_eq!(depth, 0); + } + _ => panic!("Expected SpawnTask"), + } + } +} |
