diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/db/models.rs | 212 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 429 |
2 files changed, 640 insertions, 1 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 636d81a..a7d2cda 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -1949,6 +1949,64 @@ pub struct CheckpointListResponse { // Supervisor State (for supervisor resumability) // ============================================================================ +/// Supervisor activity state enum +/// Tracks the current operational state of the supervisor +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum SupervisorActivityState { + /// Supervisor is initializing + Initializing, + /// Supervisor is in planning phase + Planning, + /// Supervisor is actively executing work + Executing, + /// Supervisor is waiting for spawned tasks to complete + WaitingForTask, + /// Supervisor is waiting for user input + WaitingForUser, + /// Supervisor encountered an error + Error, + /// Supervisor has completed its work + Completed, +} + +impl Default for SupervisorActivityState { + fn default() -> Self { + Self::Initializing + } +} + +impl std::fmt::Display for SupervisorActivityState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Initializing => write!(f, "initializing"), + Self::Planning => write!(f, "planning"), + Self::Executing => write!(f, "executing"), + Self::WaitingForTask => write!(f, "waiting_for_task"), + Self::WaitingForUser => write!(f, "waiting_for_user"), + Self::Error => write!(f, "error"), + Self::Completed => write!(f, "completed"), + } + } +} + +impl std::str::FromStr for SupervisorActivityState { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "initializing" => Ok(Self::Initializing), + "planning" => Ok(Self::Planning), + "executing" => Ok(Self::Executing), + "waiting_for_task" => Ok(Self::WaitingForTask), + "waiting_for_user" => Ok(Self::WaitingForUser), + "error" => Ok(Self::Error), + "completed" => Ok(Self::Completed), + _ => Err(format!("Unknown supervisor state: {}", s)), + } + } +} + /// Supervisor state for contract supervisor tasks /// Enables resumption after interruption #[derive(Debug, Clone, FromRow, Serialize, ToSchema)] @@ -1971,10 +2029,59 @@ pub struct SupervisorState { pub last_activity: DateTime<Utc>, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, + + // Enhanced fields for Phase 3 crash recovery + /// Current supervisor activity state + pub state: String, + /// Human-readable current activity description + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: Option<i32>, + /// Error message if in error state + pub error_message: Option<String>, + /// All task UUIDs spawned by this supervisor + #[sqlx(try_from = "Vec<Uuid>")] + pub spawned_task_ids: Vec<Uuid>, + /// Pending question UUID if waiting for user input + pub pending_question_id: Option<Uuid>, + /// Last LLM response for context restoration + pub last_llm_response: Option<String>, + /// Number of times this supervisor has been restored + pub restoration_count: Option<i32>, + /// Timestamp of last restoration + pub last_restored_at: Option<DateTime<Utc>>, +} + +impl SupervisorState { + /// Get the parsed supervisor activity state + pub fn activity_state(&self) -> SupervisorActivityState { + self.state.parse().unwrap_or(SupervisorActivityState::Initializing) + } + + /// Check if this supervisor state is restorable + pub fn is_restorable(&self) -> bool { + let state = self.activity_state(); + // Can restore if not in a terminal or error state + !matches!(state, SupervisorActivityState::Completed | SupervisorActivityState::Error) + } + + /// Check if supervisor is waiting for something + pub fn is_waiting(&self) -> bool { + let state = self.activity_state(); + matches!( + state, + SupervisorActivityState::WaitingForTask | SupervisorActivityState::WaitingForUser + ) + } + + /// Check if supervisor has pending questions + pub fn has_pending_question(&self) -> bool { + self.pending_question_id.is_some() + } } /// Request to update supervisor state -#[derive(Debug, Deserialize, ToSchema)] +#[derive(Debug, Default, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct UpdateSupervisorStateRequest { /// Updated conversation history @@ -1983,6 +2090,109 @@ pub struct UpdateSupervisorStateRequest { pub pending_task_ids: Option<Vec<Uuid>>, /// Current contract phase pub phase: Option<String>, + /// Current supervisor activity state + pub state: Option<String>, + /// Human-readable current activity description + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: Option<i32>, + /// Error message if in error state + pub error_message: Option<String>, + /// All spawned task IDs + pub spawned_task_ids: Option<Vec<Uuid>>, + /// Pending question UUID + pub pending_question_id: Option<Uuid>, + /// Clear the pending question (set to None) + #[serde(default)] + pub clear_pending_question: bool, + /// Last LLM response + pub last_llm_response: Option<String>, +} + +/// Request to save supervisor state at specific save points +#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SaveSupervisorStateRequest { + /// The save point type that triggered this save + pub save_point: SupervisorSavePoint, + /// Updated conversation history (if available) + pub conversation_history: Option<serde_json::Value>, + /// Current state + pub state: Option<String>, + /// Current activity description + pub current_activity: Option<String>, + /// Progress percentage + pub progress: Option<i32>, + /// Last LLM response + pub last_llm_response: Option<String>, + /// Task that was spawned (for task_spawn save point) + pub spawned_task_id: Option<Uuid>, + /// Question that was asked (for question_asked save point) + pub question_id: Option<Uuid>, + /// Error message (for error save point) + pub error_message: Option<String>, +} + +/// Types of save points for supervisor state persistence +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum SupervisorSavePoint { + /// Save on every LLM response + LlmResponse, + /// Save when spawning a new task + TaskSpawn, + /// Save when asking a question to user + QuestionAsked, + /// Save when phase changes + PhaseChange, + /// Lightweight heartbeat update + Heartbeat, + /// Save when an error occurs + Error, + /// Save when task completes + TaskComplete, +} + +/// Supervisor restoration context +/// Contains all information needed to restore a supervisor after crash +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorRestorationContext { + /// The restored supervisor state + pub state: SupervisorState, + /// Tasks that were pending when the crash occurred + pub pending_tasks: Vec<TaskSummary>, + /// Pending question that needs re-delivery (if any) + pub pending_question: Option<PendingQuestionInfo>, + /// Whether state was fully valid or partially recovered + pub restoration_type: RestorationResult, + /// Human-readable restoration message + pub message: String, +} + +/// Information about a pending question for restoration +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PendingQuestionInfo { + pub question_id: Uuid, + pub question: String, + pub choices: Vec<String>, + pub context: Option<String>, + pub multi_select: bool, +} + +/// Result of restoration attempt +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum RestorationResult { + /// State was fully valid and restored + FullRestore, + /// State was valid but restored from last checkpoint + CheckpointRestore, + /// Started fresh due to invalid/missing state + FreshStart, + /// Partial restoration with some context lost + PartialRestore, } // ============================================================================ diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index b7c5af1..c62de34 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -3404,6 +3404,435 @@ pub async fn update_supervisor_pending_tasks( .await } +/// Save supervisor state at a specific save point. +/// This is the main function for persisting supervisor state during operation. +pub async fn save_supervisor_state_at_savepoint( + pool: &PgPool, + contract_id: Uuid, + save_point: &str, + state: Option<&str>, + current_activity: Option<&str>, + progress: Option<i32>, + last_llm_response: Option<&str>, + spawned_task_id: Option<Uuid>, + pending_question_id: Option<Uuid>, + error_message: Option<&str>, + conversation_history: Option<serde_json::Value>, +) -> Result<SupervisorState, sqlx::Error> { + // Build dynamic UPDATE based on what fields are provided + let mut updates = vec!["last_activity = NOW()".to_string(), "updated_at = NOW()".to_string()]; + let mut bind_idx = 2; // Start at 2 since $1 is contract_id + + if state.is_some() { + bind_idx += 1; + updates.push(format!("state = ${}", bind_idx)); + } + if current_activity.is_some() { + bind_idx += 1; + updates.push(format!("current_activity = ${}", bind_idx)); + } + if progress.is_some() { + bind_idx += 1; + updates.push(format!("progress = ${}", bind_idx)); + } + if last_llm_response.is_some() { + bind_idx += 1; + updates.push(format!("last_llm_response = ${}", bind_idx)); + } + if spawned_task_id.is_some() { + bind_idx += 1; + updates.push(format!("spawned_task_ids = array_append(spawned_task_ids, ${})", bind_idx)); + } + if pending_question_id.is_some() { + bind_idx += 1; + updates.push(format!("pending_question_id = ${}", bind_idx)); + } else if save_point == "llm_response" { + // Clear pending question on LLM response (question was answered) + updates.push("pending_question_id = NULL".to_string()); + } + if error_message.is_some() { + bind_idx += 1; + updates.push(format!("error_message = ${}", bind_idx)); + } + if conversation_history.is_some() { + bind_idx += 1; + updates.push(format!("conversation_history = ${}", bind_idx)); + } + + let query = format!( + "UPDATE supervisor_states SET {} WHERE contract_id = $1 RETURNING *", + updates.join(", ") + ); + + let mut query_builder = sqlx::query_as::<_, SupervisorState>(&query).bind(contract_id); + + // Bind values in order + if let Some(v) = state { + query_builder = query_builder.bind(v); + } + if let Some(v) = current_activity { + query_builder = query_builder.bind(v); + } + if let Some(v) = progress { + query_builder = query_builder.bind(v); + } + if let Some(v) = last_llm_response { + query_builder = query_builder.bind(v); + } + if let Some(v) = spawned_task_id { + query_builder = query_builder.bind(v); + } + if let Some(v) = pending_question_id { + query_builder = query_builder.bind(v); + } + if let Some(v) = error_message { + query_builder = query_builder.bind(v); + } + if let Some(v) = conversation_history { + query_builder = query_builder.bind(v); + } + + query_builder.fetch_one(pool).await +} + +/// Update supervisor activity state (lightweight update for heartbeats). +pub async fn update_supervisor_activity_state( + pool: &PgPool, + contract_id: Uuid, + state: &str, + current_activity: Option<&str>, + progress: Option<i32>, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET state = $2, + current_activity = COALESCE($3, current_activity), + progress = COALESCE($4, progress), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(state) + .bind(current_activity) + .bind(progress) + .fetch_one(pool) + .await +} + +/// Update supervisor with spawned task. +pub async fn add_spawned_task_to_supervisor( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET spawned_task_ids = array_append(spawned_task_ids, $2), + pending_task_ids = array_append(pending_task_ids, $2), + state = 'waiting_for_task', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(task_id) + .fetch_one(pool) + .await +} + +/// Set pending question on supervisor. +pub async fn set_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, + question_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_question_id = $2, + state = 'waiting_for_user', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(question_id) + .fetch_one(pool) + .await +} + +/// Clear pending question (after user responds). +pub async fn clear_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_question_id = NULL, + state = 'executing', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Set supervisor error state. +pub async fn set_supervisor_error( + pool: &PgPool, + contract_id: Uuid, + error_message: &str, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET state = 'error', + error_message = $2, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(error_message) + .fetch_one(pool) + .await +} + +/// Save last LLM response for restoration context. +pub async fn update_supervisor_llm_response( + pool: &PgPool, + contract_id: Uuid, + last_llm_response: &str, + conversation_history: Option<serde_json::Value>, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET last_llm_response = $2, + conversation_history = COALESCE($3, conversation_history), + state = 'executing', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(last_llm_response) + .bind(conversation_history) + .fetch_one(pool) + .await +} + +/// Mark supervisor as restored and increment restoration count. +pub async fn mark_supervisor_restored( + pool: &PgPool, + contract_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET restoration_count = COALESCE(restoration_count, 0) + 1, + last_restored_at = NOW(), + error_message = NULL, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Get supervisor state with full context for restoration. +/// Returns the state along with any pending tasks and questions. +pub async fn get_supervisor_restoration_context( + pool: &PgPool, + contract_id: Uuid, + owner_id: Uuid, +) -> Result<Option<(SupervisorState, Vec<Task>)>, sqlx::Error> { + let state = get_supervisor_state(pool, contract_id).await?; + + let Some(state) = state else { + return Ok(None); + }; + + // Get pending tasks + let pending_tasks = if !state.pending_task_ids.is_empty() { + sqlx::query_as::<_, Task>( + r#" + SELECT * FROM tasks + WHERE id = ANY($1) AND owner_id = $2 + ORDER BY created_at ASC + "#, + ) + .bind(&state.pending_task_ids) + .bind(owner_id) + .fetch_all(pool) + .await? + } else { + vec![] + }; + + Ok(Some((state, pending_tasks))) +} + +/// Validate supervisor state consistency. +/// Checks if the state can be used for restoration or needs recovery. +pub async fn validate_supervisor_state( + pool: &PgPool, + contract_id: Uuid, + owner_id: Uuid, +) -> Result<StateValidationResult, sqlx::Error> { + let state = get_supervisor_state(pool, contract_id).await?; + + let Some(state) = state else { + return Ok(StateValidationResult::NotFound); + }; + + // Check if task still exists + let task = get_task(pool, state.task_id).await?; + if task.is_none() { + return Ok(StateValidationResult::Invalid { + reason: "Supervisor task no longer exists".to_string(), + }); + } + + // Check if pending tasks are valid + let mut invalid_tasks = vec![]; + for task_id in &state.pending_task_ids { + let pending = get_task(pool, *task_id).await?; + if pending.is_none() { + invalid_tasks.push(*task_id); + } + } + + if !invalid_tasks.is_empty() { + return Ok(StateValidationResult::PartiallyValid { + state, + invalid_task_ids: invalid_tasks, + }); + } + + // Check contract phase consistency + let contract = get_contract_for_owner(pool, contract_id, owner_id).await?; + if let Some(contract) = contract { + if contract.phase != state.phase { + return Ok(StateValidationResult::PhaseStale { + state, + current_phase: contract.phase, + }); + } + } + + Ok(StateValidationResult::Valid(state)) +} + +/// Result of supervisor state validation. +#[derive(Debug)] +pub enum StateValidationResult { + /// State is valid and can be restored as-is + Valid(SupervisorState), + /// State not found for this contract + NotFound, + /// State is invalid and cannot be used + Invalid { reason: String }, + /// State is partially valid (some tasks are missing) + PartiallyValid { + state: SupervisorState, + invalid_task_ids: Vec<Uuid>, + }, + /// State phase doesn't match current contract phase + PhaseStale { + state: SupervisorState, + current_phase: String, + }, +} + +/// Remove a task from supervisor's pending tasks (when task completes). +pub async fn remove_pending_task_from_supervisor( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_task_ids = array_remove(pending_task_ids, $2), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .bind(task_id) + .fetch_optional(pool) + .await +} + +/// Enhanced upsert for supervisor state with all new fields. +pub async fn upsert_enhanced_supervisor_state( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, + conversation_history: serde_json::Value, + pending_task_ids: &[Uuid], + phase: &str, + state: &str, + current_activity: Option<&str>, + progress: Option<i32>, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + INSERT INTO supervisor_states ( + contract_id, task_id, conversation_history, pending_task_ids, + phase, state, current_activity, progress, last_activity + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW()) + ON CONFLICT (contract_id) DO UPDATE SET + task_id = EXCLUDED.task_id, + conversation_history = EXCLUDED.conversation_history, + pending_task_ids = EXCLUDED.pending_task_ids, + phase = EXCLUDED.phase, + state = EXCLUDED.state, + current_activity = EXCLUDED.current_activity, + progress = EXCLUDED.progress, + last_activity = NOW(), + updated_at = NOW() + RETURNING * + "#, + ) + .bind(contract_id) + .bind(task_id) + .bind(conversation_history) + .bind(pending_task_ids) + .bind(phase) + .bind(state) + .bind(current_activity) + .bind(progress) + .fetch_one(pool) + .await +} + // ============================================================================ // Contract Supervisor // ============================================================================ |
