diff options
Diffstat (limited to 'makima/src/db')
| -rw-r--r-- | makima/src/db/models.rs | 233 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 360 |
2 files changed, 589 insertions, 4 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 40d4109..4419580 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -1559,3 +1559,236 @@ pub struct RepositorySuggestionsQuery { /// Limit results (default: 10) pub limit: Option<i32>, } + +// ============================================================================= +// Resume and History System Types +// ============================================================================= + +/// Conversation snapshot for task resumption +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ConversationSnapshot { + pub id: Uuid, + pub task_id: Uuid, + pub checkpoint_id: Option<Uuid>, + /// Snapshot type: 'auto', 'manual', 'checkpoint' + pub snapshot_type: String, + pub message_count: i32, + #[sqlx(json)] + pub conversation_state: serde_json::Value, + #[sqlx(json)] + pub metadata: Option<serde_json::Value>, + pub created_at: DateTime<Utc>, +} + +/// History event for contract/task history tracking +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct HistoryEvent { + pub id: Uuid, + pub owner_id: Uuid, + pub contract_id: Option<Uuid>, + pub task_id: Option<Uuid>, + pub event_type: String, + pub event_subtype: Option<String>, + pub phase: Option<String>, + #[sqlx(json)] + pub event_data: serde_json::Value, + pub created_at: DateTime<Utc>, +} + +/// Unified conversation message for API responses +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ConversationMessage { + pub id: String, + /// Message role: 'user', 'assistant', 'system', 'tool' + pub role: String, + pub content: String, + pub timestamp: DateTime<Utc>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option<Vec<ToolCallInfo>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_input: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_result: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option<bool>, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_count: Option<i32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_usd: Option<f64>, +} + +/// Tool call information within a conversation message +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallInfo { + pub id: String, + pub name: String, + pub input: serde_json::Value, +} + +/// Query filters for history endpoints +#[derive(Debug, Deserialize, ToSchema, Default)] +#[serde(rename_all = "camelCase")] +pub struct HistoryQueryFilters { + pub phase: Option<String>, + pub event_types: Option<Vec<String>>, + pub from: Option<DateTime<Utc>>, + pub to: Option<DateTime<Utc>>, + pub limit: Option<i32>, + pub cursor: Option<String>, +} + +/// Request to resume a supervisor +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ResumeSupervisorRequest { + pub target_daemon_id: Option<Uuid>, + /// Resume mode: 'continue', 'restart_phase', 'from_checkpoint' + pub resume_mode: String, + pub checkpoint_id: Option<Uuid>, + pub additional_context: Option<String>, +} + +/// Request to resume from a checkpoint +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ResumeFromCheckpointRequest { + pub task_name: Option<String>, + pub plan: String, + pub include_conversation: Option<bool>, + pub target_daemon_id: Option<Uuid>, +} + +/// Request to rewind a task to a checkpoint +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RewindTaskRequest { + pub checkpoint_id: Option<Uuid>, + pub checkpoint_sha: Option<String>, + /// Preserve mode: 'discard', 'create_branch', 'stash' + pub preserve_mode: String, + pub branch_name: Option<String>, +} + +/// Request to rewind a conversation +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RewindConversationRequest { + pub to_message_id: Option<String>, + pub to_timestamp: Option<DateTime<Utc>>, + pub by_message_count: Option<i32>, + pub rewind_code: Option<bool>, +} + +/// Request to fork a task +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ForkTaskRequest { + /// Fork from type: 'checkpoint', 'timestamp', 'message_id' + pub fork_from_type: String, + pub fork_from_value: String, + pub new_task_name: String, + pub new_task_plan: String, + pub include_conversation: Option<bool>, + pub create_branch: Option<bool>, + pub branch_name: Option<String>, +} + +/// Response for contract history endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ContractHistoryResponse { + pub contract_id: Uuid, + pub entries: Vec<HistoryEvent>, + pub total_count: i64, + pub cursor: Option<String>, +} + +/// Response for task conversation endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskConversationResponse { + pub task_id: Uuid, + pub task_name: String, + pub status: String, + pub messages: Vec<ConversationMessage>, + pub total_tokens: Option<i32>, + pub total_cost: Option<f64>, +} + +/// Response for supervisor conversation endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SupervisorConversationResponse { + pub contract_id: Uuid, + pub supervisor_task_id: Uuid, + pub phase: String, + pub last_activity: DateTime<Utc>, + pub pending_task_ids: Vec<Uuid>, + pub messages: Vec<ConversationMessage>, + pub spawned_tasks: Vec<TaskReference>, +} + +/// Reference to a task for history/conversation responses +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskReference { + pub task_id: Uuid, + pub task_name: String, + pub status: String, + pub created_at: DateTime<Utc>, + pub completed_at: Option<DateTime<Utc>>, +} + +/// Response for task rewind operation +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RewindTaskResponse { + pub task_id: Uuid, + pub rewinded_to: CheckpointInfo, + pub preserved_as: Option<PreservedState>, +} + +/// Checkpoint information in rewind response +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CheckpointInfo { + pub checkpoint_number: i32, + pub sha: String, + pub message: String, +} + +/// Preserved state information in rewind response +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PreservedState { + /// State type: 'branch' or 'stash' + pub state_type: String, + pub reference: String, +} + +/// Response for task fork operation +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ForkTaskResponse { + pub new_task_id: Uuid, + pub source_task_id: Uuid, + pub fork_point: ForkPoint, + pub branch_name: Option<String>, + pub conversation_included: bool, + pub message_count: Option<i32>, +} + +/// Fork point information in fork response +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ForkPoint { + pub fork_type: String, + pub checkpoint: Option<TaskCheckpoint>, + pub timestamp: DateTime<Utc>, +} diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 92b2048..cb9d52f 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -1,16 +1,17 @@ //! Repository pattern for file database operations. use chrono::Utc; +use serde::Deserialize; use sqlx::PgPool; use uuid::Uuid; use super::models::{ Contract, ContractChatConversation, ContractChatMessageRecord, ContractEvent, ContractRepository, - ContractSummary, CreateContractRequest, CreateFileRequest, + ContractSummary, ConversationMessage, ConversationSnapshot, CreateContractRequest, CreateFileRequest, CreateTaskRequest, Daemon, DaemonTaskAssignment, DaemonWithCapacity, File, FileSummary, - FileVersion, MeshChatConversation, MeshChatMessageRecord, SupervisorState, Task, TaskCheckpoint, - TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest, - UpdateTaskRequest, + FileVersion, HistoryEvent, HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, + SupervisorState, Task, TaskCheckpoint, TaskEvent, TaskSummary, UpdateContractRequest, + UpdateFileRequest, UpdateTaskRequest, }; /// Repository error types. @@ -3203,3 +3204,354 @@ pub async fn delete_repository_history( Ok(result.rows_affected() > 0) } + +// ============================================================================ +// Conversation Snapshots +// ============================================================================ + +/// Create a new conversation snapshot +pub async fn create_conversation_snapshot( + pool: &PgPool, + task_id: Uuid, + checkpoint_id: Option<Uuid>, + snapshot_type: &str, + message_count: i32, + conversation_state: serde_json::Value, + metadata: Option<serde_json::Value>, +) -> Result<ConversationSnapshot, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + r#" + INSERT INTO conversation_snapshots (task_id, checkpoint_id, snapshot_type, message_count, conversation_state, metadata) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING * + "# + ) + .bind(task_id) + .bind(checkpoint_id) + .bind(snapshot_type) + .bind(message_count) + .bind(conversation_state) + .bind(metadata) + .fetch_one(pool) + .await +} + +/// Get a conversation snapshot by ID +pub async fn get_conversation_snapshot( + pool: &PgPool, + id: Uuid, +) -> Result<Option<ConversationSnapshot>, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE id = $1" + ) + .bind(id) + .fetch_optional(pool) + .await +} + +/// Get conversation snapshot at a specific checkpoint +pub async fn get_conversation_at_checkpoint( + pool: &PgPool, + checkpoint_id: Uuid, +) -> Result<Option<ConversationSnapshot>, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE checkpoint_id = $1 ORDER BY created_at DESC LIMIT 1" + ) + .bind(checkpoint_id) + .fetch_optional(pool) + .await +} + +/// List conversation snapshots for a task +pub async fn list_conversation_snapshots( + pool: &PgPool, + task_id: Uuid, + limit: Option<i32>, +) -> Result<Vec<ConversationSnapshot>, sqlx::Error> { + let limit = limit.unwrap_or(100); + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE task_id = $1 ORDER BY created_at DESC LIMIT $2" + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Delete conversation snapshots older than retention period +pub async fn cleanup_old_snapshots( + pool: &PgPool, + retention_days: i32, +) -> Result<u64, sqlx::Error> { + let result = sqlx::query( + "DELETE FROM conversation_snapshots WHERE created_at < NOW() - INTERVAL '1 day' * $1" + ) + .bind(retention_days) + .execute(pool) + .await?; + Ok(result.rows_affected()) +} + +// ============================================================================ +// History Events +// ============================================================================ + +/// Record a new history event +#[allow(clippy::too_many_arguments)] +pub async fn record_history_event( + pool: &PgPool, + owner_id: Uuid, + contract_id: Option<Uuid>, + task_id: Option<Uuid>, + event_type: &str, + event_subtype: Option<&str>, + phase: Option<&str>, + event_data: serde_json::Value, +) -> Result<HistoryEvent, sqlx::Error> { + sqlx::query_as::<_, HistoryEvent>( + r#" + INSERT INTO history_events (owner_id, contract_id, task_id, event_type, event_subtype, phase, event_data) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING * + "# + ) + .bind(owner_id) + .bind(contract_id) + .bind(task_id) + .bind(event_type) + .bind(event_subtype) + .bind(phase) + .bind(event_data) + .fetch_one(pool) + .await +} + +/// Get contract history timeline +pub async fn get_contract_history( + pool: &PgPool, + contract_id: Uuid, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let mut query = String::from( + "SELECT * FROM history_events WHERE contract_id = $1 AND owner_id = $2" + ); + let mut count_query = String::from( + "SELECT COUNT(*) FROM history_events WHERE contract_id = $1 AND owner_id = $2" + ); + + let mut param_count = 2; + + if filters.phase.is_some() { + param_count += 1; + query.push_str(&format!(" AND phase = ${}" , param_count)); + count_query.push_str(&format!(" AND phase = ${}", param_count)); + } + + if filters.from.is_some() { + param_count += 1; + query.push_str(&format!(" AND created_at >= ${}", param_count)); + count_query.push_str(&format!(" AND created_at >= ${}", param_count)); + } + + if filters.to.is_some() { + param_count += 1; + query.push_str(&format!(" AND created_at <= ${}", param_count)); + count_query.push_str(&format!(" AND created_at <= ${}", param_count)); + } + + query.push_str(" ORDER BY created_at DESC"); + query.push_str(&format!(" LIMIT {}", limit)); + + // Build and execute the query dynamically + let mut q = sqlx::query_as::<_, HistoryEvent>(&query) + .bind(contract_id) + .bind(owner_id); + + if let Some(ref phase) = filters.phase { + q = q.bind(phase); + } + if let Some(ref from) = filters.from { + q = q.bind(from); + } + if let Some(ref to) = filters.to { + q = q.bind(to); + } + + let events = q.fetch_all(pool).await?; + + // Get total count + let mut cq = sqlx::query_scalar::<_, i64>(&count_query) + .bind(contract_id) + .bind(owner_id); + + if let Some(ref phase) = filters.phase { + cq = cq.bind(phase); + } + if let Some(ref from) = filters.from { + cq = cq.bind(from); + } + if let Some(ref to) = filters.to { + cq = cq.bind(to); + } + + let count = cq.fetch_one(pool).await?; + + Ok((events, count)) +} + +/// Get task history +pub async fn get_task_history( + pool: &PgPool, + task_id: Uuid, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let events = sqlx::query_as::<_, HistoryEvent>( + r#" + SELECT * FROM history_events + WHERE task_id = $1 AND owner_id = $2 + ORDER BY created_at DESC + LIMIT $3 + "# + ) + .bind(task_id) + .bind(owner_id) + .bind(limit) + .fetch_all(pool) + .await?; + + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM history_events WHERE task_id = $1 AND owner_id = $2" + ) + .bind(task_id) + .bind(owner_id) + .fetch_one(pool) + .await?; + + Ok((events, count)) +} + +/// Get unified timeline for an owner +pub async fn get_timeline( + pool: &PgPool, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let events = sqlx::query_as::<_, HistoryEvent>( + r#" + SELECT * FROM history_events + WHERE owner_id = $1 + ORDER BY created_at DESC + LIMIT $2 + "# + ) + .bind(owner_id) + .bind(limit) + .fetch_all(pool) + .await?; + + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM history_events WHERE owner_id = $1" + ) + .bind(owner_id) + .fetch_one(pool) + .await?; + + Ok((events, count)) +} + +// ============================================================================ +// Task Conversation Retrieval +// ============================================================================ + +// Helper struct for parsing task output events +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct TaskOutputEvent { + message_type: String, + content: Option<String>, + tool_name: Option<String>, + tool_input: Option<serde_json::Value>, + is_error: Option<bool>, + cost_usd: Option<f32>, +} + +/// Get task conversation messages (reconstructed from task_events) +pub async fn get_task_conversation( + pool: &PgPool, + task_id: Uuid, + include_tool_calls: bool, + include_tool_results: bool, + limit: Option<i32>, +) -> Result<Vec<ConversationMessage>, sqlx::Error> { + let limit = limit.unwrap_or(1000); + + // Get output events that represent conversation turns + let events = sqlx::query_as::<_, TaskEvent>( + r#" + SELECT * FROM task_events + WHERE task_id = $1 AND event_type = 'output' + ORDER BY created_at ASC + LIMIT $2 + "# + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await?; + + // Convert task events to conversation messages + let mut messages = Vec::new(); + for event in events { + if let Some(data) = event.event_data { + // Parse the event data to extract message info + if let Ok(output) = serde_json::from_value::<TaskOutputEvent>(data.clone()) { + let should_include = match output.message_type.as_str() { + "tool_use" => include_tool_calls, + "tool_result" => include_tool_results, + _ => true, + }; + + if should_include { + messages.push(ConversationMessage { + id: event.id.to_string(), + role: match output.message_type.as_str() { + "assistant" => "assistant".to_string(), + "tool_use" => "assistant".to_string(), + "tool_result" => "tool".to_string(), + "system" => "system".to_string(), + "error" => "system".to_string(), + _ => "user".to_string(), + }, + content: output.content.unwrap_or_default(), + timestamp: event.created_at, + tool_calls: None, + tool_name: output.tool_name, + tool_input: output.tool_input, + tool_result: None, + is_error: output.is_error, + token_count: None, + cost_usd: output.cost_usd.map(|c| c as f64), + }); + } + } + } + } + + Ok(messages) +} + +/// Get supervisor conversation (from supervisor_states) +pub async fn get_supervisor_conversation_full( + pool: &PgPool, + contract_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + get_supervisor_state(pool, contract_id).await +} |
