summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/mesh_supervisor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/handlers/mesh_supervisor.rs')
-rw-r--r--makima/src/server/handlers/mesh_supervisor.rs478
1 files changed, 477 insertions, 1 deletions
diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs
index 3411ec0..a29b666 100644
--- a/makima/src/server/handlers/mesh_supervisor.rs
+++ b/makima/src/server/handlers/mesh_supervisor.rs
@@ -14,8 +14,9 @@ use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
-use crate::db::models::{CreateTaskRequest, Task, TaskSummary, UpdateTaskRequest};
+use crate::db::models::{CreateTaskRequest, PendingQuestion, Task, TaskSummary, UpdateTaskRequest};
use crate::db::repository;
+use sqlx::PgPool;
use crate::server::auth::Authenticated;
use crate::server::handlers::mesh::{extract_auth, AuthSource};
use crate::server::messages::ApiError;
@@ -748,6 +749,9 @@ pub async fn spawn_task(
} else {
tracing::info!(task_id = %updated_task.id, daemon_id = %daemon.id, repo_url = ?repo_url, "Task spawn command sent");
+ // Save state: task spawn is a key save point (Task 3.3)
+ save_state_on_task_spawn(pool, request.contract_id, updated_task.id).await;
+
// Broadcast task status update notification to WebSocket subscribers
state.broadcast_task_update(TaskUpdateNotification {
task_id: updated_task.id,
@@ -1770,6 +1774,17 @@ pub async fn ask_question(
request.question_type.clone(),
);
+ // Save state: question asked is a key save point (Task 3.3)
+ let pending_question = PendingQuestion {
+ id: question_id,
+ question: request.question.clone(),
+ choices: request.choices.clone(),
+ context: request.context.clone(),
+ question_type: request.question_type.clone(),
+ asked_at: chrono::Utc::now(),
+ };
+ save_state_on_question_asked(pool, contract_id, pending_question).await;
+
// Broadcast question as task output entry for the task's chat
let question_data = serde_json::json!({
"question_id": question_id.to_string(),
@@ -1864,6 +1879,9 @@ pub async fn ask_question(
// Clean up the response
state.cleanup_question_response(question_id);
+ // Clear pending question from supervisor state (Task 3.3)
+ clear_pending_question(pool, contract_id, question_id).await;
+
return (
StatusCode::OK,
Json(AskQuestionResponse {
@@ -1879,6 +1897,9 @@ pub async fn ask_question(
// Remove the pending question on timeout
state.remove_pending_question(question_id);
+ // Clear pending question from supervisor state on timeout (Task 3.3)
+ clear_pending_question(pool, contract_id, question_id).await;
+
return (
StatusCode::REQUEST_TIMEOUT,
Json(AskQuestionResponse {
@@ -2031,6 +2052,8 @@ pub struct ResumeSupervisorResponse {
pub daemon_id: Option<Uuid>,
pub resumed_from: ResumedFromInfo,
pub status: String,
+ /// Restoration context (Task 3.4)
+ pub restoration: Option<RestorationInfo>,
}
#[derive(Debug, Serialize, ToSchema)]
@@ -2041,6 +2064,24 @@ pub struct ResumedFromInfo {
pub message_count: i32,
}
+/// Information about supervisor restoration (Task 3.4)
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct RestorationInfo {
+ /// Previous state before restoration
+ pub previous_state: String,
+ /// How many times this supervisor has been restored
+ pub restoration_count: i32,
+ /// Number of pending questions to re-deliver
+ pub pending_questions_count: usize,
+ /// Number of tasks being waited on
+ pub waiting_tasks_count: usize,
+ /// Number of tasks spawned before crash
+ pub spawned_tasks_count: usize,
+ /// Any warnings from state validation
+ pub warnings: Vec<String>,
+}
+
/// Resume interrupted supervisor with specified mode.
///
/// POST /api/v1/contracts/{id}/supervisor/resume
@@ -2350,6 +2391,31 @@ pub async fn resume_supervisor(
"Supervisor resume requested"
);
+ // Build restoration info (Task 3.4)
+ let pending_questions: Vec<PendingQuestion> = serde_json::from_value(
+ supervisor_state.pending_questions.clone()
+ ).unwrap_or_default();
+
+ let restoration_info = RestorationInfo {
+ previous_state: supervisor_state.state.clone(),
+ restoration_count: supervisor_state.restoration_count,
+ pending_questions_count: pending_questions.len(),
+ waiting_tasks_count: supervisor_state.pending_task_ids.len(),
+ spawned_tasks_count: supervisor_state.spawned_task_ids.len(),
+ warnings: vec![], // Could add validation warnings here
+ };
+
+ // Re-deliver pending questions if any (Task 3.4)
+ if !pending_questions.is_empty() {
+ redeliver_pending_questions(
+ &state,
+ supervisor_state.task_id,
+ contract_id,
+ auth_info.owner_id,
+ &pending_questions,
+ ).await;
+ }
+
Json(ResumeSupervisorResponse {
supervisor_task_id: supervisor_state.task_id,
daemon_id: response_daemon_id,
@@ -2359,6 +2425,7 @@ pub async fn resume_supervisor(
message_count,
},
status: response_status,
+ restoration: Some(restoration_info),
})
.into_response()
}
@@ -2748,3 +2815,412 @@ pub async fn spawn_red_team_task(
// It will remain pending and can be started later
Ok(task)
}
+
+// =============================================================================
+// Supervisor State Persistence Helpers (Task 3.3)
+// =============================================================================
+
+use crate::db::models::{
+ SupervisorRestorationContext, SupervisorStateEnum,
+ StateValidationResult, StateRecoveryAction,
+};
+
+/// Save supervisor state on task spawn.
+/// This is called when a supervisor spawns a new task.
+pub async fn save_state_on_task_spawn(
+ pool: &PgPool,
+ contract_id: Uuid,
+ spawned_task_id: Uuid,
+) {
+ if let Err(e) = repository::add_supervisor_spawned_task(pool, contract_id, spawned_task_id).await {
+ tracing::warn!(
+ contract_id = %contract_id,
+ spawned_task_id = %spawned_task_id,
+ error = %e,
+ "Failed to save spawned task to supervisor state"
+ );
+ } else {
+ tracing::debug!(
+ contract_id = %contract_id,
+ spawned_task_id = %spawned_task_id,
+ "Saved spawned task to supervisor state"
+ );
+ }
+
+ // Also update state to working
+ if let Err(e) = repository::update_supervisor_detailed_state(
+ pool,
+ contract_id,
+ "working",
+ Some(&format!("Spawned task {}", spawned_task_id)),
+ 0, // Progress resets when spawning new work
+ None,
+ ).await {
+ tracing::warn!(contract_id = %contract_id, error = %e, "Failed to update supervisor state on task spawn");
+ }
+}
+
+/// Save supervisor state on question asked.
+/// This is called when a supervisor asks a question and is waiting for user input.
+pub async fn save_state_on_question_asked(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question: PendingQuestion,
+) {
+ let question_json = match serde_json::to_value(&[&question]) {
+ Ok(v) => v,
+ Err(e) => {
+ tracing::warn!(contract_id = %contract_id, error = %e, "Failed to serialize pending question");
+ return;
+ }
+ };
+
+ if let Err(e) = repository::add_supervisor_pending_question(pool, contract_id, question_json).await {
+ tracing::warn!(
+ contract_id = %contract_id,
+ question_id = %question.id,
+ error = %e,
+ "Failed to save pending question to supervisor state"
+ );
+ } else {
+ tracing::debug!(
+ contract_id = %contract_id,
+ question_id = %question.id,
+ "Saved pending question to supervisor state"
+ );
+ }
+}
+
+/// Clear pending question after it's answered.
+pub async fn clear_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question_id: Uuid,
+) {
+ if let Err(e) = repository::remove_supervisor_pending_question(pool, contract_id, question_id).await {
+ tracing::warn!(
+ contract_id = %contract_id,
+ question_id = %question_id,
+ error = %e,
+ "Failed to remove pending question from supervisor state"
+ );
+ }
+
+ // Update state back to working (if no more pending questions)
+ match repository::get_supervisor_state(pool, contract_id).await {
+ Ok(Some(state)) => {
+ let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone())
+ .unwrap_or_default();
+ if questions.is_empty() {
+ let _ = repository::update_supervisor_detailed_state(
+ pool,
+ contract_id,
+ "working",
+ Some("Resumed after user response"),
+ state.progress,
+ None,
+ ).await;
+ }
+ }
+ Ok(None) => {}
+ Err(e) => {
+ tracing::warn!(contract_id = %contract_id, error = %e, "Failed to check supervisor state after clearing question");
+ }
+ }
+}
+
+/// Save supervisor state on phase change.
+pub async fn save_state_on_phase_change(
+ pool: &PgPool,
+ contract_id: Uuid,
+ new_phase: &str,
+) {
+ if let Err(e) = repository::update_supervisor_phase(pool, contract_id, new_phase).await {
+ tracing::warn!(
+ contract_id = %contract_id,
+ new_phase = %new_phase,
+ error = %e,
+ "Failed to update supervisor state on phase change"
+ );
+ } else {
+ tracing::info!(
+ contract_id = %contract_id,
+ new_phase = %new_phase,
+ "Updated supervisor state on phase change"
+ );
+ }
+}
+
+// =============================================================================
+// Supervisor Restoration Protocol (Task 3.4)
+// =============================================================================
+
+/// Validate supervisor state consistency before restoration.
+/// Checks that spawned tasks and pending questions are in expected states.
+pub async fn validate_supervisor_state(
+ pool: &PgPool,
+ state: &crate::db::models::SupervisorState,
+) -> StateValidationResult {
+ let mut issues = Vec::new();
+
+ // Validate spawned tasks
+ if !state.spawned_task_ids.is_empty() {
+ match repository::validate_spawned_tasks(pool, &state.spawned_task_ids).await {
+ Ok(task_statuses) => {
+ for task_id in &state.spawned_task_ids {
+ if !task_statuses.contains_key(task_id) {
+ issues.push(format!("Spawned task {} not found in database", task_id));
+ }
+ }
+ }
+ Err(e) => {
+ issues.push(format!("Failed to validate spawned tasks: {}", e));
+ }
+ }
+ }
+
+ // Validate pending questions
+ let pending_questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone())
+ .unwrap_or_default();
+
+ // Check if questions are not too old (e.g., more than 24 hours)
+ for question in &pending_questions {
+ let age = chrono::Utc::now() - question.asked_at;
+ if age.num_hours() > 24 {
+ issues.push(format!(
+ "Pending question {} is {} hours old, may be stale",
+ question.id, age.num_hours()
+ ));
+ }
+ }
+
+ // Validate conversation history
+ if let Some(history) = state.conversation_history.as_array() {
+ if history.is_empty() && state.restoration_count > 0 {
+ issues.push("Conversation history is empty after previous restoration".to_string());
+ }
+ }
+
+ // Determine recovery action
+ let recovery_action = if issues.is_empty() {
+ StateRecoveryAction::Proceed
+ } else if issues.iter().any(|i| i.contains("not found")) {
+ // Missing tasks suggest corruption - use checkpoint
+ StateRecoveryAction::UseCheckpoint
+ } else if issues.len() > 3 {
+ // Many issues suggest manual intervention needed
+ StateRecoveryAction::ManualIntervention
+ } else {
+ // Minor issues - proceed with warnings
+ StateRecoveryAction::Proceed
+ };
+
+ StateValidationResult {
+ is_valid: issues.is_empty(),
+ issues,
+ recovery_action,
+ }
+}
+
+/// Restore supervisor from saved state after daemon crash or task reassignment.
+/// Returns restoration context to send to the supervisor.
+pub async fn restore_supervisor(
+ pool: &PgPool,
+ contract_id: Uuid,
+ restoration_source: &str,
+) -> Result<SupervisorRestorationContext, String> {
+ // Step 1: Load supervisor state
+ let state = match repository::get_supervisor_state_for_restoration(pool, contract_id).await {
+ Ok(Some(s)) => s,
+ Ok(None) => {
+ tracing::warn!(
+ contract_id = %contract_id,
+ "No supervisor state found for restoration - starting fresh"
+ );
+ return Ok(SupervisorRestorationContext {
+ success: true,
+ previous_state: SupervisorStateEnum::Initializing,
+ conversation_history: serde_json::json!([]),
+ pending_questions: vec![],
+ waiting_task_ids: vec![],
+ spawned_task_ids: vec![],
+ restoration_count: 0,
+ restoration_context_message: "No previous state found. Starting fresh.".to_string(),
+ warnings: vec!["No previous supervisor state found".to_string()],
+ });
+ }
+ Err(e) => {
+ return Err(format!("Failed to load supervisor state: {}", e));
+ }
+ };
+
+ // Step 2: Parse previous state
+ let previous_state: SupervisorStateEnum = state.state.parse().unwrap_or(SupervisorStateEnum::Interrupted);
+
+ // Step 3: Validate state consistency
+ let validation = validate_supervisor_state(pool, &state).await;
+ let mut warnings = validation.issues.clone();
+
+ // Step 4: Handle based on validation result
+ let (conversation_history, pending_questions, restoration_message) = match validation.recovery_action {
+ StateRecoveryAction::Proceed => {
+ // State is valid, use it
+ let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone())
+ .unwrap_or_default();
+
+ let message = format!(
+ "Restored from {} state. {} pending questions, {} spawned tasks, {} waiting tasks.",
+ state.state,
+ questions.len(),
+ state.spawned_task_ids.len(),
+ state.pending_task_ids.len()
+ );
+
+ (state.conversation_history.clone(), questions, message)
+ }
+ StateRecoveryAction::UseCheckpoint => {
+ // State is corrupted, try to use checkpoint
+ warnings.push("State validation failed, attempting checkpoint recovery".to_string());
+
+ // TODO: Implement checkpoint-based recovery
+ // For now, start with empty questions but preserve conversation
+ let message = "Restored from last checkpoint due to state inconsistency.".to_string();
+ (state.conversation_history.clone(), vec![], message)
+ }
+ StateRecoveryAction::StartFresh => {
+ warnings.push("Starting fresh due to unrecoverable state".to_string());
+ let message = "Starting fresh due to unrecoverable state corruption.".to_string();
+ (serde_json::json!([]), vec![], message)
+ }
+ StateRecoveryAction::ManualIntervention => {
+ warnings.push("Manual intervention may be required".to_string());
+ // Still try to restore but with warning
+ let questions: Vec<PendingQuestion> = serde_json::from_value(state.pending_questions.clone())
+ .unwrap_or_default();
+ let message = "Restored with warnings - manual intervention may be required.".to_string();
+ (state.conversation_history.clone(), questions, message)
+ }
+ };
+
+ // Step 5: Mark supervisor as restored
+ let new_state = match repository::mark_supervisor_restored(pool, contract_id, restoration_source).await {
+ Ok(s) => s,
+ Err(e) => {
+ return Err(format!("Failed to mark supervisor as restored: {}", e));
+ }
+ };
+
+ // Step 6: Build restoration context
+ let context = SupervisorRestorationContext {
+ success: true,
+ previous_state,
+ conversation_history,
+ pending_questions,
+ waiting_task_ids: state.pending_task_ids.clone(),
+ spawned_task_ids: state.spawned_task_ids.clone(),
+ restoration_count: new_state.restoration_count,
+ restoration_context_message: restoration_message,
+ warnings,
+ };
+
+ tracing::info!(
+ contract_id = %contract_id,
+ restoration_source = %restoration_source,
+ restoration_count = new_state.restoration_count,
+ pending_questions_count = context.pending_questions.len(),
+ waiting_tasks_count = context.waiting_task_ids.len(),
+ spawned_tasks_count = context.spawned_task_ids.len(),
+ "Supervisor restoration completed"
+ );
+
+ Ok(context)
+}
+
+/// Re-deliver pending questions to the user after restoration.
+/// This ensures questions asked before crash are shown to the user again.
+pub async fn redeliver_pending_questions(
+ state: &SharedState,
+ supervisor_id: Uuid,
+ contract_id: Uuid,
+ owner_id: Uuid,
+ questions: &[PendingQuestion],
+) {
+ for question in questions {
+ // Add to in-memory question state
+ state.add_supervisor_question(
+ supervisor_id,
+ contract_id,
+ owner_id,
+ question.question.clone(),
+ question.choices.clone(),
+ question.context.clone(),
+ false, // Assume single select for restored questions
+ question.question_type.clone(),
+ );
+
+ // Broadcast to WebSocket clients
+ let question_data = serde_json::json!({
+ "question_id": question.id.to_string(),
+ "choices": question.choices,
+ "context": question.context,
+ "question_type": question.question_type,
+ "is_restored": true,
+ "originally_asked_at": question.asked_at.to_rfc3339(),
+ });
+
+ state.broadcast_task_output(TaskOutputNotification {
+ task_id: supervisor_id,
+ owner_id: Some(owner_id),
+ message_type: "supervisor_question".to_string(),
+ content: question.question.clone(),
+ tool_name: None,
+ tool_input: Some(question_data),
+ is_error: None,
+ cost_usd: None,
+ duration_ms: None,
+ is_partial: false,
+ });
+
+ tracing::info!(
+ supervisor_id = %supervisor_id,
+ question_id = %question.id,
+ "Re-delivered pending question after restoration"
+ );
+ }
+}
+
+/// Generate restoration context message for Claude.
+/// This message is injected into the conversation to inform Claude about the restoration.
+pub fn generate_restoration_context_message(context: &SupervisorRestorationContext) -> String {
+ let mut message = String::new();
+
+ message.push_str("=== SUPERVISOR RESTORATION NOTICE ===\n\n");
+ message.push_str(&format!("This supervisor has been restored after interruption. {}\n\n", context.restoration_context_message));
+ message.push_str(&format!("Restoration count: {}\n", context.restoration_count));
+
+ if !context.pending_questions.is_empty() {
+ message.push_str(&format!("\nPending questions ({}): These have been re-delivered to the user.\n", context.pending_questions.len()));
+ for q in &context.pending_questions {
+ message.push_str(&format!(" - {}: {}\n", q.id, q.question));
+ }
+ }
+
+ if !context.waiting_task_ids.is_empty() {
+ message.push_str(&format!("\nWaiting on {} task(s) to complete. Check their status before continuing.\n", context.waiting_task_ids.len()));
+ }
+
+ if !context.spawned_task_ids.is_empty() {
+ message.push_str(&format!("\n{} task(s) were spawned before interruption. Their status may need verification.\n", context.spawned_task_ids.len()));
+ }
+
+ if !context.warnings.is_empty() {
+ message.push_str("\nWarnings:\n");
+ for warning in &context.warnings {
+ message.push_str(&format!(" - {}\n", warning));
+ }
+ }
+
+ message.push_str("\n=== END RESTORATION NOTICE ===\n");
+
+ message
+}