diff options
Diffstat (limited to 'makima/src/db/repository.rs')
| -rw-r--r-- | makima/src/db/repository.rs | 360 |
1 files changed, 356 insertions, 4 deletions
diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 92b2048..cb9d52f 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -1,16 +1,17 @@ //! Repository pattern for file database operations. use chrono::Utc; +use serde::Deserialize; use sqlx::PgPool; use uuid::Uuid; use super::models::{ Contract, ContractChatConversation, ContractChatMessageRecord, ContractEvent, ContractRepository, - ContractSummary, CreateContractRequest, CreateFileRequest, + ContractSummary, ConversationMessage, ConversationSnapshot, CreateContractRequest, CreateFileRequest, CreateTaskRequest, Daemon, DaemonTaskAssignment, DaemonWithCapacity, File, FileSummary, - FileVersion, MeshChatConversation, MeshChatMessageRecord, SupervisorState, Task, TaskCheckpoint, - TaskEvent, TaskSummary, UpdateContractRequest, UpdateFileRequest, - UpdateTaskRequest, + FileVersion, HistoryEvent, HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, + SupervisorState, Task, TaskCheckpoint, TaskEvent, TaskSummary, UpdateContractRequest, + UpdateFileRequest, UpdateTaskRequest, }; /// Repository error types. @@ -3203,3 +3204,354 @@ pub async fn delete_repository_history( Ok(result.rows_affected() > 0) } + +// ============================================================================ +// Conversation Snapshots +// ============================================================================ + +/// Create a new conversation snapshot +pub async fn create_conversation_snapshot( + pool: &PgPool, + task_id: Uuid, + checkpoint_id: Option<Uuid>, + snapshot_type: &str, + message_count: i32, + conversation_state: serde_json::Value, + metadata: Option<serde_json::Value>, +) -> Result<ConversationSnapshot, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + r#" + INSERT INTO conversation_snapshots (task_id, checkpoint_id, snapshot_type, message_count, conversation_state, metadata) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING * + "# + ) + .bind(task_id) + .bind(checkpoint_id) + .bind(snapshot_type) + .bind(message_count) + .bind(conversation_state) + .bind(metadata) + .fetch_one(pool) + .await +} + +/// Get a conversation snapshot by ID +pub async fn get_conversation_snapshot( + pool: &PgPool, + id: Uuid, +) -> Result<Option<ConversationSnapshot>, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE id = $1" + ) + .bind(id) + .fetch_optional(pool) + .await +} + +/// Get conversation snapshot at a specific checkpoint +pub async fn get_conversation_at_checkpoint( + pool: &PgPool, + checkpoint_id: Uuid, +) -> Result<Option<ConversationSnapshot>, sqlx::Error> { + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE checkpoint_id = $1 ORDER BY created_at DESC LIMIT 1" + ) + .bind(checkpoint_id) + .fetch_optional(pool) + .await +} + +/// List conversation snapshots for a task +pub async fn list_conversation_snapshots( + pool: &PgPool, + task_id: Uuid, + limit: Option<i32>, +) -> Result<Vec<ConversationSnapshot>, sqlx::Error> { + let limit = limit.unwrap_or(100); + sqlx::query_as::<_, ConversationSnapshot>( + "SELECT * FROM conversation_snapshots WHERE task_id = $1 ORDER BY created_at DESC LIMIT $2" + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Delete conversation snapshots older than retention period +pub async fn cleanup_old_snapshots( + pool: &PgPool, + retention_days: i32, +) -> Result<u64, sqlx::Error> { + let result = sqlx::query( + "DELETE FROM conversation_snapshots WHERE created_at < NOW() - INTERVAL '1 day' * $1" + ) + .bind(retention_days) + .execute(pool) + .await?; + Ok(result.rows_affected()) +} + +// ============================================================================ +// History Events +// ============================================================================ + +/// Record a new history event +#[allow(clippy::too_many_arguments)] +pub async fn record_history_event( + pool: &PgPool, + owner_id: Uuid, + contract_id: Option<Uuid>, + task_id: Option<Uuid>, + event_type: &str, + event_subtype: Option<&str>, + phase: Option<&str>, + event_data: serde_json::Value, +) -> Result<HistoryEvent, sqlx::Error> { + sqlx::query_as::<_, HistoryEvent>( + r#" + INSERT INTO history_events (owner_id, contract_id, task_id, event_type, event_subtype, phase, event_data) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING * + "# + ) + .bind(owner_id) + .bind(contract_id) + .bind(task_id) + .bind(event_type) + .bind(event_subtype) + .bind(phase) + .bind(event_data) + .fetch_one(pool) + .await +} + +/// Get contract history timeline +pub async fn get_contract_history( + pool: &PgPool, + contract_id: Uuid, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let mut query = String::from( + "SELECT * FROM history_events WHERE contract_id = $1 AND owner_id = $2" + ); + let mut count_query = String::from( + "SELECT COUNT(*) FROM history_events WHERE contract_id = $1 AND owner_id = $2" + ); + + let mut param_count = 2; + + if filters.phase.is_some() { + param_count += 1; + query.push_str(&format!(" AND phase = ${}" , param_count)); + count_query.push_str(&format!(" AND phase = ${}", param_count)); + } + + if filters.from.is_some() { + param_count += 1; + query.push_str(&format!(" AND created_at >= ${}", param_count)); + count_query.push_str(&format!(" AND created_at >= ${}", param_count)); + } + + if filters.to.is_some() { + param_count += 1; + query.push_str(&format!(" AND created_at <= ${}", param_count)); + count_query.push_str(&format!(" AND created_at <= ${}", param_count)); + } + + query.push_str(" ORDER BY created_at DESC"); + query.push_str(&format!(" LIMIT {}", limit)); + + // Build and execute the query dynamically + let mut q = sqlx::query_as::<_, HistoryEvent>(&query) + .bind(contract_id) + .bind(owner_id); + + if let Some(ref phase) = filters.phase { + q = q.bind(phase); + } + if let Some(ref from) = filters.from { + q = q.bind(from); + } + if let Some(ref to) = filters.to { + q = q.bind(to); + } + + let events = q.fetch_all(pool).await?; + + // Get total count + let mut cq = sqlx::query_scalar::<_, i64>(&count_query) + .bind(contract_id) + .bind(owner_id); + + if let Some(ref phase) = filters.phase { + cq = cq.bind(phase); + } + if let Some(ref from) = filters.from { + cq = cq.bind(from); + } + if let Some(ref to) = filters.to { + cq = cq.bind(to); + } + + let count = cq.fetch_one(pool).await?; + + Ok((events, count)) +} + +/// Get task history +pub async fn get_task_history( + pool: &PgPool, + task_id: Uuid, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let events = sqlx::query_as::<_, HistoryEvent>( + r#" + SELECT * FROM history_events + WHERE task_id = $1 AND owner_id = $2 + ORDER BY created_at DESC + LIMIT $3 + "# + ) + .bind(task_id) + .bind(owner_id) + .bind(limit) + .fetch_all(pool) + .await?; + + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM history_events WHERE task_id = $1 AND owner_id = $2" + ) + .bind(task_id) + .bind(owner_id) + .fetch_one(pool) + .await?; + + Ok((events, count)) +} + +/// Get unified timeline for an owner +pub async fn get_timeline( + pool: &PgPool, + owner_id: Uuid, + filters: &HistoryQueryFilters, +) -> Result<(Vec<HistoryEvent>, i64), sqlx::Error> { + let limit = filters.limit.unwrap_or(100); + + let events = sqlx::query_as::<_, HistoryEvent>( + r#" + SELECT * FROM history_events + WHERE owner_id = $1 + ORDER BY created_at DESC + LIMIT $2 + "# + ) + .bind(owner_id) + .bind(limit) + .fetch_all(pool) + .await?; + + let count: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM history_events WHERE owner_id = $1" + ) + .bind(owner_id) + .fetch_one(pool) + .await?; + + Ok((events, count)) +} + +// ============================================================================ +// Task Conversation Retrieval +// ============================================================================ + +// Helper struct for parsing task output events +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct TaskOutputEvent { + message_type: String, + content: Option<String>, + tool_name: Option<String>, + tool_input: Option<serde_json::Value>, + is_error: Option<bool>, + cost_usd: Option<f32>, +} + +/// Get task conversation messages (reconstructed from task_events) +pub async fn get_task_conversation( + pool: &PgPool, + task_id: Uuid, + include_tool_calls: bool, + include_tool_results: bool, + limit: Option<i32>, +) -> Result<Vec<ConversationMessage>, sqlx::Error> { + let limit = limit.unwrap_or(1000); + + // Get output events that represent conversation turns + let events = sqlx::query_as::<_, TaskEvent>( + r#" + SELECT * FROM task_events + WHERE task_id = $1 AND event_type = 'output' + ORDER BY created_at ASC + LIMIT $2 + "# + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await?; + + // Convert task events to conversation messages + let mut messages = Vec::new(); + for event in events { + if let Some(data) = event.event_data { + // Parse the event data to extract message info + if let Ok(output) = serde_json::from_value::<TaskOutputEvent>(data.clone()) { + let should_include = match output.message_type.as_str() { + "tool_use" => include_tool_calls, + "tool_result" => include_tool_results, + _ => true, + }; + + if should_include { + messages.push(ConversationMessage { + id: event.id.to_string(), + role: match output.message_type.as_str() { + "assistant" => "assistant".to_string(), + "tool_use" => "assistant".to_string(), + "tool_result" => "tool".to_string(), + "system" => "system".to_string(), + "error" => "system".to_string(), + _ => "user".to_string(), + }, + content: output.content.unwrap_or_default(), + timestamp: event.created_at, + tool_calls: None, + tool_name: output.tool_name, + tool_input: output.tool_input, + tool_result: None, + is_error: output.is_error, + token_count: None, + cost_usd: output.cost_usd.map(|c| c as f64), + }); + } + } + } + } + + Ok(messages) +} + +/// Get supervisor conversation (from supervisor_states) +pub async fn get_supervisor_conversation_full( + pool: &PgPool, + contract_id: Uuid, +) -> Result<Option<SupervisorState>, sqlx::Error> { + get_supervisor_state(pool, contract_id).await +} |
