diff options
| author | soryu <soryu@soryu.co> | 2026-01-21 17:31:46 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-21 17:31:46 +0000 |
| commit | 94e5604e770d6589f786ea71e51738e21492f301 (patch) | |
| tree | 6c9b0f32a8d77464bc1a5131ba0828d252851abc /makima/src/db | |
| parent | da246c4c4e23c9ad976705f9a3fa80e0d75b4425 (diff) | |
| download | soryu-94e5604e770d6589f786ea71e51738e21492f301.tar.gz soryu-94e5604e770d6589f786ea71e51738e21492f301.zip | |
Add task branching feature (#15)
Diffstat (limited to 'makima/src/db')
| -rw-r--r-- | makima/src/db/models.rs | 45 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 54 |
2 files changed, 87 insertions, 12 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 291fad7..bf95a3a 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -519,6 +519,12 @@ pub struct Task { /// When the task was last interrupted due to daemon disconnect #[serde(skip_serializing_if = "Option::is_none")] pub interrupted_at: Option<DateTime<Utc>>, + + // Task branching + /// Source task ID when this task was branched from another task's conversation. + /// Used to track the origin of "what if" explorations. + #[serde(skip_serializing_if = "Option::is_none")] + pub branched_from_task_id: Option<Uuid>, } impl Task { @@ -598,8 +604,8 @@ pub struct TaskListResponse { #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct CreateTaskRequest { - /// Contract this task belongs to (required) - pub contract_id: Uuid, + /// Contract this task belongs to (optional for branched/anonymous tasks) + pub contract_id: Option<Uuid>, /// Name of the task pub name: String, /// Optional description @@ -633,6 +639,10 @@ pub struct CreateTaskRequest { pub copy_files: Option<Vec<String>>, /// Checkpoint SHA to branch from (optional) pub checkpoint_sha: Option<String>, + /// Source task ID when branching from another task's conversation + pub branched_from_task_id: Option<Uuid>, + /// Conversation history to initialize the task with (JSON array of messages) + pub conversation_history: Option<serde_json::Value>, } /// Request payload for updating a task @@ -681,6 +691,37 @@ pub struct SendMessageRequest { pub message: String, } +/// Default for include_conversation field in BranchTaskRequest +fn default_include_conversation() -> bool { + true +} + +/// Request to branch a task from an existing task's conversation. +/// Creates a new anonymous task that continues from the source task's state. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct BranchTaskRequest { + /// The initial message/instructions for the branched task + pub message: String, + /// Optional name for the branched task (auto-generated if not provided) + pub name: Option<String>, + /// Whether to include conversation history from the source task (default: true) + #[serde(default = "default_include_conversation")] + pub include_conversation: bool, +} + +/// Response from branching a task. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct BranchTaskResponse { + /// The newly created branched task + pub task: Task, + /// Number of conversation messages included from source task + pub message_count: usize, + /// Daemon ID if the task was started (None if no daemon available) + pub daemon_id: Option<Uuid>, +} + // ============================================================================= // Daemon Types // ============================================================================= diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 536bc9b..7387735 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -654,8 +654,8 @@ pub async fn create_task(pool: &PgPool, req: CreateTaskRequest) -> Result<Task, let new_depth = parent.depth + 1; - // Subtasks inherit contract_id from parent - let contract_id = parent.contract_id.unwrap_or(req.contract_id); + // Subtasks inherit contract_id from parent (or use request contract_id if parent has none) + let contract_id = parent.contract_id.or(req.contract_id); // Inherit repo settings if not provided let repo_url = req.repository_url.clone().or(parent.repository_url); @@ -669,7 +669,7 @@ pub async fn create_task(pool: &PgPool, req: CreateTaskRequest) -> Result<Task, (new_depth, contract_id, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) } else { - // Top-level task: depth 0, use contract_id from request + // Top-level task: depth 0, use contract_id from request (may be None for branched tasks) ( 0, req.contract_id, @@ -689,9 +689,10 @@ pub async fn create_task(pool: &PgPool, req: CreateTaskRequest) -> Result<Task, INSERT INTO tasks ( contract_id, parent_task_id, depth, name, description, plan, priority, is_supervisor, repository_url, base_branch, target_branch, merge_mode, - target_repo_path, completion_action, continue_from_task_id, copy_files + target_repo_path, completion_action, continue_from_task_id, copy_files, + branched_from_task_id, conversation_state ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18) RETURNING * "#, ) @@ -711,6 +712,8 @@ pub async fn create_task(pool: &PgPool, req: CreateTaskRequest) -> Result<Task, .bind(&completion_action) .bind(&req.continue_from_task_id) .bind(©_files_json) + .bind(&req.branched_from_task_id) + .bind(&req.conversation_history) .fetch_one(pool) .await } @@ -1041,8 +1044,8 @@ pub async fn create_task_for_owner( ))); } - // Subtasks inherit contract_id from parent - let contract_id = parent.contract_id.unwrap_or(req.contract_id); + // Subtasks inherit contract_id from parent (or use request contract_id if parent has none) + let contract_id = parent.contract_id.or(req.contract_id); // Inherit repo settings if not provided let repo_url = req.repository_url.clone().or(parent.repository_url); @@ -1056,7 +1059,7 @@ pub async fn create_task_for_owner( (new_depth, contract_id, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) } else { - // Top-level task: depth 0, use contract_id from request + // Top-level task: depth 0, use contract_id from request (may be None for branched tasks) ( 0, req.contract_id, @@ -1076,9 +1079,10 @@ pub async fn create_task_for_owner( INSERT INTO tasks ( owner_id, contract_id, parent_task_id, depth, name, description, plan, priority, is_supervisor, repository_url, base_branch, target_branch, merge_mode, - target_repo_path, completion_action, continue_from_task_id, copy_files + target_repo_path, completion_action, continue_from_task_id, copy_files, + branched_from_task_id, conversation_state ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19) RETURNING * "#, ) @@ -1099,6 +1103,8 @@ pub async fn create_task_for_owner( .bind(&completion_action) .bind(&req.continue_from_task_id) .bind(©_files_json) + .bind(&req.branched_from_task_id) + .bind(&req.conversation_history) .fetch_one(pool) .await } @@ -3678,3 +3684,31 @@ pub async fn get_supervisor_conversation_full( ) -> Result<Option<SupervisorState>, sqlx::Error> { get_supervisor_state(pool, contract_id).await } + +// ============================================================================= +// Anonymous Task Cleanup Functions +// ============================================================================= + +/// Delete stale anonymous tasks (tasks with contract_id = NULL) that: +/// - Are in a terminal state (done, failed, merged) +/// - Are older than the specified number of days +/// +/// Returns the number of deleted tasks. +pub async fn cleanup_stale_anonymous_tasks( + pool: &PgPool, + max_age_days: i32, +) -> Result<i64, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM tasks + WHERE contract_id IS NULL + AND status IN ('done', 'failed', 'merged') + AND created_at < NOW() - INTERVAL '1 day' * $1 + "#, + ) + .bind(max_age_days) + .execute(pool) + .await?; + + Ok(result.rows_affected() as i64) +} |
