summaryrefslogtreecommitdiff
path: root/makima/src/db/repository.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/db/repository.rs')
-rw-r--r--makima/src/db/repository.rs304
1 files changed, 304 insertions, 0 deletions
diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs
index 1ac188c..d1ec3ef 100644
--- a/makima/src/db/repository.rs
+++ b/makima/src/db/repository.rs
@@ -3404,6 +3404,310 @@ pub async fn update_supervisor_pending_tasks(
.await
}
+/// Update supervisor state with detailed activity tracking.
+/// Called at key save points: LLM response, task spawn, question asked, phase change.
+pub async fn update_supervisor_detailed_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ error_message: Option<&str>,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET state = $1,
+ current_activity = $2,
+ progress = $3,
+ error_message = $4,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $5
+ RETURNING *
+ "#,
+ )
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(error_message)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Add a spawned task ID to the supervisor's list.
+pub async fn add_supervisor_spawned_task(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET spawned_task_ids = array_append(spawned_task_ids, $1),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(task_id)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Add a pending question to the supervisor state.
+pub async fn add_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question: serde_json::Value,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_questions = pending_questions || $1::jsonb,
+ state = 'waiting_for_user',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(question)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Remove a pending question by ID.
+pub async fn remove_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_questions = (
+ SELECT COALESCE(jsonb_agg(elem), '[]'::jsonb)
+ FROM jsonb_array_elements(pending_questions) elem
+ WHERE (elem->>'id')::uuid != $1
+ ),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(question_id)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Comprehensive state save - used at major save points.
+pub async fn save_supervisor_state_full(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+ conversation_history: serde_json::Value,
+ pending_task_ids: &[Uuid],
+ phase: &str,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ error_message: Option<&str>,
+ spawned_task_ids: &[Uuid],
+ pending_questions: serde_json::Value,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ INSERT INTO supervisor_states (
+ contract_id, task_id, conversation_history, pending_task_ids, phase,
+ state, current_activity, progress, error_message, spawned_task_ids,
+ pending_questions, last_activity
+ )
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, NOW())
+ ON CONFLICT (contract_id) DO UPDATE SET
+ task_id = EXCLUDED.task_id,
+ conversation_history = EXCLUDED.conversation_history,
+ pending_task_ids = EXCLUDED.pending_task_ids,
+ phase = EXCLUDED.phase,
+ state = EXCLUDED.state,
+ current_activity = EXCLUDED.current_activity,
+ progress = EXCLUDED.progress,
+ error_message = EXCLUDED.error_message,
+ spawned_task_ids = EXCLUDED.spawned_task_ids,
+ pending_questions = EXCLUDED.pending_questions,
+ last_activity = NOW(),
+ updated_at = NOW()
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(task_id)
+ .bind(conversation_history)
+ .bind(pending_task_ids)
+ .bind(phase)
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(error_message)
+ .bind(spawned_task_ids)
+ .bind(pending_questions)
+ .fetch_one(pool)
+ .await
+}
+
+/// Mark supervisor as restored from a crash/interruption.
+pub async fn mark_supervisor_restored(
+ pool: &PgPool,
+ contract_id: Uuid,
+ restoration_source: &str,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET restoration_count = restoration_count + 1,
+ last_restored_at = NOW(),
+ restoration_source = $1,
+ state = 'initializing',
+ error_message = NULL,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(restoration_source)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Get supervisors with pending questions (for re-delivery after restoration).
+pub async fn get_supervisors_with_pending_questions(
+ pool: &PgPool,
+ owner_id: Uuid,
+) -> Result<Vec<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ SELECT ss.*
+ FROM supervisor_states ss
+ JOIN contracts c ON c.id = ss.contract_id
+ WHERE c.owner_id = $1
+ AND ss.pending_questions != '[]'::jsonb
+ AND c.status = 'active'
+ ORDER BY ss.last_activity DESC
+ "#,
+ )
+ .bind(owner_id)
+ .fetch_all(pool)
+ .await
+}
+
+/// Get supervisor state with full details for restoration.
+/// Includes validation info.
+pub async fn get_supervisor_state_for_restoration(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<Option<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ SELECT * FROM supervisor_states WHERE contract_id = $1
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Validate spawned tasks are in expected states.
+/// Returns map of task_id -> (status, updated_at).
+pub async fn validate_spawned_tasks(
+ pool: &PgPool,
+ task_ids: &[Uuid],
+) -> Result<std::collections::HashMap<Uuid, (String, chrono::DateTime<Utc>)>, sqlx::Error> {
+ use sqlx::Row;
+
+ let rows = sqlx::query(
+ r#"
+ SELECT id, status, updated_at
+ FROM tasks
+ WHERE id = ANY($1)
+ "#,
+ )
+ .bind(task_ids)
+ .fetch_all(pool)
+ .await?;
+
+ let mut result = std::collections::HashMap::new();
+ for row in rows {
+ let id: Uuid = row.get("id");
+ let status: String = row.get("status");
+ let updated_at: chrono::DateTime<Utc> = row.get("updated_at");
+ result.insert(id, (status, updated_at));
+ }
+ Ok(result)
+}
+
+/// Update supervisor state when phase changes.
+pub async fn update_supervisor_phase(
+ pool: &PgPool,
+ contract_id: Uuid,
+ new_phase: &str,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET phase = $1,
+ state = 'working',
+ current_activity = 'Phase changed to ' || $1,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(new_phase)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Update supervisor state on heartbeat (lightweight update).
+pub async fn update_supervisor_heartbeat_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ pending_task_ids: &[Uuid],
+) -> Result<(), sqlx::Error> {
+ sqlx::query(
+ r#"
+ UPDATE supervisor_states
+ SET state = $1,
+ current_activity = $2,
+ progress = $3,
+ pending_task_ids = $4,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $5
+ "#,
+ )
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(pending_task_ids)
+ .bind(contract_id)
+ .execute(pool)
+ .await?;
+ Ok(())
+}
+
// ============================================================================
// Supervisor Heartbeats
// ============================================================================