diff options
| author | soryu <soryu@soryu.co> | 2026-02-01 00:47:02 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2026-02-01 00:47:02 +0000 |
| commit | 999ecf644f58af7de0b0a36b22a69897d8056a1c (patch) | |
| tree | cd294c110c753034b5f9137b80a4aa493dd1a969 | |
| parent | 10d9b4ce345ac74161108818ad5532a74336cc3d (diff) | |
| download | soryu-999ecf644f58af7de0b0a36b22a69897d8056a1c.tar.gz soryu-999ecf644f58af7de0b0a36b22a69897d8056a1c.zip | |
[WIP] Heartbeat checkpoint - 2026-02-01 00:47:02 UTC
| -rw-r--r-- | makima/src/daemon/ws/protocol.rs | 103 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_daemon.rs | 352 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_supervisor.rs | 507 |
3 files changed, 960 insertions, 2 deletions
diff --git a/makima/src/daemon/ws/protocol.rs b/makima/src/daemon/ws/protocol.rs index bfe6326..8c2994a 100644 --- a/makima/src/daemon/ws/protocol.rs +++ b/makima/src/daemon/ws/protocol.rs @@ -397,6 +397,63 @@ pub enum DaemonMessage { #[serde(rename = "baseSha")] base_sha: String, }, + + // ========================================================================= + // Supervisor State Update Messages (Phase 3: Crash Recovery) + // ========================================================================= + + /// Supervisor state update for crash recovery. + /// Sent periodically or at key save points to persist state. + SupervisorStateUpdate { + /// Task ID of the supervisor. + #[serde(rename = "taskId")] + task_id: Uuid, + /// Contract ID. + #[serde(rename = "contractId")] + contract_id: Uuid, + /// Save point type that triggered this update. + #[serde(rename = "savePoint")] + save_point: String, + /// Current supervisor activity state. + state: Option<String>, + /// Human-readable current activity. + #[serde(rename = "currentActivity")] + current_activity: Option<String>, + /// Progress percentage (0-100). + progress: Option<i32>, + /// Last LLM response for context restoration. + #[serde(rename = "lastLlmResponse")] + last_llm_response: Option<String>, + /// Task that was just spawned (if save_point is "task_spawn"). + #[serde(rename = "spawnedTaskId")] + spawned_task_id: Option<Uuid>, + /// Question ID (if save_point is "question_asked"). + #[serde(rename = "questionId")] + question_id: Option<Uuid>, + /// Error message (if state is "error"). + #[serde(rename = "errorMessage")] + error_message: Option<String>, + /// Updated conversation history (sent on llm_response save points). + #[serde(rename = "conversationHistory")] + conversation_history: Option<serde_json::Value>, + }, + + /// Supervisor heartbeat for lightweight state updates. + SupervisorHeartbeat { + /// Task ID of the supervisor. + #[serde(rename = "taskId")] + task_id: Uuid, + /// Contract ID. + #[serde(rename = "contractId")] + contract_id: Uuid, + /// Current state (optional). + state: Option<String>, + /// Current activity description (optional). + #[serde(rename = "currentActivity")] + current_activity: Option<String>, + /// Progress percentage (optional). + progress: Option<i32>, + }, } /// Information about a branch (used in BranchList message). @@ -857,6 +914,52 @@ impl DaemonMessage { pub fn revoke_tool_key(task_id: Uuid) -> Self { Self::RevokeToolKey { task_id } } + + /// Create a supervisor state update message. + pub fn supervisor_state_update( + task_id: Uuid, + contract_id: Uuid, + save_point: &str, + state: Option<&str>, + current_activity: Option<&str>, + progress: Option<i32>, + last_llm_response: Option<&str>, + spawned_task_id: Option<Uuid>, + question_id: Option<Uuid>, + error_message: Option<&str>, + conversation_history: Option<serde_json::Value>, + ) -> Self { + Self::SupervisorStateUpdate { + task_id, + contract_id, + save_point: save_point.to_string(), + state: state.map(|s| s.to_string()), + current_activity: current_activity.map(|s| s.to_string()), + progress, + last_llm_response: last_llm_response.map(|s| s.to_string()), + spawned_task_id, + question_id, + error_message: error_message.map(|s| s.to_string()), + conversation_history, + } + } + + /// Create a supervisor heartbeat message. + pub fn supervisor_heartbeat( + task_id: Uuid, + contract_id: Uuid, + state: Option<&str>, + current_activity: Option<&str>, + progress: Option<i32>, + ) -> Self { + Self::SupervisorHeartbeat { + task_id, + contract_id, + state: state.map(|s| s.to_string()), + current_activity: current_activity.map(|s| s.to_string()), + progress, + } + } } #[cfg(test)] diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs index 1152502..4c6a045 100644 --- a/makima/src/server/handlers/mesh_daemon.rs +++ b/makima/src/server/handlers/mesh_daemon.rs @@ -1190,8 +1190,32 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re let pool = pool.clone(); let state = state.clone(); tokio::spawn(async move { - if worktree_intact { - // Worktree exists - task can be resumed on this daemon + // First, get the task to check if it's a supervisor + let task = match repository::get_task(&pool, task_id).await { + Ok(Some(t)) => t, + Ok(None) => { + tracing::warn!(task_id = %task_id, "Task not found during recovery"); + return; + } + Err(e) => { + tracing::error!(task_id = %task_id, error = %e, "Failed to get task during recovery"); + return; + } + }; + + // Handle supervisor-specific recovery + if task.is_supervisor { + handle_supervisor_recovery( + &pool, + &state, + task_id, + task.contract_id, + owner_id, + worktree_intact, + &previous_state, + ).await; + } else if worktree_intact { + // Regular task - worktree exists, task can be resumed on this daemon // Update task status to 'pending' so it can be picked up match sqlx::query( r#" @@ -1973,6 +1997,98 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re let _ = tx.send(response); } } + Ok(DaemonMessage::SupervisorStateUpdate { + task_id, + contract_id, + save_point, + state: supervisor_state, + current_activity, + progress, + last_llm_response, + spawned_task_id, + question_id, + error_message, + conversation_history, + }) => { + tracing::debug!( + task_id = %task_id, + contract_id = %contract_id, + save_point = %save_point, + "Received supervisor state update" + ); + + // Persist the state update + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + tokio::spawn(async move { + match repository::save_supervisor_state_at_savepoint( + &pool, + contract_id, + &save_point, + supervisor_state.as_deref(), + current_activity.as_deref(), + progress, + last_llm_response.as_deref(), + spawned_task_id, + question_id, + error_message.as_deref(), + conversation_history, + ).await { + Ok(_) => { + tracing::trace!( + task_id = %task_id, + contract_id = %contract_id, + save_point = %save_point, + "Supervisor state saved" + ); + } + Err(e) => { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + save_point = %save_point, + error = %e, + "Failed to save supervisor state" + ); + } + } + }); + } + } + Ok(DaemonMessage::SupervisorHeartbeat { + task_id, + contract_id, + state: supervisor_state, + current_activity, + progress, + }) => { + tracing::trace!( + task_id = %task_id, + contract_id = %contract_id, + "Received supervisor heartbeat" + ); + + // Lightweight state update + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + tokio::spawn(async move { + if let Err(e) = repository::update_supervisor_activity_state( + &pool, + contract_id, + supervisor_state.as_deref().unwrap_or("executing"), + current_activity.as_deref(), + progress, + ).await { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + error = %e, + "Failed to update supervisor activity state" + ); + } + }); + } + } Ok(DaemonMessage::MergePatchToSupervisor { task_id, supervisor_task_id, @@ -2180,3 +2296,235 @@ async fn handle_daemon_disconnect_tasks(pool: &sqlx::PgPool, daemon_id: Uuid) -> Ok(()) } + +// ============================================================================= +// Supervisor Recovery Protocol (Phase 3: Crash Recovery) +// ============================================================================= + +/// Handle supervisor-specific recovery after daemon crash/restart. +/// +/// This implements the Supervisor Restoration Protocol: +/// 1. Load supervisor state from supervisor_states +/// 2. If NOT FOUND: Start fresh, log warning +/// 3. If FOUND: Validate state consistency +/// 4. If INVALID: Start from last checkpoint +/// 5. If VALID: Restore conversation history +/// 6. Check for pending questions - re-deliver to user +/// 7. Check for waiting tasks - resume waiting state +/// 8. Send restoration context to Claude +/// 9. Resume execution from last state +async fn handle_supervisor_recovery( + pool: &sqlx::PgPool, + state: &SharedState, + task_id: Uuid, + contract_id: Option<Uuid>, + owner_id: Uuid, + worktree_intact: bool, + previous_state: &str, +) { + let Some(contract_id) = contract_id else { + tracing::warn!(task_id = %task_id, "Supervisor has no contract_id, treating as regular task"); + return; + }; + + tracing::info!( + task_id = %task_id, + contract_id = %contract_id, + worktree_intact = worktree_intact, + previous_state = %previous_state, + "Starting supervisor recovery protocol" + ); + + // Step 1-2: Load supervisor state from database + let supervisor_state = match repository::get_supervisor_state(pool, contract_id).await { + Ok(Some(s)) => s, + Ok(None) => { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + "No supervisor state found - will start fresh" + ); + // Mark task as pending for fresh start + if let Err(e) = sqlx::query( + r#" + UPDATE tasks + SET status = 'pending', + daemon_id = NULL, + error_message = 'Supervisor restarted - no previous state found', + interrupted_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(task_id) + .bind(owner_id) + .execute(pool) + .await + { + tracing::error!(task_id = %task_id, error = %e, "Failed to update supervisor for fresh start"); + } + return; + } + Err(e) => { + tracing::error!( + task_id = %task_id, + contract_id = %contract_id, + error = %e, + "Failed to load supervisor state - will start fresh" + ); + return; + } + }; + + // Step 3: Validate state consistency + let validation = match repository::validate_supervisor_state(pool, contract_id, owner_id).await { + Ok(v) => v, + Err(e) => { + tracing::error!( + task_id = %task_id, + error = %e, + "Failed to validate supervisor state" + ); + return; + } + }; + + let (is_valid, restoration_message) = match validation { + repository::StateValidationResult::Valid(_) => { + (true, "Supervisor state valid - restoring full context".to_string()) + } + repository::StateValidationResult::NotFound => { + (false, "Supervisor state not found - starting fresh".to_string()) + } + repository::StateValidationResult::Invalid { reason } => { + (false, format!("Supervisor state invalid: {} - starting fresh", reason)) + } + repository::StateValidationResult::PartiallyValid { invalid_task_ids, .. } => { + (true, format!( + "Supervisor state partially valid ({} tasks missing) - restoring with available context", + invalid_task_ids.len() + )) + } + repository::StateValidationResult::PhaseStale { current_phase, .. } => { + (true, format!( + "Supervisor state valid but phase changed to '{}' - restoring with updated phase", + current_phase + )) + } + }; + + tracing::info!( + task_id = %task_id, + contract_id = %contract_id, + is_valid = is_valid, + message = %restoration_message, + "Supervisor state validation complete" + ); + + // Step 4-5: Handle invalid state or valid restoration + if !is_valid { + // Start from last checkpoint or fresh + if let Err(e) = sqlx::query( + r#" + UPDATE tasks + SET status = 'pending', + daemon_id = NULL, + error_message = $3, + interrupted_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(task_id) + .bind(owner_id) + .bind(&restoration_message) + .execute(pool) + .await + { + tracing::error!(task_id = %task_id, error = %e, "Failed to update supervisor for checkpoint start"); + } + return; + } + + // Step 6: Check for pending questions - prepare for re-delivery + let pending_question_id = supervisor_state.pending_question_id; + if let Some(question_id) = pending_question_id { + tracing::info!( + task_id = %task_id, + question_id = %question_id, + "Supervisor has pending question - will re-deliver after restoration" + ); + } + + // Step 7: Check for waiting tasks + let pending_task_count = supervisor_state.pending_task_ids.len(); + if pending_task_count > 0 { + tracing::info!( + task_id = %task_id, + pending_task_count = pending_task_count, + "Supervisor has pending tasks - will resume waiting state" + ); + } + + // Step 8-9: Prepare restoration context and update task for resumption + let restoration_context = serde_json::json!({ + "restored": true, + "restoration_count": supervisor_state.restoration_count.unwrap_or(0) + 1, + "last_state": supervisor_state.state, + "last_activity": supervisor_state.current_activity, + "phase": supervisor_state.phase, + "pending_tasks": supervisor_state.pending_task_ids, + "has_pending_question": pending_question_id.is_some(), + "pending_question_id": pending_question_id, + "message": restoration_message, + }); + + // Update task status and prepare for resumption with restoration context + match sqlx::query( + r#" + UPDATE tasks + SET status = 'pending', + daemon_id = NULL, + error_message = $3, + interrupted_at = NOW(), + updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(task_id) + .bind(owner_id) + .bind(&restoration_message) + .execute(pool) + .await + { + Ok(_) => { + tracing::info!( + task_id = %task_id, + contract_id = %contract_id, + restoration_context = %restoration_context, + "Supervisor marked for restoration" + ); + + // Broadcast update + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: "pending".into(), + updated_fields: vec![ + "status".into(), + "daemon_id".into(), + "interrupted_at".into(), + ], + updated_by: "supervisor_recovery".into(), + }); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + error = %e, + "Failed to update supervisor for restoration" + ); + } + } +} diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs index 3411ec0..b33c1c9 100644 --- a/makima/src/server/handlers/mesh_supervisor.rs +++ b/makima/src/server/handlers/mesh_supervisor.rs @@ -2748,3 +2748,510 @@ pub async fn spawn_red_team_task( // It will remain pending and can be started later Ok(task) } + +// ============================================================================= +// Supervisor State Persistence Handlers (Phase 3: Crash Recovery) +// ============================================================================= + +/// Request to save supervisor state at a save point. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SaveStateRequest { + /// The save point type + pub save_point: String, + /// Updated conversation history (if available) + pub conversation_history: Option<serde_json::Value>, + /// Current state + pub state: Option<String>, + /// Current activity description + pub current_activity: Option<String>, + /// Progress percentage + pub progress: Option<i32>, + /// Last LLM response + pub last_llm_response: Option<String>, + /// Task that was spawned (for task_spawn save point) + pub spawned_task_id: Option<Uuid>, + /// Question that was asked (for question_asked save point) + pub question_id: Option<Uuid>, + /// Error message (for error save point) + pub error_message: Option<String>, +} + +/// Response for state save operation. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SaveStateResponse { + pub success: bool, + pub state_id: Uuid, + pub save_point: String, + pub message: String, +} + +/// Save supervisor state at a specific save point. +/// +/// This endpoint is called by the supervisor to persist its state for crash recovery. +/// State should be saved at key points: LLM response, task spawn, question asked, phase change, heartbeat. +#[utoipa::path( + post, + path = "/api/v1/mesh/supervisor/state/save", + request_body = SaveStateRequest, + responses( + (status = 200, description = "State saved", body = SaveStateResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - not a supervisor"), + (status = 500, description = "Internal server error"), + ), + security( + ("tool_key" = []) + ), + tag = "Mesh Supervisor" +)] +pub async fn save_supervisor_state( + State(state): State<SharedState>, + headers: HeaderMap, + Json(request): Json<SaveStateRequest>, +) -> impl IntoResponse { + let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await { + Ok(ids) => ids, + Err(e) => return e.into_response(), + }; + + let pool = state.db_pool.as_ref().unwrap(); + + // Get the supervisor task to find its contract + let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Supervisor task not found")), + ).into_response(); + } + Err(e) => { + tracing::error!(error = %e, "Failed to get supervisor task"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")), + ).into_response(); + } + }; + + let Some(contract_id) = supervisor.contract_id else { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")), + ).into_response(); + }; + + // Save the state at the specified save point + let result = repository::save_supervisor_state_at_savepoint( + pool, + contract_id, + &request.save_point, + request.state.as_deref(), + request.current_activity.as_deref(), + request.progress, + request.last_llm_response.as_deref(), + request.spawned_task_id, + request.question_id, + request.error_message.as_deref(), + request.conversation_history.clone(), + ).await; + + match result { + Ok(saved_state) => { + tracing::debug!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + save_point = %request.save_point, + "Supervisor state saved" + ); + ( + StatusCode::OK, + Json(SaveStateResponse { + success: true, + state_id: saved_state.id, + save_point: request.save_point, + message: "State saved successfully".to_string(), + }), + ).into_response() + } + Err(e) => { + tracing::error!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + save_point = %request.save_point, + error = %e, + "Failed to save supervisor state" + ); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", format!("Failed to save state: {}", e))), + ).into_response() + } + } +} + +/// Response for supervisor restoration. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RestorationContextResponse { + /// Whether restoration context was found + pub found: bool, + /// The restoration result type + pub restoration_type: String, + /// Human-readable message about the restoration + pub message: String, + /// The supervisor state (if found) + pub state: Option<crate::db::models::SupervisorState>, + /// Pending tasks the supervisor was waiting on + pub pending_tasks: Vec<TaskSummary>, + /// Pending question that needs re-delivery + pub pending_question: Option<PendingQuestionSummary>, + /// Last LLM response for context + pub last_llm_response: Option<String>, + /// Restoration count (how many times this supervisor has been restored) + pub restoration_count: i32, +} + +/// Get supervisor restoration context for crash recovery. +/// +/// This endpoint retrieves the saved state and context needed to restore +/// a supervisor after a crash or daemon restart. +#[utoipa::path( + get, + path = "/api/v1/mesh/supervisor/contracts/{contract_id}/restore", + params( + ("contract_id" = Uuid, Path, description = "Contract ID") + ), + responses( + (status = 200, description = "Restoration context", body = RestorationContextResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Contract or state not found"), + (status = 500, description = "Internal server error"), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh Supervisor" +)] +pub async fn get_restoration_context( + State(state): State<SharedState>, + Path(contract_id): Path<Uuid>, + auth: crate::server::auth::Authenticated, +) -> impl IntoResponse { + let crate::server::auth::Authenticated(auth_info) = auth; + let pool = state.db_pool.as_ref().unwrap(); + + // Validate supervisor state + let validation = match repository::validate_supervisor_state(pool, contract_id, auth_info.owner_id).await { + Ok(v) => v, + Err(e) => { + tracing::error!(error = %e, "Failed to validate supervisor state"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", "Failed to validate state")), + ).into_response(); + } + }; + + match validation { + repository::StateValidationResult::NotFound => { + ( + StatusCode::OK, + Json(RestorationContextResponse { + found: false, + restoration_type: "fresh_start".to_string(), + message: "No previous supervisor state found. Starting fresh.".to_string(), + state: None, + pending_tasks: vec![], + pending_question: None, + last_llm_response: None, + restoration_count: 0, + }), + ).into_response() + } + repository::StateValidationResult::Invalid { reason } => { + tracing::warn!(contract_id = %contract_id, reason = %reason, "Invalid supervisor state"); + ( + StatusCode::OK, + Json(RestorationContextResponse { + found: false, + restoration_type: "fresh_start".to_string(), + message: format!("Previous state invalid: {}. Starting fresh.", reason), + state: None, + pending_tasks: vec![], + pending_question: None, + last_llm_response: None, + restoration_count: 0, + }), + ).into_response() + } + repository::StateValidationResult::Valid(supervisor_state) | + repository::StateValidationResult::PartiallyValid { state: supervisor_state, .. } | + repository::StateValidationResult::PhaseStale { state: supervisor_state, .. } => { + // Get full restoration context + let (pending_tasks, pending_question) = get_restoration_details( + pool, + &state, + &supervisor_state, + auth_info.owner_id, + ).await; + + let restoration_type = match validation { + repository::StateValidationResult::Valid(_) => "full_restore", + repository::StateValidationResult::PartiallyValid { .. } => "partial_restore", + repository::StateValidationResult::PhaseStale { .. } => "checkpoint_restore", + _ => "unknown", + }; + + let message = match validation { + repository::StateValidationResult::Valid(_) => + format!("Supervisor state found. Last activity: {}. Restoring from {} phase.", + supervisor_state.last_activity.format("%Y-%m-%d %H:%M:%S UTC"), + supervisor_state.phase), + repository::StateValidationResult::PartiallyValid { invalid_task_ids, .. } => + format!("Partial state found. {} task(s) no longer exist. Restoring with available context.", + invalid_task_ids.len()), + repository::StateValidationResult::PhaseStale { current_phase, .. } => + format!("State found but phase changed from {} to {}. Restoring with updated phase.", + supervisor_state.phase, current_phase), + _ => "Unknown restoration type".to_string(), + }; + + let restoration_count = supervisor_state.restoration_count.unwrap_or(0); + let last_llm_response = supervisor_state.last_llm_response.clone(); + + ( + StatusCode::OK, + Json(RestorationContextResponse { + found: true, + restoration_type: restoration_type.to_string(), + message, + state: Some(supervisor_state), + pending_tasks, + pending_question, + last_llm_response, + restoration_count, + }), + ).into_response() + } + } +} + +/// Helper function to get restoration details (pending tasks and questions). +async fn get_restoration_details( + pool: &sqlx::PgPool, + state: &SharedState, + supervisor_state: &crate::db::models::SupervisorState, + owner_id: Uuid, +) -> (Vec<TaskSummary>, Option<PendingQuestionSummary>) { + // Get pending tasks + let pending_tasks = if !supervisor_state.pending_task_ids.is_empty() { + match sqlx::query_as::<_, Task>( + r#" + SELECT * FROM tasks + WHERE id = ANY($1) AND owner_id = $2 + ORDER BY created_at ASC + "#, + ) + .bind(&supervisor_state.pending_task_ids) + .bind(owner_id) + .fetch_all(pool) + .await + { + Ok(tasks) => tasks.into_iter().map(TaskSummary::from).collect(), + Err(e) => { + tracing::warn!(error = %e, "Failed to get pending tasks for restoration"); + vec![] + } + } + } else { + vec![] + }; + + // Get pending question if any + let pending_question = if let Some(question_id) = supervisor_state.pending_question_id { + state.get_pending_question(question_id).map(|q| PendingQuestionSummary { + question_id: q.question_id, + task_id: q.task_id, + contract_id: q.contract_id, + question: q.question, + choices: q.choices, + context: q.context, + created_at: q.created_at, + multi_select: q.multi_select, + question_type: q.question_type, + }) + } else { + None + }; + + (pending_tasks, pending_question) +} + +/// Mark supervisor as restored after successful restoration. +#[utoipa::path( + post, + path = "/api/v1/mesh/supervisor/contracts/{contract_id}/restored", + params( + ("contract_id" = Uuid, Path, description = "Contract ID") + ), + responses( + (status = 200, description = "Restoration marked", body = SaveStateResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Contract or state not found"), + (status = 500, description = "Internal server error"), + ), + security( + ("tool_key" = []) + ), + tag = "Mesh Supervisor" +)] +pub async fn mark_supervisor_restored( + State(state): State<SharedState>, + Path(contract_id): Path<Uuid>, + headers: HeaderMap, +) -> impl IntoResponse { + let (supervisor_id, _owner_id) = match verify_supervisor_auth(&state, &headers, Some(contract_id)).await { + Ok(ids) => ids, + Err(e) => return e.into_response(), + }; + + let pool = state.db_pool.as_ref().unwrap(); + + match repository::mark_supervisor_restored(pool, contract_id).await { + Ok(saved_state) => { + tracing::info!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + restoration_count = ?saved_state.restoration_count, + "Supervisor marked as restored" + ); + ( + StatusCode::OK, + Json(SaveStateResponse { + success: true, + state_id: saved_state.id, + save_point: "restoration_complete".to_string(), + message: format!( + "Supervisor restored successfully (restoration #{})", + saved_state.restoration_count.unwrap_or(1) + ), + }), + ).into_response() + } + Err(e) => { + tracing::error!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + error = %e, + "Failed to mark supervisor as restored" + ); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", format!("Failed to mark restored: {}", e))), + ).into_response() + } + } +} + +/// Lightweight heartbeat update for supervisor state. +/// +/// This is a minimal update that only touches last_activity and optionally progress/state. +/// Used for frequent heartbeats to avoid overhead of full state saves. +#[utoipa::path( + post, + path = "/api/v1/mesh/supervisor/state/heartbeat", + request_body = HeartbeatRequest, + responses( + (status = 200, description = "Heartbeat recorded"), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - not a supervisor"), + (status = 500, description = "Internal server error"), + ), + security( + ("tool_key" = []) + ), + tag = "Mesh Supervisor" +)] +pub async fn supervisor_heartbeat( + State(state): State<SharedState>, + headers: HeaderMap, + Json(request): Json<HeartbeatRequest>, +) -> impl IntoResponse { + let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await { + Ok(ids) => ids, + Err(e) => return e.into_response(), + }; + + let pool = state.db_pool.as_ref().unwrap(); + + // Get the supervisor task to find its contract + let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Supervisor task not found")), + ).into_response(); + } + Err(e) => { + tracing::error!(error = %e, "Failed to get supervisor task"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")), + ).into_response(); + } + }; + + let Some(contract_id) = supervisor.contract_id else { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")), + ).into_response(); + }; + + // Update activity state (lightweight) + match repository::update_supervisor_activity_state( + pool, + contract_id, + request.state.as_deref().unwrap_or("executing"), + request.current_activity.as_deref(), + request.progress, + ).await { + Ok(_) => { + tracing::trace!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + "Supervisor heartbeat recorded" + ); + StatusCode::OK.into_response() + } + Err(e) => { + tracing::warn!( + supervisor_id = %supervisor_id, + contract_id = %contract_id, + error = %e, + "Failed to record supervisor heartbeat" + ); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", "Failed to record heartbeat")), + ).into_response() + } + } +} + +/// Request for supervisor heartbeat. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct HeartbeatRequest { + /// Current state (optional) + pub state: Option<String>, + /// Current activity description (optional) + pub current_activity: Option<String>, + /// Progress percentage (optional) + pub progress: Option<i32>, +} |
