diff options
Diffstat (limited to 'makima/src/db')
| -rw-r--r-- | makima/src/db/models.rs | 379 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 599 |
2 files changed, 975 insertions, 3 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 636d81a..abdcce6 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -1971,6 +1971,50 @@ pub struct SupervisorState { pub last_activity: DateTime<Utc>, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, + /// Current supervisor state (initializing, idle, working, waiting_for_user, etc.) + pub state: String, + /// Human-readable description of current activity + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: i32, + /// Error message when state is failed or blocked + pub error_message: Option<String>, + /// Tasks spawned by this supervisor + #[sqlx(try_from = "Vec<Uuid>")] + pub spawned_task_ids: Vec<Uuid>, + /// Pending questions awaiting user response + #[sqlx(json)] + pub pending_questions: serde_json::Value, + /// Number of times this supervisor has been restored + pub restoration_count: i32, + /// Timestamp of last restoration + pub last_restored_at: Option<DateTime<Utc>>, + /// Source of last restoration (daemon_restart, task_reassignment, manual) + pub restoration_source: Option<String>, +} + +/// Pending question structure for supervisor state +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PendingQuestion { + /// Unique question ID + pub id: Uuid, + /// The question text + pub question: String, + /// Optional choices (empty for free-form) + #[serde(default)] + pub choices: Vec<String>, + /// Optional context + pub context: Option<String>, + /// Question type: general, phase_confirmation, contract_complete + #[serde(default = "default_question_type")] + pub question_type: String, + /// When the question was asked + pub asked_at: DateTime<Utc>, +} + +fn default_question_type() -> String { + "general".to_string() } /// Request to update supervisor state @@ -1983,6 +2027,64 @@ pub struct UpdateSupervisorStateRequest { pub pending_task_ids: Option<Vec<Uuid>>, /// Current contract phase pub phase: Option<String>, + /// Current supervisor state + pub state: Option<String>, + /// Current activity description + pub current_activity: Option<String>, + /// Progress percentage + pub progress: Option<i32>, + /// Error message + pub error_message: Option<String>, + /// Spawned task IDs + pub spawned_task_ids: Option<Vec<Uuid>>, + /// Pending questions + pub pending_questions: Option<serde_json::Value>, +} + +/// Restoration context returned when restoring a supervisor +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorRestorationContext { + /// Whether restoration was successful + pub success: bool, + /// Previous state before restoration + pub previous_state: SupervisorStateEnum, + /// Restored conversation history + pub conversation_history: serde_json::Value, + /// Pending questions that need re-delivery + pub pending_questions: Vec<PendingQuestion>, + /// Tasks still being waited on + pub waiting_task_ids: Vec<Uuid>, + /// Spawned tasks to check status of + pub spawned_task_ids: Vec<Uuid>, + /// Restoration count (incremented) + pub restoration_count: i32, + /// Context message for Claude + pub restoration_context_message: String, + /// Any warnings during restoration + pub warnings: Vec<String>, +} + +/// Validation result for supervisor state consistency +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StateValidationResult { + pub is_valid: bool, + pub issues: Vec<String>, + /// Suggested recovery action + pub recovery_action: StateRecoveryAction, +} + +/// Action to take when state validation fails +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum StateRecoveryAction { + /// State is valid, proceed with restoration + Proceed, + /// Start from last checkpoint + UseCheckpoint, + /// Start fresh + StartFresh, + /// Manual intervention required + ManualIntervention, } // ============================================================================ @@ -2339,6 +2441,111 @@ pub struct CheckpointPatchInfo { // Red Team Types // ============================================================================ +// ============================================================================= +// Supervisor State and Heartbeat Types +// ============================================================================= + +/// Supervisor state for contract supervisor tasks. +/// Captures detailed activity state for monitoring and restoration. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum SupervisorStateEnum { + /// Supervisor is starting up + Initializing, + /// Supervisor is idle, waiting for work + Idle, + /// Supervisor is actively working + Working, + /// Supervisor is waiting for user input/confirmation + WaitingForUser, + /// Supervisor is waiting for spawned tasks to complete + WaitingForTasks, + /// Supervisor is blocked (external dependency, error, etc.) + Blocked, + /// Supervisor has completed its contract + Completed, + /// Supervisor has failed + Failed, + /// Supervisor was interrupted (daemon crash, etc.) + Interrupted, +} + +impl Default for SupervisorStateEnum { + fn default() -> Self { + SupervisorStateEnum::Initializing + } +} + +impl std::fmt::Display for SupervisorStateEnum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SupervisorStateEnum::Initializing => write!(f, "initializing"), + SupervisorStateEnum::Idle => write!(f, "idle"), + SupervisorStateEnum::Working => write!(f, "working"), + SupervisorStateEnum::WaitingForUser => write!(f, "waiting_for_user"), + SupervisorStateEnum::WaitingForTasks => write!(f, "waiting_for_tasks"), + SupervisorStateEnum::Blocked => write!(f, "blocked"), + SupervisorStateEnum::Completed => write!(f, "completed"), + SupervisorStateEnum::Failed => write!(f, "failed"), + SupervisorStateEnum::Interrupted => write!(f, "interrupted"), + } + } +} + +impl std::str::FromStr for SupervisorStateEnum { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.to_lowercase().as_str() { + "initializing" => Ok(SupervisorStateEnum::Initializing), + "idle" => Ok(SupervisorStateEnum::Idle), + "working" => Ok(SupervisorStateEnum::Working), + "waiting_for_user" | "waitingforuser" => Ok(SupervisorStateEnum::WaitingForUser), + "waiting_for_tasks" | "waitingfortasks" => Ok(SupervisorStateEnum::WaitingForTasks), + "blocked" => Ok(SupervisorStateEnum::Blocked), + "completed" => Ok(SupervisorStateEnum::Completed), + "failed" => Ok(SupervisorStateEnum::Failed), + "interrupted" => Ok(SupervisorStateEnum::Interrupted), + _ => Err(format!("Unknown supervisor state: {}", s)), + } + } +} + +/// Enhanced heartbeat record for supervisor task monitoring. +/// Stored in the database for historical analysis and dead supervisor detection. +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorHeartbeatRecord { + pub id: Uuid, + pub supervisor_task_id: Uuid, + pub contract_id: Uuid, + pub state: String, + pub phase: String, + pub current_activity: Option<String>, + pub progress: i32, + #[sqlx(try_from = "Vec<Uuid>")] + pub pending_task_ids: Vec<Uuid>, + pub timestamp: DateTime<Utc>, +} + +/// Request payload for sending a supervisor heartbeat. +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorHeartbeatRequest { + pub task_id: Uuid, + pub contract_id: Uuid, + pub state: SupervisorStateEnum, + pub phase: String, + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: u8, + pub pending_task_ids: Vec<Uuid>, +} + +// ============================================================================= +// Red Team Types +// ============================================================================= + /// Red Team notification record #[derive(Debug, Clone, FromRow, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] @@ -2395,3 +2602,175 @@ impl std::str::FromStr for NotificationSeverity { } } } + +// ============================================================================ +// Supervisor Status API Types +// ============================================================================ + +/// Response for supervisor status endpoint +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorStatusResponse { + /// The supervisor task ID + pub task_id: Uuid, + /// Current supervisor state (from supervisor_states table) + pub state: String, + /// Current contract phase + pub phase: String, + /// Description of current activity (from task progress_summary) + pub current_activity: Option<String>, + /// Progress percentage (0-100) + pub progress: Option<u8>, + /// When the supervisor last updated its state + pub last_heartbeat: DateTime<Utc>, + /// Task IDs the supervisor is currently waiting on + pub pending_task_ids: Vec<Uuid>, + /// Whether the supervisor is currently running + pub is_running: bool, +} + +/// Individual heartbeat entry for history +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorHeartbeatEntry { + /// Timestamp of this heartbeat + pub timestamp: DateTime<Utc>, + /// Supervisor state at this time + pub state: String, + /// Activity description at this time + pub activity: Option<String>, + /// Progress at this time + pub progress: Option<u8>, + /// Contract phase at this time + pub phase: String, + /// Pending task IDs at this time + pub pending_task_ids: Vec<Uuid>, +} + +/// Response for supervisor heartbeat history endpoint +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorHeartbeatHistoryResponse { + /// List of heartbeat entries + pub heartbeats: Vec<SupervisorHeartbeatEntry>, + /// Total count of heartbeats (for pagination) + pub total: i64, +} + +/// Response for supervisor sync endpoint +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorSyncResponse { + /// Whether the sync was successful + pub synced: bool, + /// Current supervisor state after sync + pub state: String, + /// Optional message about the sync result + pub message: Option<String>, +} + +/// Query parameters for heartbeat history endpoint +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct HeartbeatHistoryQuery { + /// Maximum number of heartbeats to return (default: 10) + pub limit: Option<i32>, + /// Offset for pagination (default: 0) + pub offset: Option<i32>, +} + +// ============================================================================= +// Unit Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + #[test] + fn test_supervisor_state_enum_display() { + assert_eq!(SupervisorStateEnum::Initializing.to_string(), "initializing"); + assert_eq!(SupervisorStateEnum::Idle.to_string(), "idle"); + assert_eq!(SupervisorStateEnum::Working.to_string(), "working"); + assert_eq!(SupervisorStateEnum::WaitingForUser.to_string(), "waiting_for_user"); + assert_eq!(SupervisorStateEnum::WaitingForTasks.to_string(), "waiting_for_tasks"); + assert_eq!(SupervisorStateEnum::Blocked.to_string(), "blocked"); + assert_eq!(SupervisorStateEnum::Completed.to_string(), "completed"); + assert_eq!(SupervisorStateEnum::Failed.to_string(), "failed"); + assert_eq!(SupervisorStateEnum::Interrupted.to_string(), "interrupted"); + } + + #[test] + fn test_supervisor_state_enum_from_str() { + // Standard lowercase + assert_eq!("initializing".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Initializing); + assert_eq!("idle".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Idle); + assert_eq!("working".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Working); + assert_eq!("waiting_for_user".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::WaitingForUser); + assert_eq!("waiting_for_tasks".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::WaitingForTasks); + assert_eq!("blocked".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Blocked); + assert_eq!("completed".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Completed); + assert_eq!("failed".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Failed); + assert_eq!("interrupted".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Interrupted); + + // Case insensitive + assert_eq!("WORKING".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Working); + assert_eq!("Working".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::Working); + + // Alternative formats + assert_eq!("waitingforuser".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::WaitingForUser); + assert_eq!("waitingfortasks".parse::<SupervisorStateEnum>().unwrap(), SupervisorStateEnum::WaitingForTasks); + + // Invalid state + assert!("invalid_state".parse::<SupervisorStateEnum>().is_err()); + } + + #[test] + fn test_supervisor_state_enum_serialization() { + // Test JSON serialization + let state = SupervisorStateEnum::Working; + let json = serde_json::to_string(&state).unwrap(); + assert_eq!(json, "\"working\""); + + // Test JSON deserialization + let deserialized: SupervisorStateEnum = serde_json::from_str("\"working\"").unwrap(); + assert_eq!(deserialized, SupervisorStateEnum::Working); + + // Test underscore variants + let json = "\"waiting_for_user\""; + let deserialized: SupervisorStateEnum = serde_json::from_str(json).unwrap(); + assert_eq!(deserialized, SupervisorStateEnum::WaitingForUser); + } + + #[test] + fn test_supervisor_state_enum_default() { + let default_state = SupervisorStateEnum::default(); + assert_eq!(default_state, SupervisorStateEnum::Initializing); + } + + #[test] + fn test_supervisor_heartbeat_request_serialization() { + let request = SupervisorHeartbeatRequest { + task_id: Uuid::nil(), + contract_id: Uuid::nil(), + state: SupervisorStateEnum::Working, + phase: "execute".to_string(), + current_activity: Some("Implementing feature".to_string()), + progress: 50, + pending_task_ids: vec![Uuid::nil()], + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"state\":\"working\"")); + assert!(json.contains("\"phase\":\"execute\"")); + assert!(json.contains("\"progress\":50")); + assert!(json.contains("\"currentActivity\":\"Implementing feature\"")); + + // Test deserialization + let deserialized: SupervisorHeartbeatRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.state, SupervisorStateEnum::Working); + assert_eq!(deserialized.phase, "execute"); + assert_eq!(deserialized.progress, 50); + } +} diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index b7c5af1..e308df7 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -12,9 +12,9 @@ use super::models::{ CreateFileRequest, CreateTaskRequest, CreateTemplateRequest, Daemon, DaemonTaskAssignment, DaemonWithCapacity, DeliverableDefinition, File, FileSummary, FileVersion, HistoryEvent, HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, PhaseChangeResult, - PhaseConfig, PhaseDefinition, RedTeamNotification, SupervisorState, Task, TaskCheckpoint, - TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest, UpdateTaskRequest, - UpdateTemplateRequest, + PhaseConfig, PhaseDefinition, RedTeamNotification, SupervisorHeartbeatRecord, SupervisorState, + Task, TaskCheckpoint, TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest, + UpdateTaskRequest, UpdateTemplateRequest, }; /// Repository error types. @@ -3404,6 +3404,464 @@ pub async fn update_supervisor_pending_tasks( .await } +/// Update supervisor state with detailed activity tracking. +/// Called at key save points: LLM response, task spawn, question asked, phase change. +pub async fn update_supervisor_detailed_state( + pool: &PgPool, + contract_id: Uuid, + state: &str, + current_activity: Option<&str>, + progress: i32, + error_message: Option<&str>, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET state = $1, + current_activity = $2, + progress = $3, + error_message = $4, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $5 + RETURNING * + "#, + ) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(error_message) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Add a spawned task ID to the supervisor's list. +pub async fn add_supervisor_spawned_task( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET spawned_task_ids = array_append(spawned_task_ids, $1), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(task_id) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Add a pending question to the supervisor state. +pub async fn add_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, + question: serde_json::Value, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_questions = pending_questions || $1::jsonb, + state = 'waiting_for_user', + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(question) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Remove a pending question by ID. +pub async fn remove_supervisor_pending_question( + pool: &PgPool, + contract_id: Uuid, + question_id: Uuid, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET pending_questions = ( + SELECT COALESCE(jsonb_agg(elem), '[]'::jsonb) + FROM jsonb_array_elements(pending_questions) elem + WHERE (elem->>'id')::uuid != $1 + ), + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(question_id) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Comprehensive state save - used at major save points. +pub async fn save_supervisor_state_full( + pool: &PgPool, + contract_id: Uuid, + task_id: Uuid, + conversation_history: serde_json::Value, + pending_task_ids: &[Uuid], + phase: &str, + state: &str, + current_activity: Option<&str>, + progress: i32, + error_message: Option<&str>, + spawned_task_ids: &[Uuid], + pending_questions: serde_json::Value, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + INSERT INTO supervisor_states ( + contract_id, task_id, conversation_history, pending_task_ids, phase, + state, current_activity, progress, error_message, spawned_task_ids, + pending_questions, last_activity + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, NOW()) + ON CONFLICT (contract_id) DO UPDATE SET + task_id = EXCLUDED.task_id, + conversation_history = EXCLUDED.conversation_history, + pending_task_ids = EXCLUDED.pending_task_ids, + phase = EXCLUDED.phase, + state = EXCLUDED.state, + current_activity = EXCLUDED.current_activity, + progress = EXCLUDED.progress, + error_message = EXCLUDED.error_message, + spawned_task_ids = EXCLUDED.spawned_task_ids, + pending_questions = EXCLUDED.pending_questions, + last_activity = NOW(), + updated_at = NOW() + RETURNING * + "#, + ) + .bind(contract_id) + .bind(task_id) + .bind(conversation_history) + .bind(pending_task_ids) + .bind(phase) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(error_message) + .bind(spawned_task_ids) + .bind(pending_questions) + .fetch_one(pool) + .await +} + +/// Mark supervisor as restored from a crash/interruption. +pub async fn mark_supervisor_restored( + pool: &PgPool, + contract_id: Uuid, + restoration_source: &str, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET restoration_count = restoration_count + 1, + last_restored_at = NOW(), + restoration_source = $1, + state = 'initializing', + error_message = NULL, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(restoration_source) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Get supervisors with pending questions (for re-delivery after restoration). +pub async fn get_supervisors_with_pending_questions( + pool: &PgPool, + owner_id: Uuid, +) -> Result<Vec<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + SELECT ss.* + FROM supervisor_states ss + JOIN contracts c ON c.id = ss.contract_id + WHERE c.owner_id = $1 + AND ss.pending_questions != '[]'::jsonb + AND c.status = 'active' + ORDER BY ss.last_activity DESC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// Get supervisor state with full details for restoration. +/// Includes validation info. +pub async fn get_supervisor_state_for_restoration( + pool: &PgPool, + contract_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + SELECT * FROM supervisor_states WHERE contract_id = $1 + "#, + ) + .bind(contract_id) + .fetch_optional(pool) + .await +} + +/// Validate spawned tasks are in expected states. +/// Returns map of task_id -> (status, updated_at). +pub async fn validate_spawned_tasks( + pool: &PgPool, + task_ids: &[Uuid], +) -> Result<std::collections::HashMap<Uuid, (String, chrono::DateTime<Utc>)>, sqlx::Error> { + use sqlx::Row; + + let rows = sqlx::query( + r#" + SELECT id, status, updated_at + FROM tasks + WHERE id = ANY($1) + "#, + ) + .bind(task_ids) + .fetch_all(pool) + .await?; + + let mut result = std::collections::HashMap::new(); + for row in rows { + let id: Uuid = row.get("id"); + let status: String = row.get("status"); + let updated_at: chrono::DateTime<Utc> = row.get("updated_at"); + result.insert(id, (status, updated_at)); + } + Ok(result) +} + +/// Update supervisor state when phase changes. +pub async fn update_supervisor_phase( + pool: &PgPool, + contract_id: Uuid, + new_phase: &str, +) -> Result<SupervisorState, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET phase = $1, + state = 'working', + current_activity = 'Phase changed to ' || $1, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $2 + RETURNING * + "#, + ) + .bind(new_phase) + .bind(contract_id) + .fetch_one(pool) + .await +} + +/// Update supervisor state on heartbeat (lightweight update). +pub async fn update_supervisor_heartbeat_state( + pool: &PgPool, + contract_id: Uuid, + state: &str, + current_activity: Option<&str>, + progress: i32, + pending_task_ids: &[Uuid], +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + UPDATE supervisor_states + SET state = $1, + current_activity = $2, + progress = $3, + pending_task_ids = $4, + last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $5 + "#, + ) + .bind(state) + .bind(current_activity) + .bind(progress) + .bind(pending_task_ids) + .bind(contract_id) + .execute(pool) + .await?; + Ok(()) +} + +// ============================================================================ +// Supervisor Heartbeats +// ============================================================================ + +/// Record a supervisor heartbeat. +/// This creates a historical record for monitoring and dead supervisor detection. +pub async fn create_supervisor_heartbeat( + pool: &PgPool, + supervisor_task_id: Uuid, + contract_id: Uuid, + state: &str, + phase: &str, + current_activity: Option<&str>, + progress: i32, + pending_task_ids: &[Uuid], +) -> Result<SupervisorHeartbeatRecord, sqlx::Error> { + sqlx::query_as::<_, SupervisorHeartbeatRecord>( + r#" + INSERT INTO supervisor_heartbeats ( + supervisor_task_id, contract_id, state, phase, current_activity, progress, pending_task_ids, timestamp + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) + RETURNING * + "#, + ) + .bind(supervisor_task_id) + .bind(contract_id) + .bind(state) + .bind(phase) + .bind(current_activity) + .bind(progress) + .bind(pending_task_ids) + .fetch_one(pool) + .await +} + +/// Get the latest heartbeat for a supervisor task. +pub async fn get_latest_supervisor_heartbeat( + pool: &PgPool, + supervisor_task_id: Uuid, +) -> Result<Option<SupervisorHeartbeatRecord>, sqlx::Error> { + sqlx::query_as::<_, SupervisorHeartbeatRecord>( + r#" + SELECT * FROM supervisor_heartbeats + WHERE supervisor_task_id = $1 + ORDER BY timestamp DESC + LIMIT 1 + "#, + ) + .bind(supervisor_task_id) + .fetch_optional(pool) + .await +} + +/// Get recent heartbeats for a supervisor task. +pub async fn get_supervisor_heartbeats( + pool: &PgPool, + supervisor_task_id: Uuid, + limit: i64, +) -> Result<Vec<SupervisorHeartbeatRecord>, sqlx::Error> { + sqlx::query_as::<_, SupervisorHeartbeatRecord>( + r#" + SELECT * FROM supervisor_heartbeats + WHERE supervisor_task_id = $1 + ORDER BY timestamp DESC + LIMIT $2 + "#, + ) + .bind(supervisor_task_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Get recent heartbeats for a contract. +pub async fn get_contract_supervisor_heartbeats( + pool: &PgPool, + contract_id: Uuid, + limit: i64, +) -> Result<Vec<SupervisorHeartbeatRecord>, sqlx::Error> { + sqlx::query_as::<_, SupervisorHeartbeatRecord>( + r#" + SELECT * FROM supervisor_heartbeats + WHERE contract_id = $1 + ORDER BY timestamp DESC + LIMIT $2 + "#, + ) + .bind(contract_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Delete old heartbeats beyond the TTL (24 hours by default). +/// Returns the number of deleted records. +pub async fn cleanup_old_heartbeats( + pool: &PgPool, + ttl_hours: i64, +) -> Result<u64, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM supervisor_heartbeats + WHERE timestamp < NOW() - ($1 || ' hours')::INTERVAL + "#, + ) + .bind(ttl_hours.to_string()) + .execute(pool) + .await?; + + Ok(result.rows_affected()) +} + +/// Find supervisors that have not sent a heartbeat within the timeout period. +/// Returns list of (supervisor_task_id, contract_id, last_heartbeat_timestamp). +pub async fn find_stale_supervisors( + pool: &PgPool, + timeout_seconds: i64, +) -> Result<Vec<(Uuid, Uuid, chrono::DateTime<Utc>)>, sqlx::Error> { + let rows = sqlx::query( + r#" + WITH latest_heartbeats AS ( + SELECT DISTINCT ON (supervisor_task_id) + supervisor_task_id, + contract_id, + timestamp + FROM supervisor_heartbeats + ORDER BY supervisor_task_id, timestamp DESC + ) + SELECT + lh.supervisor_task_id, + lh.contract_id, + lh.timestamp + FROM latest_heartbeats lh + JOIN tasks t ON t.id = lh.supervisor_task_id + WHERE t.status = 'running' + AND lh.timestamp < NOW() - ($1 || ' seconds')::INTERVAL + "#, + ) + .bind(timeout_seconds.to_string()) + .fetch_all(pool) + .await?; + + let mut result = Vec::new(); + for row in rows { + use sqlx::Row; + let supervisor_task_id: Uuid = row.get("supervisor_task_id"); + let contract_id: Uuid = row.get("contract_id"); + let timestamp: chrono::DateTime<Utc> = row.get("timestamp"); + result.push((supervisor_task_id, contract_id, timestamp)); + } + Ok(result) +} + // ============================================================================ // Contract Supervisor // ============================================================================ @@ -4402,3 +4860,138 @@ pub async fn get_notification_count_for_task( .map_err(RepositoryError::Database)?; Ok(result.0) } + +// ============================================================================= +// Supervisor Status API Helpers +// ============================================================================= + +/// Get supervisor status for a contract. +/// Returns combined information from supervisor_states and tasks tables. +pub async fn get_supervisor_status( + pool: &PgPool, + contract_id: Uuid, + owner_id: Uuid, +) -> Result<Option<SupervisorStatusInfo>, sqlx::Error> { + // Query to get supervisor status by joining supervisor_states with tasks + sqlx::query_as::<_, SupervisorStatusInfo>( + r#" + SELECT + ss.task_id, + COALESCE(t.status, 'unknown') as supervisor_state, + ss.phase, + t.progress_summary as current_activity, + ss.pending_task_ids, + ss.last_activity as last_heartbeat, + t.status as task_status, + t.daemon_id IS NOT NULL as is_running + FROM supervisor_states ss + JOIN tasks t ON t.id = ss.task_id + WHERE ss.contract_id = $1 + AND t.owner_id = $2 + "#, + ) + .bind(contract_id) + .bind(owner_id) + .fetch_optional(pool) + .await +} + +/// Internal struct to hold supervisor status query result +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct SupervisorStatusInfo { + pub task_id: Uuid, + pub supervisor_state: String, + pub phase: String, + pub current_activity: Option<String>, + #[sqlx(try_from = "Vec<Uuid>")] + pub pending_task_ids: Vec<Uuid>, + pub last_heartbeat: chrono::DateTime<chrono::Utc>, + pub task_status: String, + pub is_running: bool, +} + +/// Get supervisor activity history from history_events table. +/// This provides a timeline of supervisor activities that can serve as "heartbeats". +pub async fn get_supervisor_activity_history( + pool: &PgPool, + contract_id: Uuid, + limit: i32, + offset: i32, +) -> Result<Vec<SupervisorActivityEntry>, sqlx::Error> { + sqlx::query_as::<_, SupervisorActivityEntry>( + r#" + SELECT + created_at as timestamp, + COALESCE(event_subtype, 'activity') as state, + event_data->>'activity' as activity, + (event_data->>'progress')::INTEGER as progress, + COALESCE(phase, 'unknown') as phase, + CASE + WHEN event_data->'pending_task_ids' IS NOT NULL + THEN ARRAY(SELECT jsonb_array_elements_text(event_data->'pending_task_ids'))::UUID[] + ELSE ARRAY[]::UUID[] + END as pending_task_ids + FROM history_events + WHERE contract_id = $1 + AND event_type IN ('supervisor', 'phase', 'task') + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + "#, + ) + .bind(contract_id) + .bind(limit) + .bind(offset) + .fetch_all(pool) + .await +} + +/// Internal struct to hold supervisor activity entry +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct SupervisorActivityEntry { + pub timestamp: chrono::DateTime<chrono::Utc>, + pub state: String, + pub activity: Option<String>, + pub progress: Option<i32>, + pub phase: String, + #[sqlx(try_from = "Vec<Uuid>")] + pub pending_task_ids: Vec<Uuid>, +} + +/// Count total supervisor activity history entries for a contract. +pub async fn count_supervisor_activity_history( + pool: &PgPool, + contract_id: Uuid, +) -> Result<i64, sqlx::Error> { + let result: (i64,) = sqlx::query_as( + r#" + SELECT COUNT(*) + FROM history_events + WHERE contract_id = $1 + AND event_type IN ('supervisor', 'phase', 'task') + "#, + ) + .bind(contract_id) + .fetch_one(pool) + .await?; + Ok(result.0) +} + +/// Update supervisor state last_activity timestamp. +/// This acts as a "sync" operation to refresh the supervisor's heartbeat. +pub async fn sync_supervisor_state( + pool: &PgPool, + contract_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + sqlx::query_as::<_, SupervisorState>( + r#" + UPDATE supervisor_states + SET last_activity = NOW(), + updated_at = NOW() + WHERE contract_id = $1 + RETURNING * + "#, + ) + .bind(contract_id) + .fetch_optional(pool) + .await +} |
