diff options
Diffstat (limited to 'makima/src/server/handlers/mesh_daemon.rs')
| -rw-r--r-- | makima/src/server/handlers/mesh_daemon.rs | 959 |
1 files changed, 959 insertions, 0 deletions
diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs new file mode 100644 index 0000000..644d0bc --- /dev/null +++ b/makima/src/server/handlers/mesh_daemon.rs @@ -0,0 +1,959 @@ +//! WebSocket handler for daemon connections. +//! +//! Daemons connect to report task progress, stream output, and receive commands. +//! Each daemon manages Claude Code containers on its local machine. +//! +//! ## Authentication +//! +//! Daemons authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +//! The API key is validated against the database and the daemon is associated with +//! the corresponding owner_id for data isolation. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use futures::{SinkExt, StreamExt}; +use serde::Deserialize; +use sqlx::Row; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::db::repository; +use crate::server::auth::{hash_api_key, API_KEY_HEADER}; +use crate::server::messages::ApiError; +use crate::server::state::{ + DaemonCommand, SharedState, TaskOutputNotification, TaskUpdateNotification, +}; + +// ============================================================================= +// Claude Code JSON Output Parsing +// ============================================================================= + +/// Claude Code stream-json message structure +#[derive(Debug, Deserialize)] +struct ClaudeMessage { + #[serde(rename = "type")] + msg_type: String, + subtype: Option<String>, + message: Option<ClaudeMessageContent>, + tool_name: Option<String>, + tool_input: Option<serde_json::Value>, + tool_result: Option<ClaudeToolResult>, + result: Option<String>, + cost_usd: Option<f64>, + duration_ms: Option<u64>, + error: Option<String>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeMessageContent { + content: Option<Vec<ClaudeContentBlock>>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeContentBlock { + #[serde(rename = "type")] + block_type: String, + text: Option<String>, + name: Option<String>, + input: Option<serde_json::Value>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeToolResult { + content: Option<String>, + is_error: Option<bool>, +} + +/// Parse a line of Claude Code output into a structured notification +fn parse_claude_output(task_id: Uuid, owner_id: Uuid, line: &str, is_partial: bool) -> Option<TaskOutputNotification> { + let trimmed = line.trim(); + if trimmed.is_empty() { + return None; + } + + // Try to parse as JSON + if trimmed.starts_with('{') { + if let Ok(msg) = serde_json::from_str::<ClaudeMessage>(trimmed) { + return parse_claude_message(task_id, owner_id, msg, is_partial); + } + } + + // Not JSON or failed to parse - treat as raw output + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "raw".to_string(), + content: trimmed.to_string(), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) +} + +fn parse_claude_message(task_id: Uuid, owner_id: Uuid, msg: ClaudeMessage, is_partial: bool) -> Option<TaskOutputNotification> { + match msg.msg_type.as_str() { + "system" => { + // System messages (init, etc.) - include subtype info + let content = match msg.subtype.as_deref() { + Some("init") => "Session started".to_string(), + Some(sub) => format!("System: {}", sub), + None => "System message".to_string(), + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "assistant" => { + // Extract text content from message blocks + if let Some(message) = msg.message { + if let Some(blocks) = message.content { + // Check for text blocks + let text_content: Vec<String> = blocks + .iter() + .filter(|b| b.block_type == "text") + .filter_map(|b| b.text.clone()) + .collect(); + + if !text_content.is_empty() { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "assistant".to_string(), + content: text_content.join("\n"), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + + // Check for tool_use blocks + if let Some(tool_block) = blocks.iter().find(|b| b.block_type == "tool_use") { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", tool_block.name.as_deref().unwrap_or("unknown")), + tool_name: tool_block.name.clone(), + tool_input: tool_block.input.clone(), + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + } + } + None + } + + "tool_use" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", msg.tool_name.as_deref().unwrap_or("unknown")), + tool_name: msg.tool_name, + tool_input: msg.tool_input, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "tool_result" => { + if let Some(result) = msg.tool_result { + let content = result.content.unwrap_or_default(); + // Truncate long results + let content = if content.len() > 500 { + format!("{}...", &content[..500]) + } else { + content + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_result".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: result.is_error, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } else { + None + } + } + + "result" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "result".to_string(), + content: msg.result.unwrap_or_else(|| "Task completed".to_string()), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: msg.cost_usd, + duration_ms: msg.duration_ms, + is_partial, + }) + } + + "error" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "error".to_string(), + content: msg.error.unwrap_or_else(|| "An error occurred".to_string()), + tool_name: None, + tool_input: None, + is_error: Some(true), + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + _ => None, // Skip unknown message types + } +} + +/// Message from daemon to server. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonMessage { + /// Authentication request (first message required) + Authenticate { + #[serde(rename = "apiKey")] + api_key: String, + #[serde(rename = "machineId")] + machine_id: String, + hostname: String, + #[serde(rename = "maxConcurrentTasks")] + max_concurrent_tasks: i32, + }, + /// Periodic heartbeat with current status + Heartbeat { + #[serde(rename = "activeTasks")] + active_tasks: Vec<Uuid>, + }, + /// Task output streaming (stdout/stderr from Claude Code) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + output: String, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Task status change notification + TaskStatusChange { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "oldStatus")] + old_status: String, + #[serde(rename = "newStatus")] + new_status: String, + }, + /// Task progress update with summary + TaskProgress { + #[serde(rename = "taskId")] + task_id: Uuid, + summary: String, + }, + /// Task completion notification + TaskComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + error: Option<String>, + }, + /// Register a tool key for orchestrator API access + RegisterToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + /// The API key for this orchestrator to use when calling mesh endpoints + key: String, + }, + /// Revoke a tool key when task completes + RevokeToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Response to RetryCompletionAction command + CompletionActionResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// PR URL if action was "pr" and successful + #[serde(rename = "prUrl")] + pr_url: Option<String>, + }, + /// Report daemon's available directories for task output + DaemonDirectories { + /// Current working directory of the daemon + #[serde(rename = "workingDirectory")] + working_directory: String, + /// Path to ~/.makima/home directory (for cloning completed work) + #[serde(rename = "homeDirectory")] + home_directory: String, + /// Path to worktrees directory (~/.makima/worktrees) + #[serde(rename = "worktreesDirectory")] + worktrees_directory: String, + }, + /// Response to CloneWorktree command + CloneWorktreeResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// The path where the worktree was cloned + #[serde(rename = "targetDir")] + target_dir: Option<String>, + }, + /// Response to CheckTargetExists command + CheckTargetExistsResult { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Whether the target directory exists + exists: bool, + /// The path that was checked + #[serde(rename = "targetDir")] + target_dir: String, + }, +} + +/// Validated daemon authentication result. +#[derive(Debug, Clone)] +struct DaemonAuthResult { + /// User ID from the API key + user_id: Uuid, + /// Owner ID for data isolation + owner_id: Uuid, +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_daemon_api_key(pool: &sqlx::PgPool, key: &str) -> Result<DaemonAuthResult, String> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| format!("Database error: {}", e))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| format!("Failed to get user_id: {}", e))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| format!("Failed to get owner_id: {}", e))?; + let owner_id = owner_id.ok_or_else(|| "User has no default owner".to_string())?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok(DaemonAuthResult { user_id, owner_id }) + } + None => Err("Invalid or revoked API key".to_string()), + } +} + +/// WebSocket upgrade handler for daemon connections. +/// +/// Daemons must authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +/// The API key is validated against the database and used to determine the owner_id +/// for data isolation. +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/connect", + params( + ("X-Api-Key" = String, Header, description = "API key for daemon authentication"), + ), + responses( + (status = 101, description = "WebSocket connection established"), + (status = 401, description = "Missing or invalid API key"), + (status = 503, description = "Database not configured"), + ), + tag = "Mesh" +)] +pub async fn daemon_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, + headers: HeaderMap, +) -> Response { + // Extract API key from headers + let api_key = match headers.get(API_KEY_HEADER).or_else(|| headers.get("x-api-key")) { + Some(value) => match value.to_str() { + Ok(key) if !key.is_empty() => key.to_string(), + _ => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("INVALID_API_KEY", "Invalid API key header value")), + ) + .into_response(); + } + }, + None => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("MISSING_API_KEY", "X-Api-Key header required")), + ) + .into_response(); + } + }; + + // Validate API key against database + let pool = match state.db_pool.as_ref() { + Some(pool) => pool, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + axum::Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + } + }; + + let auth_result = match validate_daemon_api_key(pool, &api_key).await { + Ok(result) => result, + Err(e) => { + tracing::warn!("Daemon authentication failed: {}", e); + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("AUTH_FAILED", e)), + ) + .into_response(); + } + }; + + tracing::info!( + user_id = %auth_result.user_id, + owner_id = %auth_result.owner_id, + "Daemon authenticated via API key" + ); + + ws.on_upgrade(move |socket| handle_daemon_connection(socket, state, auth_result)) +} + +async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_result: DaemonAuthResult) { + let (mut sender, mut receiver) = socket.split(); + + // Generate a unique connection ID and daemon ID + let connection_id = Uuid::new_v4().to_string(); + let daemon_id = Uuid::new_v4(); + let owner_id = auth_result.owner_id; + + // Create command channel for sending commands to this daemon + let (cmd_tx, mut cmd_rx) = mpsc::channel::<DaemonCommand>(64); + + // Wait for the daemon to send its registration info (hostname, machine_id, etc.) + // The daemon is already authenticated via API key header, but we need metadata + #[allow(unused_assignments)] + let mut registered = false; + + // Wait for registration message with metadata + loop { + tokio::select! { + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Authenticate { api_key: _, machine_id, hostname, max_concurrent_tasks }) => { + // API key was already validated via headers, but we use this message + // for backward compatibility to get the machine_id and hostname + + tracing::info!( + daemon_id = %daemon_id, + owner_id = %owner_id, + hostname = %hostname, + machine_id = %machine_id, + max_concurrent_tasks = max_concurrent_tasks, + "Daemon registered" + ); + + // Register daemon in state with owner_id + state.register_daemon( + connection_id.clone(), + daemon_id, + owner_id, + Some(hostname), + Some(machine_id), + cmd_tx.clone(), + ); + + registered = true; + + // Send authentication confirmation + let response = DaemonCommand::Authenticated { daemon_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + + break; // Exit registration loop, continue to main loop + } + Ok(_) => { + // Non-auth message before registration - still requires registration message + let response = DaemonCommand::Error { + code: "NOT_REGISTERED".into(), + message: "Must send registration message (Authenticate) first".into(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + Err(e) => { + let response = DaemonCommand::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Daemon disconnected during registration"); + return; + } + Some(Err(e)) => { + tracing::warn!("Daemon WebSocket error during registration: {}", e); + return; + } + _ => {} + } + } + } + } + + if !registered { + return; + } + + let daemon_uuid = daemon_id; + + // Main message loop after authentication + loop { + tokio::select! { + // Handle incoming messages from daemon + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Heartbeat { active_tasks }) => { + tracing::trace!( + "Daemon {} heartbeat: {} active tasks", + daemon_uuid, active_tasks.len() + ); + // TODO: Update daemon last_heartbeat_at in DB + } + Ok(DaemonMessage::TaskOutput { task_id, output, is_partial }) => { + // Parse the output line and broadcast structured data + if let Some(notification) = parse_claude_output(task_id, owner_id, &output, is_partial) { + // Broadcast to connected clients + state.broadcast_task_output(notification.clone()); + + // Persist to database (fire and forget) + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let notification = notification.clone(); + tokio::spawn(async move { + if let Err(e) = repository::save_task_output( + &pool, + notification.task_id, + ¬ification.message_type, + ¬ification.content, + notification.tool_name.as_deref(), + notification.tool_input.clone(), + notification.is_error, + notification.cost_usd, + notification.duration_ms, + ).await { + tracing::warn!( + task_id = %notification.task_id, + "Failed to persist task output: {}", + e + ); + } + }); + } + } + } + Ok(DaemonMessage::TaskStatusChange { task_id, old_status, new_status }) => { + tracing::info!( + "Task {} status change: {} -> {}", + task_id, old_status, new_status + ); + + // Update task status in database and broadcast + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let new_status_owned = new_status.clone(); + tokio::spawn(async move { + match repository::update_task_status( + &pool, + task_id, + &new_status_owned, + None, + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: new_status_owned, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when updating status" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to update task status: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: new_status, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::TaskProgress { task_id, summary }) => { + tracing::debug!("Task {} progress: {}", task_id, summary); + // TODO: Update task progress_summary in database + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: "running".into(), + updated_fields: vec!["progress_summary".into()], + updated_by: "daemon".into(), + }); + } + Ok(DaemonMessage::TaskComplete { task_id, success, error }) => { + let status = if success { "done" } else { "failed" }; + tracing::info!( + "Task {} completed: success={}, error={:?}", + task_id, success, error + ); + + // Revoke any tool keys for this task + state.revoke_tool_key(task_id); + + // Update task in database with completion info + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let error_clone = error.clone(); + tokio::spawn(async move { + match repository::complete_task( + &pool, + task_id, + success, + error_clone.as_deref(), + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: updated_task.status.clone(), + updated_fields: vec![ + "status".into(), + "completed_at".into(), + "error_message".into(), + ], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when completing" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to complete task: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: status.into(), + updated_fields: vec!["status".into(), "completed_at".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::Authenticate { .. }) => { + // Already authenticated, ignore + } + Ok(DaemonMessage::RegisterToolKey { task_id, key }) => { + tracing::info!( + task_id = %task_id, + "Registering tool key for orchestrator" + ); + state.register_tool_key(key, task_id); + } + Ok(DaemonMessage::RevokeToolKey { task_id }) => { + tracing::info!( + task_id = %task_id, + "Revoking tool key for task" + ); + state.revoke_tool_key(task_id); + } + Ok(DaemonMessage::DaemonDirectories { working_directory, home_directory, worktrees_directory }) => { + tracing::info!( + daemon_id = %daemon_uuid, + working_directory = %working_directory, + home_directory = %home_directory, + worktrees_directory = %worktrees_directory, + "Daemon directories received" + ); + state.update_daemon_directories( + &connection_id, + working_directory, + home_directory, + worktrees_directory, + ); + } + Ok(DaemonMessage::CompletionActionResult { task_id, success, message, pr_url }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + pr_url = ?pr_url, + "Completion action result received" + ); + + // Update task with PR URL if created + if let Some(ref url) = pr_url { + if let Some(ref pool) = state.db_pool { + let update_req = crate::db::models::UpdateTaskRequest { + pr_url: Some(url.clone()), + ..Default::default() + }; + if let Err(e) = crate::db::repository::update_task(pool, task_id, update_req).await { + tracing::error!("Failed to update task PR URL: {}", e); + } + } + } + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Completion action succeeded: {}", message) + } else { + format!("✗ Completion action failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CloneWorktreeResult { task_id, success, message, target_dir }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + target_dir = ?target_dir, + "Clone worktree result received" + ); + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Clone succeeded: {}", message) + } else { + format!("✗ Clone failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CheckTargetExistsResult { task_id, exists, target_dir }) => { + tracing::debug!( + task_id = %task_id, + exists = exists, + target_dir = %target_dir, + "Check target exists result received" + ); + + // Broadcast as task output so UI can use the result + let output_text = if exists { + format!("Target directory exists: {}", target_dir) + } else { + format!("Target directory does not exist: {}", target_dir) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Err(e) => { + tracing::warn!("Failed to parse daemon message: {}", e); + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::info!("Daemon {} disconnected", daemon_uuid); + break; + } + Some(Err(e)) => { + tracing::warn!("Daemon {} WebSocket error: {}", daemon_uuid, e); + break; + } + _ => {} + } + } + + // Handle commands to send to daemon + cmd = cmd_rx.recv() => { + match cmd { + Some(command) => { + let json = serde_json::to_string(&command).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + tracing::warn!("Failed to send command to daemon {}", daemon_uuid); + break; + } + } + None => { + // Channel closed + break; + } + } + } + } + } + + // Cleanup on disconnect + state.unregister_daemon(&connection_id); + + // Clear daemon_id from any tasks that were running on this daemon + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + tokio::spawn(async move { + // Find tasks assigned to this daemon that are still active + if let Err(e) = clear_daemon_from_tasks(&pool, daemon_uuid).await { + tracing::error!( + daemon_id = %daemon_uuid, + error = %e, + "Failed to clear daemon from tasks on disconnect" + ); + } + }); + } +} + +/// Clear daemon_id from tasks when daemon disconnects +async fn clear_daemon_from_tasks(pool: &sqlx::PgPool, daemon_id: Uuid) -> Result<(), sqlx::Error> { + // Update tasks that were running on this daemon to failed state + let result = sqlx::query( + r#" + UPDATE tasks + SET daemon_id = NULL, + status = 'failed', + error_message = 'Daemon disconnected', + updated_at = NOW() + WHERE daemon_id = $1 + AND status IN ('starting', 'running', 'initializing') + "#, + ) + .bind(daemon_id) + .execute(pool) + .await?; + + if result.rows_affected() > 0 { + tracing::warn!( + daemon_id = %daemon_id, + tasks_affected = result.rows_affected(), + "Marked tasks as failed due to daemon disconnect" + ); + } + + Ok(()) +} |
