summaryrefslogtreecommitdiff
path: root/makima/src
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src')
-rw-r--r--makima/src/db/models.rs212
-rw-r--r--makima/src/db/repository.rs429
2 files changed, 640 insertions, 1 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs
index 636d81a..a7d2cda 100644
--- a/makima/src/db/models.rs
+++ b/makima/src/db/models.rs
@@ -1949,6 +1949,64 @@ pub struct CheckpointListResponse {
// Supervisor State (for supervisor resumability)
// ============================================================================
+/// Supervisor activity state enum
+/// Tracks the current operational state of the supervisor
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum SupervisorActivityState {
+ /// Supervisor is initializing
+ Initializing,
+ /// Supervisor is in planning phase
+ Planning,
+ /// Supervisor is actively executing work
+ Executing,
+ /// Supervisor is waiting for spawned tasks to complete
+ WaitingForTask,
+ /// Supervisor is waiting for user input
+ WaitingForUser,
+ /// Supervisor encountered an error
+ Error,
+ /// Supervisor has completed its work
+ Completed,
+}
+
+impl Default for SupervisorActivityState {
+ fn default() -> Self {
+ Self::Initializing
+ }
+}
+
+impl std::fmt::Display for SupervisorActivityState {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::Initializing => write!(f, "initializing"),
+ Self::Planning => write!(f, "planning"),
+ Self::Executing => write!(f, "executing"),
+ Self::WaitingForTask => write!(f, "waiting_for_task"),
+ Self::WaitingForUser => write!(f, "waiting_for_user"),
+ Self::Error => write!(f, "error"),
+ Self::Completed => write!(f, "completed"),
+ }
+ }
+}
+
+impl std::str::FromStr for SupervisorActivityState {
+ type Err = String;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "initializing" => Ok(Self::Initializing),
+ "planning" => Ok(Self::Planning),
+ "executing" => Ok(Self::Executing),
+ "waiting_for_task" => Ok(Self::WaitingForTask),
+ "waiting_for_user" => Ok(Self::WaitingForUser),
+ "error" => Ok(Self::Error),
+ "completed" => Ok(Self::Completed),
+ _ => Err(format!("Unknown supervisor state: {}", s)),
+ }
+ }
+}
+
/// Supervisor state for contract supervisor tasks
/// Enables resumption after interruption
#[derive(Debug, Clone, FromRow, Serialize, ToSchema)]
@@ -1971,10 +2029,59 @@ pub struct SupervisorState {
pub last_activity: DateTime<Utc>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
+
+ // Enhanced fields for Phase 3 crash recovery
+ /// Current supervisor activity state
+ pub state: String,
+ /// Human-readable current activity description
+ pub current_activity: Option<String>,
+ /// Progress percentage (0-100)
+ pub progress: Option<i32>,
+ /// Error message if in error state
+ pub error_message: Option<String>,
+ /// All task UUIDs spawned by this supervisor
+ #[sqlx(try_from = "Vec<Uuid>")]
+ pub spawned_task_ids: Vec<Uuid>,
+ /// Pending question UUID if waiting for user input
+ pub pending_question_id: Option<Uuid>,
+ /// Last LLM response for context restoration
+ pub last_llm_response: Option<String>,
+ /// Number of times this supervisor has been restored
+ pub restoration_count: Option<i32>,
+ /// Timestamp of last restoration
+ pub last_restored_at: Option<DateTime<Utc>>,
+}
+
+impl SupervisorState {
+ /// Get the parsed supervisor activity state
+ pub fn activity_state(&self) -> SupervisorActivityState {
+ self.state.parse().unwrap_or(SupervisorActivityState::Initializing)
+ }
+
+ /// Check if this supervisor state is restorable
+ pub fn is_restorable(&self) -> bool {
+ let state = self.activity_state();
+ // Can restore if not in a terminal or error state
+ !matches!(state, SupervisorActivityState::Completed | SupervisorActivityState::Error)
+ }
+
+ /// Check if supervisor is waiting for something
+ pub fn is_waiting(&self) -> bool {
+ let state = self.activity_state();
+ matches!(
+ state,
+ SupervisorActivityState::WaitingForTask | SupervisorActivityState::WaitingForUser
+ )
+ }
+
+ /// Check if supervisor has pending questions
+ pub fn has_pending_question(&self) -> bool {
+ self.pending_question_id.is_some()
+ }
}
/// Request to update supervisor state
-#[derive(Debug, Deserialize, ToSchema)]
+#[derive(Debug, Default, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct UpdateSupervisorStateRequest {
/// Updated conversation history
@@ -1983,6 +2090,109 @@ pub struct UpdateSupervisorStateRequest {
pub pending_task_ids: Option<Vec<Uuid>>,
/// Current contract phase
pub phase: Option<String>,
+ /// Current supervisor activity state
+ pub state: Option<String>,
+ /// Human-readable current activity description
+ pub current_activity: Option<String>,
+ /// Progress percentage (0-100)
+ pub progress: Option<i32>,
+ /// Error message if in error state
+ pub error_message: Option<String>,
+ /// All spawned task IDs
+ pub spawned_task_ids: Option<Vec<Uuid>>,
+ /// Pending question UUID
+ pub pending_question_id: Option<Uuid>,
+ /// Clear the pending question (set to None)
+ #[serde(default)]
+ pub clear_pending_question: bool,
+ /// Last LLM response
+ pub last_llm_response: Option<String>,
+}
+
+/// Request to save supervisor state at specific save points
+#[derive(Debug, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct SaveSupervisorStateRequest {
+ /// The save point type that triggered this save
+ pub save_point: SupervisorSavePoint,
+ /// Updated conversation history (if available)
+ pub conversation_history: Option<serde_json::Value>,
+ /// Current state
+ pub state: Option<String>,
+ /// Current activity description
+ pub current_activity: Option<String>,
+ /// Progress percentage
+ pub progress: Option<i32>,
+ /// Last LLM response
+ pub last_llm_response: Option<String>,
+ /// Task that was spawned (for task_spawn save point)
+ pub spawned_task_id: Option<Uuid>,
+ /// Question that was asked (for question_asked save point)
+ pub question_id: Option<Uuid>,
+ /// Error message (for error save point)
+ pub error_message: Option<String>,
+}
+
+/// Types of save points for supervisor state persistence
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum SupervisorSavePoint {
+ /// Save on every LLM response
+ LlmResponse,
+ /// Save when spawning a new task
+ TaskSpawn,
+ /// Save when asking a question to user
+ QuestionAsked,
+ /// Save when phase changes
+ PhaseChange,
+ /// Lightweight heartbeat update
+ Heartbeat,
+ /// Save when an error occurs
+ Error,
+ /// Save when task completes
+ TaskComplete,
+}
+
+/// Supervisor restoration context
+/// Contains all information needed to restore a supervisor after crash
+#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct SupervisorRestorationContext {
+ /// The restored supervisor state
+ pub state: SupervisorState,
+ /// Tasks that were pending when the crash occurred
+ pub pending_tasks: Vec<TaskSummary>,
+ /// Pending question that needs re-delivery (if any)
+ pub pending_question: Option<PendingQuestionInfo>,
+ /// Whether state was fully valid or partially recovered
+ pub restoration_type: RestorationResult,
+ /// Human-readable restoration message
+ pub message: String,
+}
+
+/// Information about a pending question for restoration
+#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct PendingQuestionInfo {
+ pub question_id: Uuid,
+ pub question: String,
+ pub choices: Vec<String>,
+ pub context: Option<String>,
+ pub multi_select: bool,
+}
+
+/// Result of restoration attempt
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
+#[serde(rename_all = "snake_case")]
+pub enum RestorationResult {
+ /// State was fully valid and restored
+ FullRestore,
+ /// State was valid but restored from last checkpoint
+ CheckpointRestore,
+ /// Started fresh due to invalid/missing state
+ FreshStart,
+ /// Partial restoration with some context lost
+ PartialRestore,
}
// ============================================================================
diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs
index b7c5af1..c62de34 100644
--- a/makima/src/db/repository.rs
+++ b/makima/src/db/repository.rs
@@ -3404,6 +3404,435 @@ pub async fn update_supervisor_pending_tasks(
.await
}
+/// Save supervisor state at a specific save point.
+/// This is the main function for persisting supervisor state during operation.
+pub async fn save_supervisor_state_at_savepoint(
+ pool: &PgPool,
+ contract_id: Uuid,
+ save_point: &str,
+ state: Option<&str>,
+ current_activity: Option<&str>,
+ progress: Option<i32>,
+ last_llm_response: Option<&str>,
+ spawned_task_id: Option<Uuid>,
+ pending_question_id: Option<Uuid>,
+ error_message: Option<&str>,
+ conversation_history: Option<serde_json::Value>,
+) -> Result<SupervisorState, sqlx::Error> {
+ // Build dynamic UPDATE based on what fields are provided
+ let mut updates = vec!["last_activity = NOW()".to_string(), "updated_at = NOW()".to_string()];
+ let mut bind_idx = 2; // Start at 2 since $1 is contract_id
+
+ if state.is_some() {
+ bind_idx += 1;
+ updates.push(format!("state = ${}", bind_idx));
+ }
+ if current_activity.is_some() {
+ bind_idx += 1;
+ updates.push(format!("current_activity = ${}", bind_idx));
+ }
+ if progress.is_some() {
+ bind_idx += 1;
+ updates.push(format!("progress = ${}", bind_idx));
+ }
+ if last_llm_response.is_some() {
+ bind_idx += 1;
+ updates.push(format!("last_llm_response = ${}", bind_idx));
+ }
+ if spawned_task_id.is_some() {
+ bind_idx += 1;
+ updates.push(format!("spawned_task_ids = array_append(spawned_task_ids, ${})", bind_idx));
+ }
+ if pending_question_id.is_some() {
+ bind_idx += 1;
+ updates.push(format!("pending_question_id = ${}", bind_idx));
+ } else if save_point == "llm_response" {
+ // Clear pending question on LLM response (question was answered)
+ updates.push("pending_question_id = NULL".to_string());
+ }
+ if error_message.is_some() {
+ bind_idx += 1;
+ updates.push(format!("error_message = ${}", bind_idx));
+ }
+ if conversation_history.is_some() {
+ bind_idx += 1;
+ updates.push(format!("conversation_history = ${}", bind_idx));
+ }
+
+ let query = format!(
+ "UPDATE supervisor_states SET {} WHERE contract_id = $1 RETURNING *",
+ updates.join(", ")
+ );
+
+ let mut query_builder = sqlx::query_as::<_, SupervisorState>(&query).bind(contract_id);
+
+ // Bind values in order
+ if let Some(v) = state {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = current_activity {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = progress {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = last_llm_response {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = spawned_task_id {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = pending_question_id {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = error_message {
+ query_builder = query_builder.bind(v);
+ }
+ if let Some(v) = conversation_history {
+ query_builder = query_builder.bind(v);
+ }
+
+ query_builder.fetch_one(pool).await
+}
+
+/// Update supervisor activity state (lightweight update for heartbeats).
+pub async fn update_supervisor_activity_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: Option<i32>,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET state = $2,
+ current_activity = COALESCE($3, current_activity),
+ progress = COALESCE($4, progress),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .fetch_one(pool)
+ .await
+}
+
+/// Update supervisor with spawned task.
+pub async fn add_spawned_task_to_supervisor(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET spawned_task_ids = array_append(spawned_task_ids, $2),
+ pending_task_ids = array_append(pending_task_ids, $2),
+ state = 'waiting_for_task',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(task_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Set pending question on supervisor.
+pub async fn set_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_question_id = $2,
+ state = 'waiting_for_user',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(question_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Clear pending question (after user responds).
+pub async fn clear_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_question_id = NULL,
+ state = 'executing',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Set supervisor error state.
+pub async fn set_supervisor_error(
+ pool: &PgPool,
+ contract_id: Uuid,
+ error_message: &str,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET state = 'error',
+ error_message = $2,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(error_message)
+ .fetch_one(pool)
+ .await
+}
+
+/// Save last LLM response for restoration context.
+pub async fn update_supervisor_llm_response(
+ pool: &PgPool,
+ contract_id: Uuid,
+ last_llm_response: &str,
+ conversation_history: Option<serde_json::Value>,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET last_llm_response = $2,
+ conversation_history = COALESCE($3, conversation_history),
+ state = 'executing',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(last_llm_response)
+ .bind(conversation_history)
+ .fetch_one(pool)
+ .await
+}
+
+/// Mark supervisor as restored and increment restoration count.
+pub async fn mark_supervisor_restored(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET restoration_count = COALESCE(restoration_count, 0) + 1,
+ last_restored_at = NOW(),
+ error_message = NULL,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Get supervisor state with full context for restoration.
+/// Returns the state along with any pending tasks and questions.
+pub async fn get_supervisor_restoration_context(
+ pool: &PgPool,
+ contract_id: Uuid,
+ owner_id: Uuid,
+) -> Result<Option<(SupervisorState, Vec<Task>)>, sqlx::Error> {
+ let state = get_supervisor_state(pool, contract_id).await?;
+
+ let Some(state) = state else {
+ return Ok(None);
+ };
+
+ // Get pending tasks
+ let pending_tasks = if !state.pending_task_ids.is_empty() {
+ sqlx::query_as::<_, Task>(
+ r#"
+ SELECT * FROM tasks
+ WHERE id = ANY($1) AND owner_id = $2
+ ORDER BY created_at ASC
+ "#,
+ )
+ .bind(&state.pending_task_ids)
+ .bind(owner_id)
+ .fetch_all(pool)
+ .await?
+ } else {
+ vec![]
+ };
+
+ Ok(Some((state, pending_tasks)))
+}
+
+/// Validate supervisor state consistency.
+/// Checks if the state can be used for restoration or needs recovery.
+pub async fn validate_supervisor_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ owner_id: Uuid,
+) -> Result<StateValidationResult, sqlx::Error> {
+ let state = get_supervisor_state(pool, contract_id).await?;
+
+ let Some(state) = state else {
+ return Ok(StateValidationResult::NotFound);
+ };
+
+ // Check if task still exists
+ let task = get_task(pool, state.task_id).await?;
+ if task.is_none() {
+ return Ok(StateValidationResult::Invalid {
+ reason: "Supervisor task no longer exists".to_string(),
+ });
+ }
+
+ // Check if pending tasks are valid
+ let mut invalid_tasks = vec![];
+ for task_id in &state.pending_task_ids {
+ let pending = get_task(pool, *task_id).await?;
+ if pending.is_none() {
+ invalid_tasks.push(*task_id);
+ }
+ }
+
+ if !invalid_tasks.is_empty() {
+ return Ok(StateValidationResult::PartiallyValid {
+ state,
+ invalid_task_ids: invalid_tasks,
+ });
+ }
+
+ // Check contract phase consistency
+ let contract = get_contract_for_owner(pool, contract_id, owner_id).await?;
+ if let Some(contract) = contract {
+ if contract.phase != state.phase {
+ return Ok(StateValidationResult::PhaseStale {
+ state,
+ current_phase: contract.phase,
+ });
+ }
+ }
+
+ Ok(StateValidationResult::Valid(state))
+}
+
+/// Result of supervisor state validation.
+#[derive(Debug)]
+pub enum StateValidationResult {
+ /// State is valid and can be restored as-is
+ Valid(SupervisorState),
+ /// State not found for this contract
+ NotFound,
+ /// State is invalid and cannot be used
+ Invalid { reason: String },
+ /// State is partially valid (some tasks are missing)
+ PartiallyValid {
+ state: SupervisorState,
+ invalid_task_ids: Vec<Uuid>,
+ },
+ /// State phase doesn't match current contract phase
+ PhaseStale {
+ state: SupervisorState,
+ current_phase: String,
+ },
+}
+
+/// Remove a task from supervisor's pending tasks (when task completes).
+pub async fn remove_pending_task_from_supervisor(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+) -> Result<Option<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_task_ids = array_remove(pending_task_ids, $2),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(task_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Enhanced upsert for supervisor state with all new fields.
+pub async fn upsert_enhanced_supervisor_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+ conversation_history: serde_json::Value,
+ pending_task_ids: &[Uuid],
+ phase: &str,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: Option<i32>,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ INSERT INTO supervisor_states (
+ contract_id, task_id, conversation_history, pending_task_ids,
+ phase, state, current_activity, progress, last_activity
+ )
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, NOW())
+ ON CONFLICT (contract_id) DO UPDATE SET
+ task_id = EXCLUDED.task_id,
+ conversation_history = EXCLUDED.conversation_history,
+ pending_task_ids = EXCLUDED.pending_task_ids,
+ phase = EXCLUDED.phase,
+ state = EXCLUDED.state,
+ current_activity = EXCLUDED.current_activity,
+ progress = EXCLUDED.progress,
+ last_activity = NOW(),
+ updated_at = NOW()
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(task_id)
+ .bind(conversation_history)
+ .bind(pending_task_ids)
+ .bind(phase)
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .fetch_one(pool)
+ .await
+}
+
// ============================================================================
// Contract Supervisor
// ============================================================================