summaryrefslogtreecommitdiff
path: root/makima/src/db/repository.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/db/repository.rs')
-rw-r--r--makima/src/db/repository.rs599
1 files changed, 596 insertions, 3 deletions
diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs
index b7c5af1..e308df7 100644
--- a/makima/src/db/repository.rs
+++ b/makima/src/db/repository.rs
@@ -12,9 +12,9 @@ use super::models::{
CreateFileRequest, CreateTaskRequest, CreateTemplateRequest, Daemon, DaemonTaskAssignment,
DaemonWithCapacity, DeliverableDefinition, File, FileSummary, FileVersion, HistoryEvent,
HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, PhaseChangeResult,
- PhaseConfig, PhaseDefinition, RedTeamNotification, SupervisorState, Task, TaskCheckpoint,
- TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest, UpdateTaskRequest,
- UpdateTemplateRequest,
+ PhaseConfig, PhaseDefinition, RedTeamNotification, SupervisorHeartbeatRecord, SupervisorState,
+ Task, TaskCheckpoint, TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest,
+ UpdateTaskRequest, UpdateTemplateRequest,
};
/// Repository error types.
@@ -3404,6 +3404,464 @@ pub async fn update_supervisor_pending_tasks(
.await
}
+/// Update supervisor state with detailed activity tracking.
+/// Called at key save points: LLM response, task spawn, question asked, phase change.
+pub async fn update_supervisor_detailed_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ error_message: Option<&str>,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET state = $1,
+ current_activity = $2,
+ progress = $3,
+ error_message = $4,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $5
+ RETURNING *
+ "#,
+ )
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(error_message)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Add a spawned task ID to the supervisor's list.
+pub async fn add_supervisor_spawned_task(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET spawned_task_ids = array_append(spawned_task_ids, $1),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(task_id)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Add a pending question to the supervisor state.
+pub async fn add_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question: serde_json::Value,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_questions = pending_questions || $1::jsonb,
+ state = 'waiting_for_user',
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(question)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Remove a pending question by ID.
+pub async fn remove_supervisor_pending_question(
+ pool: &PgPool,
+ contract_id: Uuid,
+ question_id: Uuid,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET pending_questions = (
+ SELECT COALESCE(jsonb_agg(elem), '[]'::jsonb)
+ FROM jsonb_array_elements(pending_questions) elem
+ WHERE (elem->>'id')::uuid != $1
+ ),
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(question_id)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Comprehensive state save - used at major save points.
+pub async fn save_supervisor_state_full(
+ pool: &PgPool,
+ contract_id: Uuid,
+ task_id: Uuid,
+ conversation_history: serde_json::Value,
+ pending_task_ids: &[Uuid],
+ phase: &str,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ error_message: Option<&str>,
+ spawned_task_ids: &[Uuid],
+ pending_questions: serde_json::Value,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ INSERT INTO supervisor_states (
+ contract_id, task_id, conversation_history, pending_task_ids, phase,
+ state, current_activity, progress, error_message, spawned_task_ids,
+ pending_questions, last_activity
+ )
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, NOW())
+ ON CONFLICT (contract_id) DO UPDATE SET
+ task_id = EXCLUDED.task_id,
+ conversation_history = EXCLUDED.conversation_history,
+ pending_task_ids = EXCLUDED.pending_task_ids,
+ phase = EXCLUDED.phase,
+ state = EXCLUDED.state,
+ current_activity = EXCLUDED.current_activity,
+ progress = EXCLUDED.progress,
+ error_message = EXCLUDED.error_message,
+ spawned_task_ids = EXCLUDED.spawned_task_ids,
+ pending_questions = EXCLUDED.pending_questions,
+ last_activity = NOW(),
+ updated_at = NOW()
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .bind(task_id)
+ .bind(conversation_history)
+ .bind(pending_task_ids)
+ .bind(phase)
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(error_message)
+ .bind(spawned_task_ids)
+ .bind(pending_questions)
+ .fetch_one(pool)
+ .await
+}
+
+/// Mark supervisor as restored from a crash/interruption.
+pub async fn mark_supervisor_restored(
+ pool: &PgPool,
+ contract_id: Uuid,
+ restoration_source: &str,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET restoration_count = restoration_count + 1,
+ last_restored_at = NOW(),
+ restoration_source = $1,
+ state = 'initializing',
+ error_message = NULL,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(restoration_source)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Get supervisors with pending questions (for re-delivery after restoration).
+pub async fn get_supervisors_with_pending_questions(
+ pool: &PgPool,
+ owner_id: Uuid,
+) -> Result<Vec<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ SELECT ss.*
+ FROM supervisor_states ss
+ JOIN contracts c ON c.id = ss.contract_id
+ WHERE c.owner_id = $1
+ AND ss.pending_questions != '[]'::jsonb
+ AND c.status = 'active'
+ ORDER BY ss.last_activity DESC
+ "#,
+ )
+ .bind(owner_id)
+ .fetch_all(pool)
+ .await
+}
+
+/// Get supervisor state with full details for restoration.
+/// Includes validation info.
+pub async fn get_supervisor_state_for_restoration(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<Option<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ SELECT * FROM supervisor_states WHERE contract_id = $1
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Validate spawned tasks are in expected states.
+/// Returns map of task_id -> (status, updated_at).
+pub async fn validate_spawned_tasks(
+ pool: &PgPool,
+ task_ids: &[Uuid],
+) -> Result<std::collections::HashMap<Uuid, (String, chrono::DateTime<Utc>)>, sqlx::Error> {
+ use sqlx::Row;
+
+ let rows = sqlx::query(
+ r#"
+ SELECT id, status, updated_at
+ FROM tasks
+ WHERE id = ANY($1)
+ "#,
+ )
+ .bind(task_ids)
+ .fetch_all(pool)
+ .await?;
+
+ let mut result = std::collections::HashMap::new();
+ for row in rows {
+ let id: Uuid = row.get("id");
+ let status: String = row.get("status");
+ let updated_at: chrono::DateTime<Utc> = row.get("updated_at");
+ result.insert(id, (status, updated_at));
+ }
+ Ok(result)
+}
+
+/// Update supervisor state when phase changes.
+pub async fn update_supervisor_phase(
+ pool: &PgPool,
+ contract_id: Uuid,
+ new_phase: &str,
+) -> Result<SupervisorState, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET phase = $1,
+ state = 'working',
+ current_activity = 'Phase changed to ' || $1,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $2
+ RETURNING *
+ "#,
+ )
+ .bind(new_phase)
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await
+}
+
+/// Update supervisor state on heartbeat (lightweight update).
+pub async fn update_supervisor_heartbeat_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+ state: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ pending_task_ids: &[Uuid],
+) -> Result<(), sqlx::Error> {
+ sqlx::query(
+ r#"
+ UPDATE supervisor_states
+ SET state = $1,
+ current_activity = $2,
+ progress = $3,
+ pending_task_ids = $4,
+ last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $5
+ "#,
+ )
+ .bind(state)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(pending_task_ids)
+ .bind(contract_id)
+ .execute(pool)
+ .await?;
+ Ok(())
+}
+
+// ============================================================================
+// Supervisor Heartbeats
+// ============================================================================
+
+/// Record a supervisor heartbeat.
+/// This creates a historical record for monitoring and dead supervisor detection.
+pub async fn create_supervisor_heartbeat(
+ pool: &PgPool,
+ supervisor_task_id: Uuid,
+ contract_id: Uuid,
+ state: &str,
+ phase: &str,
+ current_activity: Option<&str>,
+ progress: i32,
+ pending_task_ids: &[Uuid],
+) -> Result<SupervisorHeartbeatRecord, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorHeartbeatRecord>(
+ r#"
+ INSERT INTO supervisor_heartbeats (
+ supervisor_task_id, contract_id, state, phase, current_activity, progress, pending_task_ids, timestamp
+ )
+ VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ RETURNING *
+ "#,
+ )
+ .bind(supervisor_task_id)
+ .bind(contract_id)
+ .bind(state)
+ .bind(phase)
+ .bind(current_activity)
+ .bind(progress)
+ .bind(pending_task_ids)
+ .fetch_one(pool)
+ .await
+}
+
+/// Get the latest heartbeat for a supervisor task.
+pub async fn get_latest_supervisor_heartbeat(
+ pool: &PgPool,
+ supervisor_task_id: Uuid,
+) -> Result<Option<SupervisorHeartbeatRecord>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorHeartbeatRecord>(
+ r#"
+ SELECT * FROM supervisor_heartbeats
+ WHERE supervisor_task_id = $1
+ ORDER BY timestamp DESC
+ LIMIT 1
+ "#,
+ )
+ .bind(supervisor_task_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Get recent heartbeats for a supervisor task.
+pub async fn get_supervisor_heartbeats(
+ pool: &PgPool,
+ supervisor_task_id: Uuid,
+ limit: i64,
+) -> Result<Vec<SupervisorHeartbeatRecord>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorHeartbeatRecord>(
+ r#"
+ SELECT * FROM supervisor_heartbeats
+ WHERE supervisor_task_id = $1
+ ORDER BY timestamp DESC
+ LIMIT $2
+ "#,
+ )
+ .bind(supervisor_task_id)
+ .bind(limit)
+ .fetch_all(pool)
+ .await
+}
+
+/// Get recent heartbeats for a contract.
+pub async fn get_contract_supervisor_heartbeats(
+ pool: &PgPool,
+ contract_id: Uuid,
+ limit: i64,
+) -> Result<Vec<SupervisorHeartbeatRecord>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorHeartbeatRecord>(
+ r#"
+ SELECT * FROM supervisor_heartbeats
+ WHERE contract_id = $1
+ ORDER BY timestamp DESC
+ LIMIT $2
+ "#,
+ )
+ .bind(contract_id)
+ .bind(limit)
+ .fetch_all(pool)
+ .await
+}
+
+/// Delete old heartbeats beyond the TTL (24 hours by default).
+/// Returns the number of deleted records.
+pub async fn cleanup_old_heartbeats(
+ pool: &PgPool,
+ ttl_hours: i64,
+) -> Result<u64, sqlx::Error> {
+ let result = sqlx::query(
+ r#"
+ DELETE FROM supervisor_heartbeats
+ WHERE timestamp < NOW() - ($1 || ' hours')::INTERVAL
+ "#,
+ )
+ .bind(ttl_hours.to_string())
+ .execute(pool)
+ .await?;
+
+ Ok(result.rows_affected())
+}
+
+/// Find supervisors that have not sent a heartbeat within the timeout period.
+/// Returns list of (supervisor_task_id, contract_id, last_heartbeat_timestamp).
+pub async fn find_stale_supervisors(
+ pool: &PgPool,
+ timeout_seconds: i64,
+) -> Result<Vec<(Uuid, Uuid, chrono::DateTime<Utc>)>, sqlx::Error> {
+ let rows = sqlx::query(
+ r#"
+ WITH latest_heartbeats AS (
+ SELECT DISTINCT ON (supervisor_task_id)
+ supervisor_task_id,
+ contract_id,
+ timestamp
+ FROM supervisor_heartbeats
+ ORDER BY supervisor_task_id, timestamp DESC
+ )
+ SELECT
+ lh.supervisor_task_id,
+ lh.contract_id,
+ lh.timestamp
+ FROM latest_heartbeats lh
+ JOIN tasks t ON t.id = lh.supervisor_task_id
+ WHERE t.status = 'running'
+ AND lh.timestamp < NOW() - ($1 || ' seconds')::INTERVAL
+ "#,
+ )
+ .bind(timeout_seconds.to_string())
+ .fetch_all(pool)
+ .await?;
+
+ let mut result = Vec::new();
+ for row in rows {
+ use sqlx::Row;
+ let supervisor_task_id: Uuid = row.get("supervisor_task_id");
+ let contract_id: Uuid = row.get("contract_id");
+ let timestamp: chrono::DateTime<Utc> = row.get("timestamp");
+ result.push((supervisor_task_id, contract_id, timestamp));
+ }
+ Ok(result)
+}
+
// ============================================================================
// Contract Supervisor
// ============================================================================
@@ -4402,3 +4860,138 @@ pub async fn get_notification_count_for_task(
.map_err(RepositoryError::Database)?;
Ok(result.0)
}
+
+// =============================================================================
+// Supervisor Status API Helpers
+// =============================================================================
+
+/// Get supervisor status for a contract.
+/// Returns combined information from supervisor_states and tasks tables.
+pub async fn get_supervisor_status(
+ pool: &PgPool,
+ contract_id: Uuid,
+ owner_id: Uuid,
+) -> Result<Option<SupervisorStatusInfo>, sqlx::Error> {
+ // Query to get supervisor status by joining supervisor_states with tasks
+ sqlx::query_as::<_, SupervisorStatusInfo>(
+ r#"
+ SELECT
+ ss.task_id,
+ COALESCE(t.status, 'unknown') as supervisor_state,
+ ss.phase,
+ t.progress_summary as current_activity,
+ ss.pending_task_ids,
+ ss.last_activity as last_heartbeat,
+ t.status as task_status,
+ t.daemon_id IS NOT NULL as is_running
+ FROM supervisor_states ss
+ JOIN tasks t ON t.id = ss.task_id
+ WHERE ss.contract_id = $1
+ AND t.owner_id = $2
+ "#,
+ )
+ .bind(contract_id)
+ .bind(owner_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Internal struct to hold supervisor status query result
+#[derive(Debug, Clone, sqlx::FromRow)]
+pub struct SupervisorStatusInfo {
+ pub task_id: Uuid,
+ pub supervisor_state: String,
+ pub phase: String,
+ pub current_activity: Option<String>,
+ #[sqlx(try_from = "Vec<Uuid>")]
+ pub pending_task_ids: Vec<Uuid>,
+ pub last_heartbeat: chrono::DateTime<chrono::Utc>,
+ pub task_status: String,
+ pub is_running: bool,
+}
+
+/// Get supervisor activity history from history_events table.
+/// This provides a timeline of supervisor activities that can serve as "heartbeats".
+pub async fn get_supervisor_activity_history(
+ pool: &PgPool,
+ contract_id: Uuid,
+ limit: i32,
+ offset: i32,
+) -> Result<Vec<SupervisorActivityEntry>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorActivityEntry>(
+ r#"
+ SELECT
+ created_at as timestamp,
+ COALESCE(event_subtype, 'activity') as state,
+ event_data->>'activity' as activity,
+ (event_data->>'progress')::INTEGER as progress,
+ COALESCE(phase, 'unknown') as phase,
+ CASE
+ WHEN event_data->'pending_task_ids' IS NOT NULL
+ THEN ARRAY(SELECT jsonb_array_elements_text(event_data->'pending_task_ids'))::UUID[]
+ ELSE ARRAY[]::UUID[]
+ END as pending_task_ids
+ FROM history_events
+ WHERE contract_id = $1
+ AND event_type IN ('supervisor', 'phase', 'task')
+ ORDER BY created_at DESC
+ LIMIT $2 OFFSET $3
+ "#,
+ )
+ .bind(contract_id)
+ .bind(limit)
+ .bind(offset)
+ .fetch_all(pool)
+ .await
+}
+
+/// Internal struct to hold supervisor activity entry
+#[derive(Debug, Clone, sqlx::FromRow)]
+pub struct SupervisorActivityEntry {
+ pub timestamp: chrono::DateTime<chrono::Utc>,
+ pub state: String,
+ pub activity: Option<String>,
+ pub progress: Option<i32>,
+ pub phase: String,
+ #[sqlx(try_from = "Vec<Uuid>")]
+ pub pending_task_ids: Vec<Uuid>,
+}
+
+/// Count total supervisor activity history entries for a contract.
+pub async fn count_supervisor_activity_history(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<i64, sqlx::Error> {
+ let result: (i64,) = sqlx::query_as(
+ r#"
+ SELECT COUNT(*)
+ FROM history_events
+ WHERE contract_id = $1
+ AND event_type IN ('supervisor', 'phase', 'task')
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_one(pool)
+ .await?;
+ Ok(result.0)
+}
+
+/// Update supervisor state last_activity timestamp.
+/// This acts as a "sync" operation to refresh the supervisor's heartbeat.
+pub async fn sync_supervisor_state(
+ pool: &PgPool,
+ contract_id: Uuid,
+) -> Result<Option<SupervisorState>, sqlx::Error> {
+ sqlx::query_as::<_, SupervisorState>(
+ r#"
+ UPDATE supervisor_states
+ SET last_activity = NOW(),
+ updated_at = NOW()
+ WHERE contract_id = $1
+ RETURNING *
+ "#,
+ )
+ .bind(contract_id)
+ .fetch_optional(pool)
+ .await
+}