summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-02-01 00:47:02 +0000
committersoryu <soryu@soryu.co>2026-02-01 00:47:02 +0000
commit999ecf644f58af7de0b0a36b22a69897d8056a1c (patch)
treecd294c110c753034b5f9137b80a4aa493dd1a969
parent10d9b4ce345ac74161108818ad5532a74336cc3d (diff)
downloadsoryu-999ecf644f58af7de0b0a36b22a69897d8056a1c.tar.gz
soryu-999ecf644f58af7de0b0a36b22a69897d8056a1c.zip
[WIP] Heartbeat checkpoint - 2026-02-01 00:47:02 UTC
-rw-r--r--makima/src/daemon/ws/protocol.rs103
-rw-r--r--makima/src/server/handlers/mesh_daemon.rs352
-rw-r--r--makima/src/server/handlers/mesh_supervisor.rs507
3 files changed, 960 insertions, 2 deletions
diff --git a/makima/src/daemon/ws/protocol.rs b/makima/src/daemon/ws/protocol.rs
index bfe6326..8c2994a 100644
--- a/makima/src/daemon/ws/protocol.rs
+++ b/makima/src/daemon/ws/protocol.rs
@@ -397,6 +397,63 @@ pub enum DaemonMessage {
#[serde(rename = "baseSha")]
base_sha: String,
},
+
+ // =========================================================================
+ // Supervisor State Update Messages (Phase 3: Crash Recovery)
+ // =========================================================================
+
+ /// Supervisor state update for crash recovery.
+ /// Sent periodically or at key save points to persist state.
+ SupervisorStateUpdate {
+ /// Task ID of the supervisor.
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Contract ID.
+ #[serde(rename = "contractId")]
+ contract_id: Uuid,
+ /// Save point type that triggered this update.
+ #[serde(rename = "savePoint")]
+ save_point: String,
+ /// Current supervisor activity state.
+ state: Option<String>,
+ /// Human-readable current activity.
+ #[serde(rename = "currentActivity")]
+ current_activity: Option<String>,
+ /// Progress percentage (0-100).
+ progress: Option<i32>,
+ /// Last LLM response for context restoration.
+ #[serde(rename = "lastLlmResponse")]
+ last_llm_response: Option<String>,
+ /// Task that was just spawned (if save_point is "task_spawn").
+ #[serde(rename = "spawnedTaskId")]
+ spawned_task_id: Option<Uuid>,
+ /// Question ID (if save_point is "question_asked").
+ #[serde(rename = "questionId")]
+ question_id: Option<Uuid>,
+ /// Error message (if state is "error").
+ #[serde(rename = "errorMessage")]
+ error_message: Option<String>,
+ /// Updated conversation history (sent on llm_response save points).
+ #[serde(rename = "conversationHistory")]
+ conversation_history: Option<serde_json::Value>,
+ },
+
+ /// Supervisor heartbeat for lightweight state updates.
+ SupervisorHeartbeat {
+ /// Task ID of the supervisor.
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Contract ID.
+ #[serde(rename = "contractId")]
+ contract_id: Uuid,
+ /// Current state (optional).
+ state: Option<String>,
+ /// Current activity description (optional).
+ #[serde(rename = "currentActivity")]
+ current_activity: Option<String>,
+ /// Progress percentage (optional).
+ progress: Option<i32>,
+ },
}
/// Information about a branch (used in BranchList message).
@@ -857,6 +914,52 @@ impl DaemonMessage {
pub fn revoke_tool_key(task_id: Uuid) -> Self {
Self::RevokeToolKey { task_id }
}
+
+ /// Create a supervisor state update message.
+ pub fn supervisor_state_update(
+ task_id: Uuid,
+ contract_id: Uuid,
+ save_point: &str,
+ state: Option<&str>,
+ current_activity: Option<&str>,
+ progress: Option<i32>,
+ last_llm_response: Option<&str>,
+ spawned_task_id: Option<Uuid>,
+ question_id: Option<Uuid>,
+ error_message: Option<&str>,
+ conversation_history: Option<serde_json::Value>,
+ ) -> Self {
+ Self::SupervisorStateUpdate {
+ task_id,
+ contract_id,
+ save_point: save_point.to_string(),
+ state: state.map(|s| s.to_string()),
+ current_activity: current_activity.map(|s| s.to_string()),
+ progress,
+ last_llm_response: last_llm_response.map(|s| s.to_string()),
+ spawned_task_id,
+ question_id,
+ error_message: error_message.map(|s| s.to_string()),
+ conversation_history,
+ }
+ }
+
+ /// Create a supervisor heartbeat message.
+ pub fn supervisor_heartbeat(
+ task_id: Uuid,
+ contract_id: Uuid,
+ state: Option<&str>,
+ current_activity: Option<&str>,
+ progress: Option<i32>,
+ ) -> Self {
+ Self::SupervisorHeartbeat {
+ task_id,
+ contract_id,
+ state: state.map(|s| s.to_string()),
+ current_activity: current_activity.map(|s| s.to_string()),
+ progress,
+ }
+ }
}
#[cfg(test)]
diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs
index 1152502..4c6a045 100644
--- a/makima/src/server/handlers/mesh_daemon.rs
+++ b/makima/src/server/handlers/mesh_daemon.rs
@@ -1190,8 +1190,32 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re
let pool = pool.clone();
let state = state.clone();
tokio::spawn(async move {
- if worktree_intact {
- // Worktree exists - task can be resumed on this daemon
+ // First, get the task to check if it's a supervisor
+ let task = match repository::get_task(&pool, task_id).await {
+ Ok(Some(t)) => t,
+ Ok(None) => {
+ tracing::warn!(task_id = %task_id, "Task not found during recovery");
+ return;
+ }
+ Err(e) => {
+ tracing::error!(task_id = %task_id, error = %e, "Failed to get task during recovery");
+ return;
+ }
+ };
+
+ // Handle supervisor-specific recovery
+ if task.is_supervisor {
+ handle_supervisor_recovery(
+ &pool,
+ &state,
+ task_id,
+ task.contract_id,
+ owner_id,
+ worktree_intact,
+ &previous_state,
+ ).await;
+ } else if worktree_intact {
+ // Regular task - worktree exists, task can be resumed on this daemon
// Update task status to 'pending' so it can be picked up
match sqlx::query(
r#"
@@ -1973,6 +1997,98 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re
let _ = tx.send(response);
}
}
+ Ok(DaemonMessage::SupervisorStateUpdate {
+ task_id,
+ contract_id,
+ save_point,
+ state: supervisor_state,
+ current_activity,
+ progress,
+ last_llm_response,
+ spawned_task_id,
+ question_id,
+ error_message,
+ conversation_history,
+ }) => {
+ tracing::debug!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ save_point = %save_point,
+ "Received supervisor state update"
+ );
+
+ // Persist the state update
+ if let Some(ref pool) = state.db_pool {
+ let pool = pool.clone();
+ tokio::spawn(async move {
+ match repository::save_supervisor_state_at_savepoint(
+ &pool,
+ contract_id,
+ &save_point,
+ supervisor_state.as_deref(),
+ current_activity.as_deref(),
+ progress,
+ last_llm_response.as_deref(),
+ spawned_task_id,
+ question_id,
+ error_message.as_deref(),
+ conversation_history,
+ ).await {
+ Ok(_) => {
+ tracing::trace!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ save_point = %save_point,
+ "Supervisor state saved"
+ );
+ }
+ Err(e) => {
+ tracing::warn!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ save_point = %save_point,
+ error = %e,
+ "Failed to save supervisor state"
+ );
+ }
+ }
+ });
+ }
+ }
+ Ok(DaemonMessage::SupervisorHeartbeat {
+ task_id,
+ contract_id,
+ state: supervisor_state,
+ current_activity,
+ progress,
+ }) => {
+ tracing::trace!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ "Received supervisor heartbeat"
+ );
+
+ // Lightweight state update
+ if let Some(ref pool) = state.db_pool {
+ let pool = pool.clone();
+ tokio::spawn(async move {
+ if let Err(e) = repository::update_supervisor_activity_state(
+ &pool,
+ contract_id,
+ supervisor_state.as_deref().unwrap_or("executing"),
+ current_activity.as_deref(),
+ progress,
+ ).await {
+ tracing::warn!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ error = %e,
+ "Failed to update supervisor activity state"
+ );
+ }
+ });
+ }
+ }
Ok(DaemonMessage::MergePatchToSupervisor {
task_id,
supervisor_task_id,
@@ -2180,3 +2296,235 @@ async fn handle_daemon_disconnect_tasks(pool: &sqlx::PgPool, daemon_id: Uuid) ->
Ok(())
}
+
+// =============================================================================
+// Supervisor Recovery Protocol (Phase 3: Crash Recovery)
+// =============================================================================
+
+/// Handle supervisor-specific recovery after daemon crash/restart.
+///
+/// This implements the Supervisor Restoration Protocol:
+/// 1. Load supervisor state from supervisor_states
+/// 2. If NOT FOUND: Start fresh, log warning
+/// 3. If FOUND: Validate state consistency
+/// 4. If INVALID: Start from last checkpoint
+/// 5. If VALID: Restore conversation history
+/// 6. Check for pending questions - re-deliver to user
+/// 7. Check for waiting tasks - resume waiting state
+/// 8. Send restoration context to Claude
+/// 9. Resume execution from last state
+async fn handle_supervisor_recovery(
+ pool: &sqlx::PgPool,
+ state: &SharedState,
+ task_id: Uuid,
+ contract_id: Option<Uuid>,
+ owner_id: Uuid,
+ worktree_intact: bool,
+ previous_state: &str,
+) {
+ let Some(contract_id) = contract_id else {
+ tracing::warn!(task_id = %task_id, "Supervisor has no contract_id, treating as regular task");
+ return;
+ };
+
+ tracing::info!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ worktree_intact = worktree_intact,
+ previous_state = %previous_state,
+ "Starting supervisor recovery protocol"
+ );
+
+ // Step 1-2: Load supervisor state from database
+ let supervisor_state = match repository::get_supervisor_state(pool, contract_id).await {
+ Ok(Some(s)) => s,
+ Ok(None) => {
+ tracing::warn!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ "No supervisor state found - will start fresh"
+ );
+ // Mark task as pending for fresh start
+ if let Err(e) = sqlx::query(
+ r#"
+ UPDATE tasks
+ SET status = 'pending',
+ daemon_id = NULL,
+ error_message = 'Supervisor restarted - no previous state found',
+ interrupted_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1 AND owner_id = $2
+ "#,
+ )
+ .bind(task_id)
+ .bind(owner_id)
+ .execute(pool)
+ .await
+ {
+ tracing::error!(task_id = %task_id, error = %e, "Failed to update supervisor for fresh start");
+ }
+ return;
+ }
+ Err(e) => {
+ tracing::error!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ error = %e,
+ "Failed to load supervisor state - will start fresh"
+ );
+ return;
+ }
+ };
+
+ // Step 3: Validate state consistency
+ let validation = match repository::validate_supervisor_state(pool, contract_id, owner_id).await {
+ Ok(v) => v,
+ Err(e) => {
+ tracing::error!(
+ task_id = %task_id,
+ error = %e,
+ "Failed to validate supervisor state"
+ );
+ return;
+ }
+ };
+
+ let (is_valid, restoration_message) = match validation {
+ repository::StateValidationResult::Valid(_) => {
+ (true, "Supervisor state valid - restoring full context".to_string())
+ }
+ repository::StateValidationResult::NotFound => {
+ (false, "Supervisor state not found - starting fresh".to_string())
+ }
+ repository::StateValidationResult::Invalid { reason } => {
+ (false, format!("Supervisor state invalid: {} - starting fresh", reason))
+ }
+ repository::StateValidationResult::PartiallyValid { invalid_task_ids, .. } => {
+ (true, format!(
+ "Supervisor state partially valid ({} tasks missing) - restoring with available context",
+ invalid_task_ids.len()
+ ))
+ }
+ repository::StateValidationResult::PhaseStale { current_phase, .. } => {
+ (true, format!(
+ "Supervisor state valid but phase changed to '{}' - restoring with updated phase",
+ current_phase
+ ))
+ }
+ };
+
+ tracing::info!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ is_valid = is_valid,
+ message = %restoration_message,
+ "Supervisor state validation complete"
+ );
+
+ // Step 4-5: Handle invalid state or valid restoration
+ if !is_valid {
+ // Start from last checkpoint or fresh
+ if let Err(e) = sqlx::query(
+ r#"
+ UPDATE tasks
+ SET status = 'pending',
+ daemon_id = NULL,
+ error_message = $3,
+ interrupted_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1 AND owner_id = $2
+ "#,
+ )
+ .bind(task_id)
+ .bind(owner_id)
+ .bind(&restoration_message)
+ .execute(pool)
+ .await
+ {
+ tracing::error!(task_id = %task_id, error = %e, "Failed to update supervisor for checkpoint start");
+ }
+ return;
+ }
+
+ // Step 6: Check for pending questions - prepare for re-delivery
+ let pending_question_id = supervisor_state.pending_question_id;
+ if let Some(question_id) = pending_question_id {
+ tracing::info!(
+ task_id = %task_id,
+ question_id = %question_id,
+ "Supervisor has pending question - will re-deliver after restoration"
+ );
+ }
+
+ // Step 7: Check for waiting tasks
+ let pending_task_count = supervisor_state.pending_task_ids.len();
+ if pending_task_count > 0 {
+ tracing::info!(
+ task_id = %task_id,
+ pending_task_count = pending_task_count,
+ "Supervisor has pending tasks - will resume waiting state"
+ );
+ }
+
+ // Step 8-9: Prepare restoration context and update task for resumption
+ let restoration_context = serde_json::json!({
+ "restored": true,
+ "restoration_count": supervisor_state.restoration_count.unwrap_or(0) + 1,
+ "last_state": supervisor_state.state,
+ "last_activity": supervisor_state.current_activity,
+ "phase": supervisor_state.phase,
+ "pending_tasks": supervisor_state.pending_task_ids,
+ "has_pending_question": pending_question_id.is_some(),
+ "pending_question_id": pending_question_id,
+ "message": restoration_message,
+ });
+
+ // Update task status and prepare for resumption with restoration context
+ match sqlx::query(
+ r#"
+ UPDATE tasks
+ SET status = 'pending',
+ daemon_id = NULL,
+ error_message = $3,
+ interrupted_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1 AND owner_id = $2
+ "#,
+ )
+ .bind(task_id)
+ .bind(owner_id)
+ .bind(&restoration_message)
+ .execute(pool)
+ .await
+ {
+ Ok(_) => {
+ tracing::info!(
+ task_id = %task_id,
+ contract_id = %contract_id,
+ restoration_context = %restoration_context,
+ "Supervisor marked for restoration"
+ );
+
+ // Broadcast update
+ state.broadcast_task_update(TaskUpdateNotification {
+ task_id,
+ owner_id: Some(owner_id),
+ version: 0,
+ status: "pending".into(),
+ updated_fields: vec![
+ "status".into(),
+ "daemon_id".into(),
+ "interrupted_at".into(),
+ ],
+ updated_by: "supervisor_recovery".into(),
+ });
+ }
+ Err(e) => {
+ tracing::error!(
+ task_id = %task_id,
+ error = %e,
+ "Failed to update supervisor for restoration"
+ );
+ }
+ }
+}
diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs
index 3411ec0..b33c1c9 100644
--- a/makima/src/server/handlers/mesh_supervisor.rs
+++ b/makima/src/server/handlers/mesh_supervisor.rs
@@ -2748,3 +2748,510 @@ pub async fn spawn_red_team_task(
// It will remain pending and can be started later
Ok(task)
}
+
+// =============================================================================
+// Supervisor State Persistence Handlers (Phase 3: Crash Recovery)
+// =============================================================================
+
+/// Request to save supervisor state at a save point.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct SaveStateRequest {
+ /// The save point type
+ pub save_point: String,
+ /// Updated conversation history (if available)
+ pub conversation_history: Option<serde_json::Value>,
+ /// Current state
+ pub state: Option<String>,
+ /// Current activity description
+ pub current_activity: Option<String>,
+ /// Progress percentage
+ pub progress: Option<i32>,
+ /// Last LLM response
+ pub last_llm_response: Option<String>,
+ /// Task that was spawned (for task_spawn save point)
+ pub spawned_task_id: Option<Uuid>,
+ /// Question that was asked (for question_asked save point)
+ pub question_id: Option<Uuid>,
+ /// Error message (for error save point)
+ pub error_message: Option<String>,
+}
+
+/// Response for state save operation.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct SaveStateResponse {
+ pub success: bool,
+ pub state_id: Uuid,
+ pub save_point: String,
+ pub message: String,
+}
+
+/// Save supervisor state at a specific save point.
+///
+/// This endpoint is called by the supervisor to persist its state for crash recovery.
+/// State should be saved at key points: LLM response, task spawn, question asked, phase change, heartbeat.
+#[utoipa::path(
+ post,
+ path = "/api/v1/mesh/supervisor/state/save",
+ request_body = SaveStateRequest,
+ responses(
+ (status = 200, description = "State saved", body = SaveStateResponse),
+ (status = 401, description = "Unauthorized"),
+ (status = 403, description = "Forbidden - not a supervisor"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("tool_key" = [])
+ ),
+ tag = "Mesh Supervisor"
+)]
+pub async fn save_supervisor_state(
+ State(state): State<SharedState>,
+ headers: HeaderMap,
+ Json(request): Json<SaveStateRequest>,
+) -> impl IntoResponse {
+ let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await {
+ Ok(ids) => ids,
+ Err(e) => return e.into_response(),
+ };
+
+ let pool = state.db_pool.as_ref().unwrap();
+
+ // Get the supervisor task to find its contract
+ let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await {
+ Ok(Some(t)) => t,
+ Ok(None) => {
+ return (
+ StatusCode::NOT_FOUND,
+ Json(ApiError::new("NOT_FOUND", "Supervisor task not found")),
+ ).into_response();
+ }
+ Err(e) => {
+ tracing::error!(error = %e, "Failed to get supervisor task");
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")),
+ ).into_response();
+ }
+ };
+
+ let Some(contract_id) = supervisor.contract_id else {
+ return (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")),
+ ).into_response();
+ };
+
+ // Save the state at the specified save point
+ let result = repository::save_supervisor_state_at_savepoint(
+ pool,
+ contract_id,
+ &request.save_point,
+ request.state.as_deref(),
+ request.current_activity.as_deref(),
+ request.progress,
+ request.last_llm_response.as_deref(),
+ request.spawned_task_id,
+ request.question_id,
+ request.error_message.as_deref(),
+ request.conversation_history.clone(),
+ ).await;
+
+ match result {
+ Ok(saved_state) => {
+ tracing::debug!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ save_point = %request.save_point,
+ "Supervisor state saved"
+ );
+ (
+ StatusCode::OK,
+ Json(SaveStateResponse {
+ success: true,
+ state_id: saved_state.id,
+ save_point: request.save_point,
+ message: "State saved successfully".to_string(),
+ }),
+ ).into_response()
+ }
+ Err(e) => {
+ tracing::error!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ save_point = %request.save_point,
+ error = %e,
+ "Failed to save supervisor state"
+ );
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", format!("Failed to save state: {}", e))),
+ ).into_response()
+ }
+ }
+}
+
+/// Response for supervisor restoration.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct RestorationContextResponse {
+ /// Whether restoration context was found
+ pub found: bool,
+ /// The restoration result type
+ pub restoration_type: String,
+ /// Human-readable message about the restoration
+ pub message: String,
+ /// The supervisor state (if found)
+ pub state: Option<crate::db::models::SupervisorState>,
+ /// Pending tasks the supervisor was waiting on
+ pub pending_tasks: Vec<TaskSummary>,
+ /// Pending question that needs re-delivery
+ pub pending_question: Option<PendingQuestionSummary>,
+ /// Last LLM response for context
+ pub last_llm_response: Option<String>,
+ /// Restoration count (how many times this supervisor has been restored)
+ pub restoration_count: i32,
+}
+
+/// Get supervisor restoration context for crash recovery.
+///
+/// This endpoint retrieves the saved state and context needed to restore
+/// a supervisor after a crash or daemon restart.
+#[utoipa::path(
+ get,
+ path = "/api/v1/mesh/supervisor/contracts/{contract_id}/restore",
+ params(
+ ("contract_id" = Uuid, Path, description = "Contract ID")
+ ),
+ responses(
+ (status = 200, description = "Restoration context", body = RestorationContextResponse),
+ (status = 401, description = "Unauthorized"),
+ (status = 404, description = "Contract or state not found"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("bearer_auth" = []),
+ ("api_key" = [])
+ ),
+ tag = "Mesh Supervisor"
+)]
+pub async fn get_restoration_context(
+ State(state): State<SharedState>,
+ Path(contract_id): Path<Uuid>,
+ auth: crate::server::auth::Authenticated,
+) -> impl IntoResponse {
+ let crate::server::auth::Authenticated(auth_info) = auth;
+ let pool = state.db_pool.as_ref().unwrap();
+
+ // Validate supervisor state
+ let validation = match repository::validate_supervisor_state(pool, contract_id, auth_info.owner_id).await {
+ Ok(v) => v,
+ Err(e) => {
+ tracing::error!(error = %e, "Failed to validate supervisor state");
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", "Failed to validate state")),
+ ).into_response();
+ }
+ };
+
+ match validation {
+ repository::StateValidationResult::NotFound => {
+ (
+ StatusCode::OK,
+ Json(RestorationContextResponse {
+ found: false,
+ restoration_type: "fresh_start".to_string(),
+ message: "No previous supervisor state found. Starting fresh.".to_string(),
+ state: None,
+ pending_tasks: vec![],
+ pending_question: None,
+ last_llm_response: None,
+ restoration_count: 0,
+ }),
+ ).into_response()
+ }
+ repository::StateValidationResult::Invalid { reason } => {
+ tracing::warn!(contract_id = %contract_id, reason = %reason, "Invalid supervisor state");
+ (
+ StatusCode::OK,
+ Json(RestorationContextResponse {
+ found: false,
+ restoration_type: "fresh_start".to_string(),
+ message: format!("Previous state invalid: {}. Starting fresh.", reason),
+ state: None,
+ pending_tasks: vec![],
+ pending_question: None,
+ last_llm_response: None,
+ restoration_count: 0,
+ }),
+ ).into_response()
+ }
+ repository::StateValidationResult::Valid(supervisor_state) |
+ repository::StateValidationResult::PartiallyValid { state: supervisor_state, .. } |
+ repository::StateValidationResult::PhaseStale { state: supervisor_state, .. } => {
+ // Get full restoration context
+ let (pending_tasks, pending_question) = get_restoration_details(
+ pool,
+ &state,
+ &supervisor_state,
+ auth_info.owner_id,
+ ).await;
+
+ let restoration_type = match validation {
+ repository::StateValidationResult::Valid(_) => "full_restore",
+ repository::StateValidationResult::PartiallyValid { .. } => "partial_restore",
+ repository::StateValidationResult::PhaseStale { .. } => "checkpoint_restore",
+ _ => "unknown",
+ };
+
+ let message = match validation {
+ repository::StateValidationResult::Valid(_) =>
+ format!("Supervisor state found. Last activity: {}. Restoring from {} phase.",
+ supervisor_state.last_activity.format("%Y-%m-%d %H:%M:%S UTC"),
+ supervisor_state.phase),
+ repository::StateValidationResult::PartiallyValid { invalid_task_ids, .. } =>
+ format!("Partial state found. {} task(s) no longer exist. Restoring with available context.",
+ invalid_task_ids.len()),
+ repository::StateValidationResult::PhaseStale { current_phase, .. } =>
+ format!("State found but phase changed from {} to {}. Restoring with updated phase.",
+ supervisor_state.phase, current_phase),
+ _ => "Unknown restoration type".to_string(),
+ };
+
+ let restoration_count = supervisor_state.restoration_count.unwrap_or(0);
+ let last_llm_response = supervisor_state.last_llm_response.clone();
+
+ (
+ StatusCode::OK,
+ Json(RestorationContextResponse {
+ found: true,
+ restoration_type: restoration_type.to_string(),
+ message,
+ state: Some(supervisor_state),
+ pending_tasks,
+ pending_question,
+ last_llm_response,
+ restoration_count,
+ }),
+ ).into_response()
+ }
+ }
+}
+
+/// Helper function to get restoration details (pending tasks and questions).
+async fn get_restoration_details(
+ pool: &sqlx::PgPool,
+ state: &SharedState,
+ supervisor_state: &crate::db::models::SupervisorState,
+ owner_id: Uuid,
+) -> (Vec<TaskSummary>, Option<PendingQuestionSummary>) {
+ // Get pending tasks
+ let pending_tasks = if !supervisor_state.pending_task_ids.is_empty() {
+ match sqlx::query_as::<_, Task>(
+ r#"
+ SELECT * FROM tasks
+ WHERE id = ANY($1) AND owner_id = $2
+ ORDER BY created_at ASC
+ "#,
+ )
+ .bind(&supervisor_state.pending_task_ids)
+ .bind(owner_id)
+ .fetch_all(pool)
+ .await
+ {
+ Ok(tasks) => tasks.into_iter().map(TaskSummary::from).collect(),
+ Err(e) => {
+ tracing::warn!(error = %e, "Failed to get pending tasks for restoration");
+ vec![]
+ }
+ }
+ } else {
+ vec![]
+ };
+
+ // Get pending question if any
+ let pending_question = if let Some(question_id) = supervisor_state.pending_question_id {
+ state.get_pending_question(question_id).map(|q| PendingQuestionSummary {
+ question_id: q.question_id,
+ task_id: q.task_id,
+ contract_id: q.contract_id,
+ question: q.question,
+ choices: q.choices,
+ context: q.context,
+ created_at: q.created_at,
+ multi_select: q.multi_select,
+ question_type: q.question_type,
+ })
+ } else {
+ None
+ };
+
+ (pending_tasks, pending_question)
+}
+
+/// Mark supervisor as restored after successful restoration.
+#[utoipa::path(
+ post,
+ path = "/api/v1/mesh/supervisor/contracts/{contract_id}/restored",
+ params(
+ ("contract_id" = Uuid, Path, description = "Contract ID")
+ ),
+ responses(
+ (status = 200, description = "Restoration marked", body = SaveStateResponse),
+ (status = 401, description = "Unauthorized"),
+ (status = 404, description = "Contract or state not found"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("tool_key" = [])
+ ),
+ tag = "Mesh Supervisor"
+)]
+pub async fn mark_supervisor_restored(
+ State(state): State<SharedState>,
+ Path(contract_id): Path<Uuid>,
+ headers: HeaderMap,
+) -> impl IntoResponse {
+ let (supervisor_id, _owner_id) = match verify_supervisor_auth(&state, &headers, Some(contract_id)).await {
+ Ok(ids) => ids,
+ Err(e) => return e.into_response(),
+ };
+
+ let pool = state.db_pool.as_ref().unwrap();
+
+ match repository::mark_supervisor_restored(pool, contract_id).await {
+ Ok(saved_state) => {
+ tracing::info!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ restoration_count = ?saved_state.restoration_count,
+ "Supervisor marked as restored"
+ );
+ (
+ StatusCode::OK,
+ Json(SaveStateResponse {
+ success: true,
+ state_id: saved_state.id,
+ save_point: "restoration_complete".to_string(),
+ message: format!(
+ "Supervisor restored successfully (restoration #{})",
+ saved_state.restoration_count.unwrap_or(1)
+ ),
+ }),
+ ).into_response()
+ }
+ Err(e) => {
+ tracing::error!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ error = %e,
+ "Failed to mark supervisor as restored"
+ );
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", format!("Failed to mark restored: {}", e))),
+ ).into_response()
+ }
+ }
+}
+
+/// Lightweight heartbeat update for supervisor state.
+///
+/// This is a minimal update that only touches last_activity and optionally progress/state.
+/// Used for frequent heartbeats to avoid overhead of full state saves.
+#[utoipa::path(
+ post,
+ path = "/api/v1/mesh/supervisor/state/heartbeat",
+ request_body = HeartbeatRequest,
+ responses(
+ (status = 200, description = "Heartbeat recorded"),
+ (status = 401, description = "Unauthorized"),
+ (status = 403, description = "Forbidden - not a supervisor"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("tool_key" = [])
+ ),
+ tag = "Mesh Supervisor"
+)]
+pub async fn supervisor_heartbeat(
+ State(state): State<SharedState>,
+ headers: HeaderMap,
+ Json(request): Json<HeartbeatRequest>,
+) -> impl IntoResponse {
+ let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await {
+ Ok(ids) => ids,
+ Err(e) => return e.into_response(),
+ };
+
+ let pool = state.db_pool.as_ref().unwrap();
+
+ // Get the supervisor task to find its contract
+ let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await {
+ Ok(Some(t)) => t,
+ Ok(None) => {
+ return (
+ StatusCode::NOT_FOUND,
+ Json(ApiError::new("NOT_FOUND", "Supervisor task not found")),
+ ).into_response();
+ }
+ Err(e) => {
+ tracing::error!(error = %e, "Failed to get supervisor task");
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")),
+ ).into_response();
+ }
+ };
+
+ let Some(contract_id) = supervisor.contract_id else {
+ return (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")),
+ ).into_response();
+ };
+
+ // Update activity state (lightweight)
+ match repository::update_supervisor_activity_state(
+ pool,
+ contract_id,
+ request.state.as_deref().unwrap_or("executing"),
+ request.current_activity.as_deref(),
+ request.progress,
+ ).await {
+ Ok(_) => {
+ tracing::trace!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ "Supervisor heartbeat recorded"
+ );
+ StatusCode::OK.into_response()
+ }
+ Err(e) => {
+ tracing::warn!(
+ supervisor_id = %supervisor_id,
+ contract_id = %contract_id,
+ error = %e,
+ "Failed to record supervisor heartbeat"
+ );
+ (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", "Failed to record heartbeat")),
+ ).into_response()
+ }
+ }
+}
+
+/// Request for supervisor heartbeat.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct HeartbeatRequest {
+ /// Current state (optional)
+ pub state: Option<String>,
+ /// Current activity description (optional)
+ pub current_activity: Option<String>,
+ /// Progress percentage (optional)
+ pub progress: Option<i32>,
+}