diff options
| author | soryu <soryu@soryu.co> | 2026-02-01 00:42:53 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2026-02-01 00:42:53 +0000 |
| commit | 96ad3af6051af69e2e8b34b35e8b40926bdd13a1 (patch) | |
| tree | 2e2aedd39c66dedf7da301273306a0c77440ecf4 | |
| parent | bb14010db99b40792372bfcb4348cf4984f30b3f (diff) | |
| download | soryu-96ad3af6051af69e2e8b34b35e8b40926bdd13a1.tar.gz soryu-96ad3af6051af69e2e8b34b35e8b40926bdd13a1.zip | |
feat: Implement Phase 3 Tasks 3.3 and 3.4 - Supervisor State Persistence and Restoration
Task 3.3: Supervisor State Persistence
- Add migration 20260201000001_enhanced_supervisor_state.sql with new columns:
- state (supervisor state enum)
- current_activity (description)
- progress (0-100)
- error_message (for failed states)
- spawned_task_ids (tasks created by supervisor)
- pending_questions (questions awaiting user response)
- restoration_count, last_restored_at, restoration_source (restoration tracking)
- Update SupervisorState model with new fields
- Add PendingQuestion struct for tracking unanswered questions
- Add SupervisorRestorationContext for returning restoration info
- Add StateValidationResult and StateRecoveryAction for state validation
State persistence functions in repository.rs:
- update_supervisor_detailed_state() - Update state, activity, progress
- add_supervisor_spawned_task() - Track spawned tasks
- add_supervisor_pending_question() - Track pending questions
- remove_supervisor_pending_question() - Clear answered questions
- save_supervisor_state_full() - Full state save (UPSERT)
- mark_supervisor_restored() - Increment restoration count
- get_supervisors_with_pending_questions() - Find supervisors with pending questions
- get_supervisor_state_for_restoration() - Load state for restoration
- validate_spawned_tasks() - Validate task consistency
- update_supervisor_phase() - Update on phase change
- update_supervisor_heartbeat_state() - Lightweight heartbeat update
State save points:
- On task spawn (save_state_on_task_spawn)
- On question asked (save_state_on_question_asked)
- On question answered (clear_pending_question)
- On phase change (save_state_on_phase_change)
- On heartbeat (update_supervisor_heartbeat_state)
Task 3.4: Supervisor Restoration Protocol
- Add restoration detection when supervisor starts with existing state
- Implement validate_supervisor_state() for state consistency checks
- Implement restore_supervisor() with validation and context generation
- Add redeliver_pending_questions() for re-delivering questions after crash
- Add generate_restoration_context_message() for Claude context injection
- Update resume_supervisor endpoint to return RestorationInfo
- Re-deliver pending questions when supervisor resumes
Restoration flow:
1. Daemon restarts or task reassigned
2. Load supervisor state from supervisor_states
3. If NOT FOUND: Start fresh, log warning
4. If FOUND: Validate state consistency
5. If INVALID: Start from last checkpoint
6. If VALID: Restore conversation history
7. Check for pending questions - re-deliver to user
8. Check for waiting tasks - resume waiting state
9. Send restoration context to Claude
10. Resume execution from last state
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| -rw-r--r-- | makima/migrations/20260201000001_enhanced_supervisor_state.sql | 56 | ||||
| -rw-r--r-- | makima/src/db/models.rs | 102 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 304 | ||||
| -rw-r--r-- | makima/src/server/handlers/contracts.rs | 16 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_daemon.rs | 126 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_supervisor.rs | 478 |
6 files changed, 1057 insertions, 25 deletions
diff --git a/makima/migrations/20260201000001_enhanced_supervisor_state.sql b/makima/migrations/20260201000001_enhanced_supervisor_state.sql new file mode 100644 index 0000000..5411b73 --- /dev/null +++ b/makima/migrations/20260201000001_enhanced_supervisor_state.sql @@ -0,0 +1,56 @@ +-- Enhanced supervisor state persistence for restoration after crashes. +-- Adds additional fields to supervisor_states to track detailed state for recovery. + +-- Add state tracking field (matches SupervisorStateEnum: initializing, idle, working, +-- waiting_for_user, waiting_for_tasks, blocked, completed, failed, interrupted) +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS state VARCHAR(50) NOT NULL DEFAULT 'initializing'; + +-- Add current activity description for monitoring +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS current_activity TEXT; + +-- Add progress percentage (0-100) +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS progress INTEGER DEFAULT 0 + CHECK (progress >= 0 AND progress <= 100); + +-- Add error message for failed states +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS error_message TEXT; + +-- Add spawned task IDs (tasks this supervisor has created) +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS spawned_task_ids UUID[] DEFAULT ARRAY[]::UUID[]; + +-- Add pending questions (questions waiting for user response) +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS pending_questions JSONB DEFAULT '[]'; + +-- Add restoration metadata +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS restoration_count INTEGER DEFAULT 0; + +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS last_restored_at TIMESTAMPTZ; + +ALTER TABLE supervisor_states + ADD COLUMN IF NOT EXISTS restoration_source VARCHAR(50); + +-- Index for finding supervisors by state (useful for finding blocked/failed supervisors) +CREATE INDEX IF NOT EXISTS idx_supervisor_states_state ON supervisor_states(state); + +-- Index for finding supervisors with pending questions +CREATE INDEX IF NOT EXISTS idx_supervisor_states_pending_questions + ON supervisor_states USING gin(pending_questions) + WHERE pending_questions != '[]'::jsonb; + +COMMENT ON COLUMN supervisor_states.state IS 'Current supervisor state: initializing, idle, working, waiting_for_user, waiting_for_tasks, blocked, completed, failed, interrupted'; +COMMENT ON COLUMN supervisor_states.current_activity IS 'Human-readable description of current activity'; +COMMENT ON COLUMN supervisor_states.progress IS 'Progress percentage (0-100)'; +COMMENT ON COLUMN supervisor_states.error_message IS 'Error message when state is failed or blocked'; +COMMENT ON COLUMN supervisor_states.spawned_task_ids IS 'Array of task UUIDs spawned by this supervisor'; +COMMENT ON COLUMN supervisor_states.pending_questions IS 'Array of questions awaiting user response: [{id, question, choices, context, asked_at}]'; +COMMENT ON COLUMN supervisor_states.restoration_count IS 'Number of times this supervisor has been restored after interruption'; +COMMENT ON COLUMN supervisor_states.last_restored_at IS 'Timestamp of last restoration'; +COMMENT ON COLUMN supervisor_states.restoration_source IS 'Source of last restoration: daemon_restart, task_reassignment, manual'; diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index cc30465..fcbd044 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -1971,6 +1971,50 @@ pub struct SupervisorState { pub last_activity: DateTime<Utc>, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, + /// Current supervisor state (initializing, idle, working, waiting_for_user, etc.) + pub state: String, + /// Human-readable description of current activity + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: i32, + /// Error message when state is failed or blocked + pub error_message: Option<String>, + /// Tasks spawned by this supervisor + #[sqlx(try_from = "Vec<Uuid>")] + pub spawned_task_ids: Vec<Uuid>, + /// Pending questions awaiting user response + #[sqlx(json)] + pub pending_questions: serde_json::Value, + /// Number of times this supervisor has been restored + pub restoration_count: i32, + /// Timestamp of last restoration + pub last_restored_at: Option<DateTime<Utc>>, + /// Source of last restoration (daemon_restart, task_reassignment, manual) + pub restoration_source: Option<String>, +} + +/// Pending question structure for supervisor state +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PendingQuestion { + /// Unique question ID + pub id: Uuid, + /// The question text + pub question: String, + /// Optional choices (empty for free-form) + #[serde(default)] + pub choices: Vec<String>, + /// Optional context + pub context: Option<String>, + /// Question type: general, phase_confirmation, contract_complete + #[serde(default = "default_question_type")] + pub question_type: String, + /// When the question was asked + pub asked_at: DateTime<Utc>, +} + +fn default_question_type() -> String { + "general".to_string() } /// Request to update supervisor state @@ -1983,6 +2027,64 @@ pub struct UpdateSupervisorStateRequest { pub pending_task_ids: Option<Vec<Uuid>>, /// Current contract phase pub phase: Option<String>, + /// Current supervisor state + pub state: Option<String>, + /// Current activity description + pub current_activity: Option<String>, + /// Progress percentage + pub progress: Option<i32>, + /// Error message + pub error_message: Option<String>, + /// Spawned task IDs + pub spawned_task_ids: Option<Vec<Uuid>>, + /// Pending questions + pub pending_questions: Option<serde_json::Value>, +} + +/// Restoration context returned when restoring a supervisor +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorRestorationContext { + /// Whether restoration was successful + pub success: bool, + /// Previous state before restoration + pub previous_state: SupervisorStateEnum, + /// Restored conversation history + pub conversation_history: serde_json::Value, + /// Pending questions that need re-delivery + pub pending_questions: Vec<PendingQuestion>, + /// Tasks still being waited on + pub waiting_task_ids: Vec<Uuid>, + /// Spawned tasks to check status of + pub spawned_task_ids: Vec<Uuid>, + /// Restoration count (incremented) + pub restoration_count: i32, + /// Context message for Claude + pub restoration_context_message: String, + /// Any warnings during restoration + pub warnings: Vec<String>, +} + +/// Validation result for supervisor state consistency +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateValidationResult { + pub is_valid: bool, + pub issues: Vec<String>, + /// Suggested recovery action + pub recovery_action: StateRecoveryAction, +} + +/// Action to take when state validation fails +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum StateRecoveryAction { + /// State is valid, proceed with restoration + Proceed, + /// Start from last checkpoint + UseCheckpoint, + /// Start fresh + StartFresh, + /// Manual intervention required + ManualIntervention, } // ============================================================================ diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 1ac188c..d1ec3ef 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -3404,6 +3404,310 @@ pub async fn update_supervisor_pending_tasks( .await } +/// Update supervisor state with detailed activity tracking. +/// Called at key save points: LLM response, task spawn, question asked, phase change. +pub async fn update_supervisor_detailed_state( + pool: &PgPool, + contract_id: Uuid, + state: &str, + current_activity: Option<&str>, + progress: i32, + error_message: Option<&str>, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET state = $1, + current_activity = $2, + progress = $3, + error_message = $4, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $5 + RETURNING * + "#, + ) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(error_message) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Add a spawned task ID to the supervisor's list. +pub async fn add_supervisor_spawned_task( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET spawned_task_ids = array_append(spawned_task_ids, $1), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(task_id) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Add a pending question to the supervisor state. +pub async fn add_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, + question: serde_json::Value, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_questions = pending_questions || $1::jsonb, + state = 'waiting_for_user', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(question) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Remove a pending question by ID. +pub async fn remove_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, + question_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_questions = ( + SELECT COALESCE(jsonb_agg(elem), '[]'::jsonb) + FROM jsonb_array_elements(pending_questions) elem + WHERE (elem->>'id')::uuid != $1 + ), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(question_id) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Comprehensive state save - used at major save points. +pub async fn save_supervisor_state_full( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, + conversation_history: serde_json::Value, + pending_task_ids: &[Uuid], + phase: &str, + state: &str, + current_activity: Option<&str>, + progress: i32, + error_message: Option<&str>, + spawned_task_ids: &[Uuid], + pending_questions: serde_json::Value, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + INSERT INTO supervisor_states ( + contract_id, task_id, conversation_history, pending_task_ids, phase, + state, current_activity, progress, error_message, spawned_task_ids, + pending_questions, last_activity + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, NOW()) + ON CONFLICT (contract_id) DO UPDATE SET + task_id = EXCLUDED.task_id, + conversation_history = EXCLUDED.conversation_history, + pending_task_ids = EXCLUDED.pending_task_ids, + phase = EXCLUDED.phase, + state = EXCLUDED.state, + current_activity = EXCLUDED.current_activity, + progress = EXCLUDED.progress, + error_message = EXCLUDED.error_message, + spawned_task_ids = EXCLUDED.spawned_task_ids, + pending_questions = EXCLUDED.pending_questions, + last_activity = NOW(), + updated_at = NOW() + RETURNING * + "#, + ) + .bind(contract_id) + .bind(task_id) + .bind(conversation_history) + .bind(pending_task_ids) + .bind(phase) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(error_message) + .bind(spawned_task_ids) + .bind(pending_questions) + .fetch_one(pool) + .await +} + +/// Mark supervisor as restored from a crash/interruption. +pub async fn mark_supervisor_restored( + pool: &PgPool, + contract_id: Uuid, + restoration_source: &str, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET restoration_count = restoration_count + 1, + last_restored_at = NOW(), + restoration_source = $1, + state = 'initializing', + error_message = NULL, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(restoration_source) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Get supervisors with pending questions (for re-delivery after restoration). +pub async fn get_supervisors_with_pending_questions( + pool: &PgPool, + owner_id: Uuid, +) -> Result<Vec<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + SELECT ss.* + FROM supervisor_states ss + JOIN contracts c ON c.id = ss.contract_id + WHERE c.owner_id = $1 + AND ss.pending_questions != '[]'::jsonb + AND c.status = 'active' + ORDER BY ss.last_activity DESC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// Get supervisor state with full details for restoration. +/// Includes validation info. +pub async fn get_supervisor_state_for_restoration( + pool: &PgPool, + contract_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + SELECT * FROM supervisor_states WHERE contract_id = $1 + "#, + ) + .bind(contract_id) + .fetch_optional(pool) + .await +} + +/// Validate spawned tasks are in expected states. +/// Returns map of task_id -> (status, updated_at). +pub async fn validate_spawned_tasks( + pool: &PgPool, + task_ids: &[Uuid], +) -> Result<std::collections::HashMap<Uuid, (String, chrono::DateTime<Utc>)>, sqlx::Error> { + use sqlx::Row; + + let rows = sqlx::query( + r#" + SELECT id, status, updated_at + FROM tasks + WHERE id = ANY($1) + "#, + ) + .bind(task_ids) + .fetch_all(pool) + .await?; + + let mut result = std::collections::HashMap::new(); + for row in rows { + let id: Uuid = row.get("id"); + let status: String = row.get("status"); + let updated_at: chrono::DateTime<Utc> = row.get("updated_at"); + result.insert(id, (status, updated_at)); + } + Ok(result) +} + +/// Update supervisor state when phase changes. +pub async fn update_supervisor_phase( + pool: &PgPool, + contract_id: Uuid, + new_phase: &str, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET phase = $1, + state = 'working', + current_activity = 'Phase changed to ' || $1, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(new_phase) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Update supervisor state on heartbeat (lightweight update). +pub async fn update_supervisor_heartbeat_state( + pool: &PgPool, + contract_id: Uuid, + state: &str, + current_activity: Option<&str>, + progress: i32, + pending_task_ids: &[Uuid], +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + UPDATE supervisor_states + SET state = $1, + current_activity = $2, + progress = $3, + pending_task_ids = $4, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $5 + "#, + ) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(pending_task_ids) + .bind(contract_id) + .execute(pool) + .await?; + Ok(()) +} + // ============================================================================ // Supervisor Heartbeats // ============================================================================ diff --git a/makima/src/server/handlers/contracts.rs b/makima/src/server/handlers/contracts.rs index b15667d..5a87616 100644 --- a/makima/src/server/handlers/contracts.rs +++ b/makima/src/server/handlers/contracts.rs @@ -1461,6 +1461,22 @@ pub async fn change_phase( .await { Ok(PhaseChangeResult::Success(updated_contract)) => { + // Save supervisor state on phase change (Task 3.3) + // This is a key save point for restoration + let new_phase_for_state = updated_contract.phase.clone(); + let contract_id_for_state = updated_contract.id; + let pool_for_state = pool.clone(); + tokio::spawn(async move { + if let Err(e) = repository::update_supervisor_phase(&pool_for_state, contract_id_for_state, &new_phase_for_state).await { + tracing::warn!( + contract_id = %contract_id_for_state, + new_phase = %new_phase_for_state, + error = %e, + "Failed to save supervisor state on phase change" + ); + } + }); + // Notify supervisor of phase change if let Some(supervisor_task_id) = updated_contract.supervisor_task_id { if let Ok(Some(supervisor)) = repository::get_task_for_owner(pool, supervisor_task_id, auth.owner_id).await { diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs index 887183a..34e2cc3 100644 --- a/makima/src/server/handlers/mesh_daemon.rs +++ b/makima/src/server/handlers/mesh_daemon.rs @@ -933,7 +933,7 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re "Supervisor heartbeat received" ); - // Store heartbeat in database + // Store heartbeat in database and update supervisor state (Task 3.3) if let Some(ref pool) = state.db_pool { let pool = pool.clone(); let pending_ids = pending_task_ids.clone(); @@ -960,6 +960,22 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re ); } + // Update supervisor_states table (lightweight heartbeat state update - Task 3.3) + if let Err(e) = repository::update_supervisor_heartbeat_state( + &pool, + contract_id, + &state_str, + activity.as_deref(), + progress as i32, + &pending_ids, + ).await { + tracing::debug!( + contract_id = %contract_id, + error = %e, + "Failed to update supervisor state from heartbeat (may not exist yet)" + ); + } + // Also update the daemon heartbeat if let Ok(Some(task)) = repository::get_task(&pool, task_id).await { if let Some(daemon_id) = task.daemon_id { @@ -1035,55 +1051,117 @@ async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_re updated_by: "daemon".into(), }); - // Initialize supervisor_state when supervisor task starts running + // Initialize or restore supervisor_state when supervisor task starts running (Task 3.4) if updated_task.is_supervisor && new_status_owned == "running" { if let Some(contract_id) = updated_task.contract_id { - // Get contract to get its phase - match repository::get_contract_for_owner( - &pool, - contract_id, - updated_task.owner_id, - ).await { - Ok(Some(contract)) => { - match repository::upsert_supervisor_state( + // Check if supervisor state already exists (restoration scenario) + match repository::get_supervisor_state(&pool, contract_id).await { + Ok(Some(existing_state)) => { + // State exists - this is a restoration + tracing::info!( + task_id = %task_id, + contract_id = %contract_id, + existing_state = %existing_state.state, + restoration_count = existing_state.restoration_count, + "Supervisor starting with existing state - restoration in progress" + ); + + // Mark as restored (increments restoration_count) + match repository::mark_supervisor_restored( &pool, contract_id, - task_id, - serde_json::json!([]), // Empty conversation - &[], // No pending tasks - &contract.phase, + "daemon_restart", ).await { - Ok(_) => { + Ok(restored_state) => { tracing::info!( task_id = %task_id, contract_id = %contract_id, - phase = %contract.phase, - "Initialized supervisor state for running supervisor" + restoration_count = restored_state.restoration_count, + "Supervisor restoration marked" ); + + // Check for pending questions to re-deliver + if let Ok(questions) = serde_json::from_value::<Vec<crate::db::models::PendingQuestion>>( + restored_state.pending_questions.clone() + ) { + if !questions.is_empty() { + tracing::info!( + contract_id = %contract_id, + question_count = questions.len(), + "Pending questions found for re-delivery" + ); + // Questions will be re-delivered by the supervisor when it restores + } + } } Err(e) => { tracing::warn!( task_id = %task_id, contract_id = %contract_id, error = %e, - "Failed to initialize supervisor state" + "Failed to mark supervisor as restored" ); } } } Ok(None) => { - tracing::warn!( - task_id = %task_id, - contract_id = %contract_id, - "Contract not found when initializing supervisor state" - ); + // No existing state - fresh start + // Get contract to get its phase + match repository::get_contract_for_owner( + &pool, + contract_id, + updated_task.owner_id, + ).await { + Ok(Some(contract)) => { + match repository::upsert_supervisor_state( + &pool, + contract_id, + task_id, + serde_json::json!([]), // Empty conversation + &[], // No pending tasks + &contract.phase, + ).await { + Ok(_) => { + tracing::info!( + task_id = %task_id, + contract_id = %contract_id, + phase = %contract.phase, + "Initialized fresh supervisor state" + ); + } + Err(e) => { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + error = %e, + "Failed to initialize supervisor state" + ); + } + } + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + "Contract not found when initializing supervisor state" + ); + } + Err(e) => { + tracing::warn!( + task_id = %task_id, + contract_id = %contract_id, + error = %e, + "Failed to get contract for supervisor state" + ); + } + } } Err(e) => { tracing::warn!( task_id = %task_id, contract_id = %contract_id, error = %e, - "Failed to get contract for supervisor state" + "Failed to check existing supervisor state" ); } } diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs index 3411ec0..a29b666 100644 --- a/makima/src/server/handlers/mesh_supervisor.rs +++ b/makima/src/server/handlers/mesh_supervisor.rs @@ -14,8 +14,9 @@ use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use uuid::Uuid; -use crate::db::models::{CreateTaskRequest, Task, TaskSummary, UpdateTaskRequest}; +use crate::db::models::{CreateTaskRequest, PendingQuestion, Task, TaskSummary, UpdateTaskRequest}; use crate::db::repository; +use sqlx::PgPool; use crate::server::auth::Authenticated; use crate::server::handlers::mesh::{extract_auth, AuthSource}; use crate::server::messages::ApiError; @@ -748,6 +749,9 @@ pub async fn spawn_task( } else { tracing::info!(task_id = %updated_task.id, daemon_id = %daemon.id, repo_url = ?repo_url, "Task spawn command sent"); + // Save state: task spawn is a key save point (Task 3.3) + save_state_on_task_spawn(pool, request.contract_id, updated_task.id).await; + // Broadcast task status update notification to WebSocket subscribers state.broadcast_task_update(TaskUpdateNotification { task_id: updated_task.id, @@ -1770,6 +1774,17 @@ pub async fn ask_question( request.question_type.clone(), ); + // Save state: question asked is a key save point (Task 3.3) + let pending_question = PendingQuestion { + id: question_id, + question: request.question.clone(), + choices: request.choices.clone(), + context: request.context.clone(), + question_type: request.question_type.clone(), + asked_at: chrono::Utc::now(), + }; + save_state_on_question_asked(pool, contract_id, pending_question).await; + // Broadcast question as task output entry for the task's chat let question_data = serde_json::json!({ "question_id": question_id.to_string(), @@ -1864,6 +1879,9 @@ pub async fn ask_question( // Clean up the response state.cleanup_question_response(question_id); + // Clear pending question from supervisor state (Task 3.3) + clear_pending_question(pool, contract_id, question_id).await; + return ( StatusCode::OK, Json(AskQuestionResponse { @@ -1879,6 +1897,9 @@ pub async fn ask_question( // Remove the pending question on timeout state.remove_pending_question(question_id); + // Clear pending question from supervisor state on timeout (Task 3.3) + clear_pending_question(pool, contract_id, question_id).await; + return ( StatusCode::REQUEST_TIMEOUT, Json(AskQuestionResponse { @@ -2031,6 +2052,8 @@ pub struct ResumeSupervisorResponse { pub daemon_id: Option<Uuid>, pub resumed_from: ResumedFromInfo, pub status: String, + /// Restoration context (Task 3.4) + pub restoration: Option<RestorationInfo>, } #[derive(Debug, Serialize, ToSchema)] @@ -2041,6 +2064,24 @@ pub struct ResumedFromInfo { pub message_count: i32, } +/// Information about supervisor restoration (Task 3.4) +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RestorationInfo { + /// Previous state before restoration + pub previous_state: String, + /// How many times this supervisor has been restored + pub restoration_count: i32, + /// Number of pending questions to re-deliver + pub pending_questions_count: usize, + /// Number of tasks being waited on + pub waiting_tasks_count: usize, + /// Number of tasks spawned before crash + pub spawned_tasks_count: usize, + /// Any warnings from state validation + pub warnings: Vec<String>, +} + /// Resume interrupted supervisor with specified mode. /// /// POST /api/v1/contracts/{id}/supervisor/resume @@ -2350,6 +2391,31 @@ pub async fn resume_supervisor( "Supervisor resume requested" ); + // Build restoration info (Task 3.4) + let pending_questions: Vec<PendingQuestion> = serde_json::from_value( + supervisor_state.pending_questions.clone() + ).unwrap_or_default(); + + let restoration_info = RestorationInfo { + previous_state: supervisor_state.state.clone(), + restoration_count: supervisor_state.restoration_count, + pending_questions_count: pending_questions.len(), + waiting_tasks_count: supervisor_state.pending_task_ids.len(), + spawned_tasks_count: supervisor_state.spawned_task_ids.len(), + warnings: vec![], // Could add validation warnings here + }; + + // Re-deliver pending questions if any (Task 3.4) + if !pending_questions.is_empty() { + redeliver_pending_questions( + &state, + supervisor_state.task_id, + contract_id, + auth_info.owner_id, + &pending_questions, + ).await; + } + Json(ResumeSupervisorResponse { supervisor_task_id: supervisor_state.task_id, daemon_id: response_daemon_id, @@ -2359,6 +2425,7 @@ pub async fn resume_supervisor( message_count, }, status: response_status, + restoration: Some(restoration_info), }) .into_response() } @@ -2748,3 +2815,412 @@ pub async fn spawn_red_team_task( // It will remain pending and can be started later Ok(task) } + +// ============================================================================= +// Supervisor State Persistence Helpers (Task 3.3) +// ============================================================================= + +use crate::db::models::{ + SupervisorRestorationContext, SupervisorStateEnum, + StateValidationResult, StateRecoveryAction, +}; + +/// Save supervisor state on task spawn. +/// This is called when a supervisor spawns a new task. +pub async fn save_state_on_task_spawn( + pool: &PgPool, + contract_id: Uuid, + spawned_task_id: Uuid, +) { + if let Err(e) = repository::add_supervisor_spawned_task(pool, contract_id, spawned_task_id).await { + tracing::warn!( + contract_id = %contract_id, + spawned_task_id = %spawned_task_id, + error = %e, + "Failed to save spawned task to supervisor state" + ); + } else { + tracing::debug!( + contract_id = %contract_id, + spawned_task_id = %spawned_task_id, + "Saved spawned task to supervisor state" + ); + } + + // Also update state to working + if let Err(e) = repository::update_supervisor_detailed_state( + pool, + contract_id, + "working", + Some(&format!("Spawned task {}", spawned_task_id)), + 0, // Progress resets when spawning new work + None, + ).await { + tracing::warn!(contract_id = %contract_id, error = %e, "Failed to update supervisor state on task spawn"); + } +} + +/// Save supervisor state on question asked. +/// This is called when a supervisor asks a question and is waiting for user input. +pub async fn save_state_on_question_asked( + pool: &PgPool, + contract_id: Uuid, + question: PendingQuestion, +) { + let question_json = match serde_json::to_value(&[&question]) { + Ok(v) => v, + Err(e) => { + tracing::warn!(contract_id = %contract_id, error = %e, "Failed to serialize pending question"); + return; + } + }; + + if let Err(e) = repository::add_supervisor_pending_question(pool, contract_id, question_json).await { + tracing::warn!( + contract_id = %contract_id, + question_id = %question.id, + error = %e, + "Failed to save pending question to supervisor state" + ); + } else { + tracing::debug!( + contract_id = %contract_id, + question_id = %question.id, + "Saved pending question to supervisor state" + ); + } +} + +/// Clear pending question after it's answered. +pub async fn clear_pending_question( + pool: &PgPool, + contract_id: Uuid, + question_id: Uuid, +) { + if let Err(e) = repository::remove_supervisor_pending_question(pool, contract_id, question_id).await { + tracing::warn!( + contract_id = %contract_id, + question_id = %question_id, + error = %e, + "Failed to remove pending question from supervisor state" + ); + } + + // Update state back to working (if no more pending questions) + match repository::get_supervisor_state(pool, contract_id).await { + Ok(Some(state)) => { + let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone()) + .unwrap_or_default(); + if questions.is_empty() { + let _ = repository::update_supervisor_detailed_state( + pool, + contract_id, + "working", + Some("Resumed after user response"), + state.progress, + None, + ).await; + } + } + Ok(None) => {} + Err(e) => { + tracing::warn!(contract_id = %contract_id, error = %e, "Failed to check supervisor state after clearing question"); + } + } +} + +/// Save supervisor state on phase change. +pub async fn save_state_on_phase_change( + pool: &PgPool, + contract_id: Uuid, + new_phase: &str, +) { + if let Err(e) = repository::update_supervisor_phase(pool, contract_id, new_phase).await { + tracing::warn!( + contract_id = %contract_id, + new_phase = %new_phase, + error = %e, + "Failed to update supervisor state on phase change" + ); + } else { + tracing::info!( + contract_id = %contract_id, + new_phase = %new_phase, + "Updated supervisor state on phase change" + ); + } +} + +// ============================================================================= +// Supervisor Restoration Protocol (Task 3.4) +// ============================================================================= + +/// Validate supervisor state consistency before restoration. +/// Checks that spawned tasks and pending questions are in expected states. +pub async fn validate_supervisor_state( + pool: &PgPool, + state: &crate::db::models::SupervisorState, +) -> StateValidationResult { + let mut issues = Vec::new(); + + // Validate spawned tasks + if !state.spawned_task_ids.is_empty() { + match repository::validate_spawned_tasks(pool, &state.spawned_task_ids).await { + Ok(task_statuses) => { + for task_id in &state.spawned_task_ids { + if !task_statuses.contains_key(task_id) { + issues.push(format!("Spawned task {} not found in database", task_id)); + } + } + } + Err(e) => { + issues.push(format!("Failed to validate spawned tasks: {}", e)); + } + } + } + + // Validate pending questions + let pending_questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone()) + .unwrap_or_default(); + + // Check if questions are not too old (e.g., more than 24 hours) + for question in &pending_questions { + let age = chrono::Utc::now() - question.asked_at; + if age.num_hours() > 24 { + issues.push(format!( + "Pending question {} is {} hours old, may be stale", + question.id, age.num_hours() + )); + } + } + + // Validate conversation history + if let Some(history) = state.conversation_history.as_array() { + if history.is_empty() && state.restoration_count > 0 { + issues.push("Conversation history is empty after previous restoration".to_string()); + } + } + + // Determine recovery action + let recovery_action = if issues.is_empty() { + StateRecoveryAction::Proceed + } else if issues.iter().any(|i| i.contains("not found")) { + // Missing tasks suggest corruption - use checkpoint + StateRecoveryAction::UseCheckpoint + } else if issues.len() > 3 { + // Many issues suggest manual intervention needed + StateRecoveryAction::ManualIntervention + } else { + // Minor issues - proceed with warnings + StateRecoveryAction::Proceed + }; + + StateValidationResult { + is_valid: issues.is_empty(), + issues, + recovery_action, + } +} + +/// Restore supervisor from saved state after daemon crash or task reassignment. +/// Returns restoration context to send to the supervisor. +pub async fn restore_supervisor( + pool: &PgPool, + contract_id: Uuid, + restoration_source: &str, +) -> Result<SupervisorRestorationContext, String> { + // Step 1: Load supervisor state + let state = match repository::get_supervisor_state_for_restoration(pool, contract_id).await { + Ok(Some(s)) => s, + Ok(None) => { + tracing::warn!( + contract_id = %contract_id, + "No supervisor state found for restoration - starting fresh" + ); + return Ok(SupervisorRestorationContext { + success: true, + previous_state: SupervisorStateEnum::Initializing, + conversation_history: serde_json::json!([]), + pending_questions: vec![], + waiting_task_ids: vec![], + spawned_task_ids: vec![], + restoration_count: 0, + restoration_context_message: "No previous state found. Starting fresh.".to_string(), + warnings: vec!["No previous supervisor state found".to_string()], + }); + } + Err(e) => { + return Err(format!("Failed to load supervisor state: {}", e)); + } + }; + + // Step 2: Parse previous state + let previous_state: SupervisorStateEnum = state.state.parse().unwrap_or(SupervisorStateEnum::Interrupted); + + // Step 3: Validate state consistency + let validation = validate_supervisor_state(pool, &state).await; + let mut warnings = validation.issues.clone(); + + // Step 4: Handle based on validation result + let (conversation_history, pending_questions, restoration_message) = match validation.recovery_action { + StateRecoveryAction::Proceed => { + // State is valid, use it + let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone()) + .unwrap_or_default(); + + let message = format!( + "Restored from {} state. {} pending questions, {} spawned tasks, {} waiting tasks.", + state.state, + questions.len(), + state.spawned_task_ids.len(), + state.pending_task_ids.len() + ); + + (state.conversation_history.clone(), questions, message) + } + StateRecoveryAction::UseCheckpoint => { + // State is corrupted, try to use checkpoint + warnings.push("State validation failed, attempting checkpoint recovery".to_string()); + + // TODO: Implement checkpoint-based recovery + // For now, start with empty questions but preserve conversation + let message = "Restored from last checkpoint due to state inconsistency.".to_string(); + (state.conversation_history.clone(), vec![], message) + } + StateRecoveryAction::StartFresh => { + warnings.push("Starting fresh due to unrecoverable state".to_string()); + let message = "Starting fresh due to unrecoverable state corruption.".to_string(); + (serde_json::json!([]), vec![], message) + } + StateRecoveryAction::ManualIntervention => { + warnings.push("Manual intervention may be required".to_string()); + // Still try to restore but with warning + let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone()) + .unwrap_or_default(); + let message = "Restored with warnings - manual intervention may be required.".to_string(); + (state.conversation_history.clone(), questions, message) + } + }; + + // Step 5: Mark supervisor as restored + let new_state = match repository::mark_supervisor_restored(pool, contract_id, restoration_source).await { + Ok(s) => s, + Err(e) => { + return Err(format!("Failed to mark supervisor as restored: {}", e)); + } + }; + + // Step 6: Build restoration context + let context = SupervisorRestorationContext { + success: true, + previous_state, + conversation_history, + pending_questions, + waiting_task_ids: state.pending_task_ids.clone(), + spawned_task_ids: state.spawned_task_ids.clone(), + restoration_count: new_state.restoration_count, + restoration_context_message: restoration_message, + warnings, + }; + + tracing::info!( + contract_id = %contract_id, + restoration_source = %restoration_source, + restoration_count = new_state.restoration_count, + pending_questions_count = context.pending_questions.len(), + waiting_tasks_count = context.waiting_task_ids.len(), + spawned_tasks_count = context.spawned_task_ids.len(), + "Supervisor restoration completed" + ); + + Ok(context) +} + +/// Re-deliver pending questions to the user after restoration. +/// This ensures questions asked before crash are shown to the user again. +pub async fn redeliver_pending_questions( + state: &SharedState, + supervisor_id: Uuid, + contract_id: Uuid, + owner_id: Uuid, + questions: &[PendingQuestion], +) { + for question in questions { + // Add to in-memory question state + state.add_supervisor_question( + supervisor_id, + contract_id, + owner_id, + question.question.clone(), + question.choices.clone(), + question.context.clone(), + false, // Assume single select for restored questions + question.question_type.clone(), + ); + + // Broadcast to WebSocket clients + let question_data = serde_json::json!({ + "question_id": question.id.to_string(), + "choices": question.choices, + "context": question.context, + "question_type": question.question_type, + "is_restored": true, + "originally_asked_at": question.asked_at.to_rfc3339(), + }); + + state.broadcast_task_output(TaskOutputNotification { + task_id: supervisor_id, + owner_id: Some(owner_id), + message_type: "supervisor_question".to_string(), + content: question.question.clone(), + tool_name: None, + tool_input: Some(question_data), + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + + tracing::info!( + supervisor_id = %supervisor_id, + question_id = %question.id, + "Re-delivered pending question after restoration" + ); + } +} + +/// Generate restoration context message for Claude. +/// This message is injected into the conversation to inform Claude about the restoration. +pub fn generate_restoration_context_message(context: &SupervisorRestorationContext) -> String { + let mut message = String::new(); + + message.push_str("=== SUPERVISOR RESTORATION NOTICE ===\n\n"); + message.push_str(&format!("This supervisor has been restored after interruption. {}\n\n", context.restoration_context_message)); + message.push_str(&format!("Restoration count: {}\n", context.restoration_count)); + + if !context.pending_questions.is_empty() { + message.push_str(&format!("\nPending questions ({}): These have been re-delivered to the user.\n", context.pending_questions.len())); + for q in &context.pending_questions { + message.push_str(&format!(" - {}: {}\n", q.id, q.question)); + } + } + + if !context.waiting_task_ids.is_empty() { + message.push_str(&format!("\nWaiting on {} task(s) to complete. Check their status before continuing.\n", context.waiting_task_ids.len())); + } + + if !context.spawned_task_ids.is_empty() { + message.push_str(&format!("\n{} task(s) were spawned before interruption. Their status may need verification.\n", context.spawned_task_ids.len())); + } + + if !context.warnings.is_empty() { + message.push_str("\nWarnings:\n"); + for warning in &context.warnings { + message.push_str(&format!(" - {}\n", warning)); + } + } + + message.push_str("\n=== END RESTORATION NOTICE ===\n"); + + message +} |
