diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/db/models.rs | 589 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 1393 | ||||
| -rw-r--r-- | makima/src/llm/mesh_tools.rs | 1080 | ||||
| -rw-r--r-- | makima/src/llm/mod.rs | 2 | ||||
| -rw-r--r-- | makima/src/llm/tools.rs | 206 | ||||
| -rw-r--r-- | makima/src/server/auth.rs | 1238 | ||||
| -rw-r--r-- | makima/src/server/handlers/api_keys.rs | 282 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 115 | ||||
| -rw-r--r-- | makima/src/server/handlers/files.rs | 53 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh.rs | 1679 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_chat.rs | 2088 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_daemon.rs | 959 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_merge.rs | 441 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_ws.rs | 346 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 7 | ||||
| -rw-r--r-- | makima/src/server/handlers/users.rs | 972 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 59 | ||||
| -rw-r--r-- | makima/src/server/openapi.rs | 96 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 467 |
19 files changed, 12019 insertions, 53 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 617e590..5064b97 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -36,6 +36,16 @@ pub enum BodyElement { Heading { level: u8, text: String }, /// Paragraph text Paragraph { text: String }, + /// Code block with optional language + Code { + language: Option<String>, + content: String, + }, + /// List (ordered or unordered) + List { + ordered: bool, + items: Vec<String>, + }, /// Chart visualization Chart { #[serde(rename = "chartType")] @@ -245,3 +255,582 @@ pub struct RestoreVersionRequest { /// The current version (for optimistic locking) pub current_version: i32, } + +// ============================================================================= +// Mesh/Task Types +// ============================================================================= + +/// Task status for orchestrating Claude Code instances +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum TaskStatus { + Pending, + Running, + Paused, + Blocked, + Done, + Failed, + Merged, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "pending"), + TaskStatus::Running => write!(f, "running"), + TaskStatus::Paused => write!(f, "paused"), + TaskStatus::Blocked => write!(f, "blocked"), + TaskStatus::Done => write!(f, "done"), + TaskStatus::Failed => write!(f, "failed"), + TaskStatus::Merged => write!(f, "merged"), + } + } +} + +impl std::str::FromStr for TaskStatus { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.to_lowercase().as_str() { + "pending" => Ok(TaskStatus::Pending), + "running" => Ok(TaskStatus::Running), + "paused" => Ok(TaskStatus::Paused), + "blocked" => Ok(TaskStatus::Blocked), + "done" => Ok(TaskStatus::Done), + "failed" => Ok(TaskStatus::Failed), + "merged" => Ok(TaskStatus::Merged), + _ => Err(format!("Unknown task status: {}", s)), + } + } +} + +/// Merge mode for task completion +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum MergeMode { + /// Create a PR for review + Pr, + /// Auto-merge to target branch + Auto, + /// Manual merge by user + Manual, +} + +impl std::fmt::Display for MergeMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MergeMode::Pr => write!(f, "pr"), + MergeMode::Auto => write!(f, "auto"), + MergeMode::Manual => write!(f, "manual"), + } + } +} + +impl std::str::FromStr for MergeMode { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.to_lowercase().as_str() { + "pr" => Ok(MergeMode::Pr), + "auto" => Ok(MergeMode::Auto), + "manual" => Ok(MergeMode::Manual), + _ => Err(format!("Unknown merge mode: {}", s)), + } + } +} + +/// Task record from the database +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct Task { + pub id: Uuid, + pub owner_id: Uuid, + pub parent_task_id: Option<Uuid>, + /// Depth in task hierarchy: 0=orchestrator (top-level), 1=subtask (max) + pub depth: i32, + pub name: String, + pub description: Option<String>, + pub status: String, + pub priority: i32, + pub plan: String, + + // Daemon/container info + pub daemon_id: Option<Uuid>, + pub container_id: Option<String>, + pub overlay_path: Option<String>, + + // Repository info + pub repository_url: Option<String>, + pub base_branch: Option<String>, + pub target_branch: Option<String>, + + // Merge settings + pub merge_mode: Option<String>, + pub pr_url: Option<String>, + + // Completion action settings + /// Path to user's local repository (outside ~/.makima) + pub target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr" + pub completion_action: Option<String>, + + // Progress tracking + pub progress_summary: Option<String>, + pub last_output: Option<String>, + pub error_message: Option<String>, + + // Timestamps + pub started_at: Option<DateTime<Utc>>, + pub completed_at: Option<DateTime<Utc>>, + pub version: i32, + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, + + // Task continuation + /// Task ID to continue from (copy worktree from this task when starting). + /// Used for sequential subtask dependencies. + #[serde(skip_serializing_if = "Option::is_none")] + pub continue_from_task_id: Option<Uuid>, + /// Files to copy from parent task's worktree when starting. + #[serde(skip_serializing_if = "Option::is_none")] + pub copy_files: Option<serde_json::Value>, +} + +impl Task { + /// Parse status string to TaskStatus enum + pub fn status_enum(&self) -> Result<TaskStatus, String> { + self.status.parse() + } + + /// Parse merge_mode string to MergeMode enum + pub fn merge_mode_enum(&self) -> Option<Result<MergeMode, String>> { + self.merge_mode.as_ref().map(|s| s.parse()) + } +} + +/// Summary of a task for list views +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskSummary { + pub id: Uuid, + pub parent_task_id: Option<Uuid>, + /// Depth in task hierarchy: 0=orchestrator (top-level), 1=subtask (max) + pub depth: i32, + pub name: String, + pub status: String, + pub priority: i32, + pub progress_summary: Option<String>, + pub subtask_count: i64, + pub version: i32, + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, +} + +/// Response for task list endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskListResponse { + pub tasks: Vec<TaskSummary>, + pub total: i64, +} + +/// Request payload for creating a new task +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateTaskRequest { + /// Name of the task + pub name: String, + /// Optional description + pub description: Option<String>, + /// The plan/instructions for Claude Code + pub plan: String, + /// Parent task ID (for subtasks) + pub parent_task_id: Option<Uuid>, + /// Priority (higher = more urgent) + #[serde(default)] + pub priority: i32, + /// Repository URL + pub repository_url: Option<String>, + /// Base branch for overlay + pub base_branch: Option<String>, + /// Target branch to merge into + pub target_branch: Option<String>, + /// Merge mode (pr, auto, manual) + pub merge_mode: Option<String>, + /// Path to user's local repository (outside ~/.makima) + pub target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr" + pub completion_action: Option<String>, + /// Task ID to continue from (copy worktree from this task when starting) + pub continue_from_task_id: Option<Uuid>, + /// Files to copy from parent task's worktree when starting + #[serde(skip_serializing_if = "Option::is_none")] + pub copy_files: Option<Vec<String>>, +} + +/// Request payload for updating a task +#[derive(Debug, Default, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct UpdateTaskRequest { + pub name: Option<String>, + pub description: Option<String>, + pub plan: Option<String>, + pub status: Option<String>, + pub priority: Option<i32>, + pub progress_summary: Option<String>, + pub last_output: Option<String>, + pub error_message: Option<String>, + pub merge_mode: Option<String>, + pub pr_url: Option<String>, + /// Path to user's local repository (outside ~/.makima) + pub target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr" + pub completion_action: Option<String>, + /// The daemon currently running this task + pub daemon_id: Option<Uuid>, + /// Explicitly clear daemon_id (set to NULL) + #[serde(default)] + pub clear_daemon_id: bool, + /// Version for optimistic locking + pub version: Option<i32>, +} + +/// Task with its subtasks for detail view +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskWithSubtasks { + #[serde(flatten)] + pub task: Task, + pub subtasks: Vec<TaskSummary>, +} + +/// Request to send a message to a running task's stdin. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SendMessageRequest { + /// The message to send to the task's stdin. + pub message: String, +} + +// ============================================================================= +// Daemon Types +// ============================================================================= + +/// Daemon status +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum DaemonStatus { + Connected, + Disconnected, + Unhealthy, +} + +impl std::fmt::Display for DaemonStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DaemonStatus::Connected => write!(f, "connected"), + DaemonStatus::Disconnected => write!(f, "disconnected"), + DaemonStatus::Unhealthy => write!(f, "unhealthy"), + } + } +} + +impl std::str::FromStr for DaemonStatus { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s.to_lowercase().as_str() { + "connected" => Ok(DaemonStatus::Connected), + "disconnected" => Ok(DaemonStatus::Disconnected), + "unhealthy" => Ok(DaemonStatus::Unhealthy), + _ => Err(format!("Unknown daemon status: {}", s)), + } + } +} + +/// Connected daemon record from the database +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct Daemon { + pub id: Uuid, + pub owner_id: Uuid, + pub connection_id: String, + pub hostname: Option<String>, + pub machine_id: Option<String>, + pub max_concurrent_tasks: i32, + pub current_task_count: i32, + pub status: String, + pub last_heartbeat_at: DateTime<Utc>, + pub connected_at: DateTime<Utc>, + pub disconnected_at: Option<DateTime<Utc>>, +} + +impl Daemon { + /// Parse status string to DaemonStatus enum + pub fn status_enum(&self) -> Result<DaemonStatus, String> { + self.status.parse() + } +} + +/// Response for daemon list endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DaemonListResponse { + pub daemons: Vec<Daemon>, + pub total: i64, +} + +/// Response for daemon directories endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DaemonDirectoriesResponse { + /// List of suggested directories from connected daemons + pub directories: Vec<DaemonDirectory>, +} + +/// A suggested directory from a daemon +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DaemonDirectory { + /// Path to the directory + pub path: String, + /// Display label for the directory + pub label: String, + /// Type of directory: "working", "makima", "worktrees" + pub directory_type: String, + /// Daemon hostname this directory is from + pub hostname: Option<String>, + /// Whether the directory already exists (for validation) + #[serde(skip_serializing_if = "Option::is_none")] + pub exists: Option<bool>, +} + +// ============================================================================= +// Task Event Types +// ============================================================================= + +/// Task event record from the database +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskEvent { + pub id: Uuid, + pub task_id: Uuid, + pub event_type: String, + pub previous_status: Option<String>, + pub new_status: Option<String>, + #[sqlx(json)] + pub event_data: Option<serde_json::Value>, + pub created_at: DateTime<Utc>, +} + +/// Response for task events list endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskEventListResponse { + pub events: Vec<TaskEvent>, + pub total: i64, +} + +/// A single output entry from a Claude Code task +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskOutputEntry { + pub id: Uuid, + pub task_id: Uuid, + /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + pub message_type: String, + /// Main text content + pub content: String, + /// Tool name if tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<String>, + /// Tool input JSON if tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_input: Option<serde_json::Value>, + /// Whether tool result was an error + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option<bool>, + /// Cost in USD if result message + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_usd: Option<f64>, + /// Duration in ms if result message + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option<u64>, + /// Timestamp when this output was recorded + pub created_at: DateTime<Utc>, +} + +impl TaskOutputEntry { + /// Convert a TaskEvent with event_type='output' to a TaskOutputEntry + pub fn from_task_event(event: TaskEvent) -> Option<Self> { + if event.event_type != "output" { + return None; + } + let data = event.event_data?; + Some(Self { + id: event.id, + task_id: event.task_id, + message_type: data.get("messageType")?.as_str()?.to_string(), + content: data.get("content")?.as_str().unwrap_or("").to_string(), + tool_name: data.get("toolName").and_then(|v| v.as_str()).map(|s| s.to_string()), + tool_input: data.get("toolInput").cloned(), + is_error: data.get("isError").and_then(|v| v.as_bool()), + cost_usd: data.get("costUsd").and_then(|v| v.as_f64()), + duration_ms: data.get("durationMs").and_then(|v| v.as_u64()), + created_at: event.created_at, + }) + } +} + +/// Response for task output history endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct TaskOutputResponse { + pub entries: Vec<TaskOutputEntry>, + pub total: usize, + pub task_id: Uuid, +} + +// ============================================================================= +// Mesh Chat History Types +// ============================================================================= + +/// Mesh chat conversation for persisting history +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatConversation { + pub id: Uuid, + pub owner_id: Uuid, + pub name: Option<String>, + pub is_active: bool, + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, +} + +/// Individual message in a mesh chat conversation +#[derive(Debug, Clone, FromRow, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatMessageRecord { + pub id: Uuid, + pub conversation_id: Uuid, + pub role: String, + pub content: String, + pub context_type: String, + pub context_task_id: Option<Uuid>, + /// Tool calls made during this message (JSON, nullable) + pub tool_calls: Option<serde_json::Value>, + /// Pending questions requiring user response (JSON, nullable) + pub pending_questions: Option<serde_json::Value>, + pub created_at: DateTime<Utc>, +} + +/// Response for chat history endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatHistoryResponse { + pub conversation_id: Uuid, + pub messages: Vec<MeshChatMessageRecord>, +} + +// ============================================================================= +// Merge API Types +// ============================================================================= + +/// Information about a task branch +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct BranchInfo { + /// Full branch name + pub name: String, + /// Task ID extracted from branch name (if parseable) + pub task_id: Option<Uuid>, + /// Whether this branch has been merged + pub is_merged: bool, + /// Short SHA of the last commit + pub last_commit: String, + /// Subject line of the last commit + pub last_commit_message: String, +} + +/// Response for branch list endpoint +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct BranchListResponse { + pub branches: Vec<BranchInfo>, +} + +/// Request to start a merge +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeStartRequest { + /// Branch name to merge + pub source_branch: String, +} + +/// Current merge state +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeStatusResponse { + /// Whether a merge is in progress + pub in_progress: bool, + /// Branch being merged (if in progress) + pub source_branch: Option<String>, + /// Files with unresolved conflicts + pub conflicted_files: Vec<String>, +} + +/// Request to resolve a conflict +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeResolveRequest { + /// File path to resolve + pub file: String, + /// Resolution strategy: "ours" or "theirs" + pub strategy: String, +} + +/// Request to commit a merge +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeCommitRequest { + /// Commit message + pub message: String, +} + +/// Request to skip a subtask branch +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeSkipRequest { + /// Subtask ID to skip + pub subtask_id: Uuid, + /// Reason for skipping + pub reason: String, +} + +/// Result of a merge operation +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeResultResponse { + /// Whether the operation succeeded + pub success: bool, + /// Human-readable message + pub message: String, + /// Commit SHA (if a commit was created) + pub commit_sha: Option<String>, + /// Conflicted files (if conflicts occurred) + pub conflicts: Option<Vec<String>>, +} + +/// Response to check if all branches are merged +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MergeCompleteCheckResponse { + /// Whether the orchestrator can mark itself as complete + pub can_complete: bool, + /// Branches not yet merged or skipped + pub unmerged_branches: Vec<String>, + /// Count of merged branches + pub merged_count: u32, + /// Count of skipped branches + pub skipped_count: u32, +} diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 4137ba6..ce1e97d 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -4,10 +4,10 @@ use chrono::Utc; use sqlx::PgPool; use uuid::Uuid; -use super::models::{CreateFileRequest, File, FileVersion, UpdateFileRequest}; - -/// Default owner ID for anonymous users. -pub const ANONYMOUS_OWNER_ID: Uuid = Uuid::from_u128(0x00000000_0000_0000_0000_000000000002); +use super::models::{ + CreateFileRequest, CreateTaskRequest, Daemon, File, FileVersion, MeshChatConversation, + MeshChatMessageRecord, Task, TaskEvent, TaskSummary, UpdateFileRequest, UpdateTaskRequest, +}; /// Repository error types. #[derive(Debug)] @@ -60,12 +60,11 @@ pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, sqlx::query_as::<_, File>( r#" - INSERT INTO files (owner_id, name, description, transcript, location, summary, body) - VALUES ($1, $2, $3, $4, $5, NULL, $6) + INSERT INTO files (name, description, transcript, location, summary, body) + VALUES ($1, $2, $3, $4, NULL, $5) RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at "#, ) - .bind(ANONYMOUS_OWNER_ID) .bind(&name) .bind(&req.description) .bind(&transcript_json) @@ -81,26 +80,23 @@ pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Err r#" SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at FROM files - WHERE id = $1 AND owner_id = $2 + WHERE id = $1 "#, ) .bind(id) - .bind(ANONYMOUS_OWNER_ID) .fetch_optional(pool) .await } -/// List all files for the owner, ordered by created_at DESC. +/// List all files, ordered by created_at DESC. pub async fn list_files(pool: &PgPool) -> Result<Vec<File>, sqlx::Error> { sqlx::query_as::<_, File>( r#" SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at FROM files - WHERE owner_id = $1 ORDER BY created_at DESC "#, ) - .bind(ANONYMOUS_OWNER_ID) .fetch_all(pool) .await } @@ -146,13 +142,12 @@ pub async fn update_file( sqlx::query_as::<_, File>( r#" UPDATE files - SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() - WHERE id = $1 AND owner_id = $2 AND version = $8 + SET name = $2, description = $3, transcript = $4, summary = $5, body = $6, updated_at = NOW() + WHERE id = $1 AND version = $7 RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at "#, ) .bind(id) - .bind(ANONYMOUS_OWNER_ID) .bind(&name) .bind(&description) .bind(&transcript_json) @@ -166,13 +161,12 @@ pub async fn update_file( sqlx::query_as::<_, File>( r#" UPDATE files - SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() - WHERE id = $1 AND owner_id = $2 + SET name = $2, description = $3, transcript = $4, summary = $5, body = $6, updated_at = NOW() + WHERE id = $1 RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at "#, ) .bind(id) - .bind(ANONYMOUS_OWNER_ID) .bind(&name) .bind(&description) .bind(&transcript_json) @@ -201,21 +195,19 @@ pub async fn delete_file(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> { let result = sqlx::query( r#" DELETE FROM files - WHERE id = $1 AND owner_id = $2 + WHERE id = $1 "#, ) .bind(id) - .bind(ANONYMOUS_OWNER_ID) .execute(pool) .await?; Ok(result.rows_affected() > 0) } -/// Count total files for owner. +/// Count total files. pub async fn count_files(pool: &PgPool) -> Result<i64, sqlx::Error> { - let result: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM files WHERE owner_id = $1") - .bind(ANONYMOUS_OWNER_ID) + let result: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM files") .fetch_one(pool) .await?; @@ -223,6 +215,178 @@ pub async fn count_files(pool: &PgPool) -> Result<i64, sqlx::Error> { } // ============================================================================= +// Owner-Scoped File Functions +// ============================================================================= + +/// Create a new file record for a specific owner. +pub async fn create_file_for_owner( + pool: &PgPool, + owner_id: Uuid, + req: CreateFileRequest, +) -> Result<File, sqlx::Error> { + let name = req.name.unwrap_or_else(generate_default_name); + let transcript_json = serde_json::to_value(&req.transcript).unwrap_or_default(); + let body_json = serde_json::to_value::<Vec<super::models::BodyElement>>(vec![]).unwrap(); + + sqlx::query_as::<_, File>( + r#" + INSERT INTO files (owner_id, name, description, transcript, location, summary, body) + VALUES ($1, $2, $3, $4, $5, NULL, $6) + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + "#, + ) + .bind(owner_id) + .bind(&name) + .bind(&req.description) + .bind(&transcript_json) + .bind(&req.location) + .bind(&body_json) + .fetch_one(pool) + .await +} + +/// Get a file by ID, scoped to owner. +pub async fn get_file_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<Option<File>, sqlx::Error> { + sqlx::query_as::<_, File>( + r#" + SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + FROM files + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .fetch_optional(pool) + .await +} + +/// List all files for an owner, ordered by created_at DESC. +pub async fn list_files_for_owner(pool: &PgPool, owner_id: Uuid) -> Result<Vec<File>, sqlx::Error> { + sqlx::query_as::<_, File>( + r#" + SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + FROM files + WHERE owner_id = $1 + ORDER BY created_at DESC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// Update a file by ID with optimistic locking, scoped to owner. +pub async fn update_file_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, + req: UpdateFileRequest, +) -> Result<Option<File>, RepositoryError> { + // Get the existing file first (scoped to owner) + let existing = get_file_for_owner(pool, id, owner_id).await?; + let Some(existing) = existing else { + return Ok(None); + }; + + // Check version if provided (optimistic locking) + if let Some(expected_version) = req.version { + if existing.version != expected_version { + return Err(RepositoryError::VersionConflict { + expected: expected_version, + actual: existing.version, + }); + } + } + + // Apply updates + let name = req.name.unwrap_or(existing.name); + let description = req.description.or(existing.description); + let transcript = req.transcript.unwrap_or(existing.transcript); + let transcript_json = serde_json::to_value(&transcript).unwrap_or_default(); + let summary = req.summary.or(existing.summary); + let body = req.body.unwrap_or(existing.body); + let body_json = serde_json::to_value(&body).unwrap_or_default(); + + // Update with version check in WHERE clause for race condition safety + let result = if req.version.is_some() { + sqlx::query_as::<_, File>( + r#" + UPDATE files + SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 AND version = $8 + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + "#, + ) + .bind(id) + .bind(owner_id) + .bind(&name) + .bind(&description) + .bind(&transcript_json) + .bind(&summary) + .bind(&body_json) + .bind(req.version.unwrap()) + .fetch_optional(pool) + .await? + } else { + // No version check for internal updates + sqlx::query_as::<_, File>( + r#" + UPDATE files + SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + "#, + ) + .bind(id) + .bind(owner_id) + .bind(&name) + .bind(&description) + .bind(&transcript_json) + .bind(&summary) + .bind(&body_json) + .fetch_optional(pool) + .await? + }; + + // If versioned update returned None, there was a race condition + if result.is_none() && req.version.is_some() { + // Re-fetch to get the actual version + if let Some(current) = get_file_for_owner(pool, id, owner_id).await? { + return Err(RepositoryError::VersionConflict { + expected: req.version.unwrap(), + actual: current.version, + }); + } + } + + Ok(result) +} + +/// Delete a file by ID, scoped to owner. +pub async fn delete_file_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM files + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +// ============================================================================= // Version History Functions // ============================================================================= @@ -363,3 +527,1186 @@ pub async fn count_file_versions(pool: &PgPool, file_id: Uuid) -> Result<i64, sq Ok(result.0) } + +// ============================================================================= +// Task Functions +// ============================================================================= + +/// Create a new task. +/// +/// If creating a subtask (parent_task_id is set) and repository settings are not provided, +/// the subtask will inherit repository_url, base_branch, target_branch, merge_mode, +/// and target_repo_path from the parent task. Depth is calculated from parent and limited +/// to max 1 (2 levels: orchestrator at depth 0, subtasks at depth 1). +/// +/// NOTE: completion_action is NOT inherited - subtasks should not auto-merge unless +/// explicitly configured. The orchestrator controls when completion steps happen. +pub async fn create_task(pool: &PgPool, req: CreateTaskRequest) -> Result<Task, sqlx::Error> { + // Calculate depth and inherit settings from parent if applicable + let (depth, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) = + if let Some(parent_id) = req.parent_task_id { + // Fetch parent task to get depth and inherit repo settings + let parent = get_task(pool, parent_id).await? + .ok_or_else(|| sqlx::Error::RowNotFound)?; + + let new_depth = parent.depth + 1; + + // Validate max depth (must be < 2, i.e., 0 or 1 only) + // Orchestrators are at depth 0, subtasks at depth 1 + // Subtasks cannot have their own children + if new_depth >= 2 { + return Err(sqlx::Error::Protocol(format!( + "Maximum task depth exceeded. Cannot create subtask at depth {} (max is 1). Subtasks cannot have children.", + new_depth + ))); + } + + // Inherit repo settings if not provided + let repo_url = req.repository_url.clone().or(parent.repository_url); + let base_branch = req.base_branch.clone().or(parent.base_branch); + let target_branch = req.target_branch.clone().or(parent.target_branch); + let merge_mode = req.merge_mode.clone().or(parent.merge_mode); + let target_repo_path = req.target_repo_path.clone().or(parent.target_repo_path); + // NOTE: completion_action is NOT inherited - subtasks should not auto-merge. + // The orchestrator integrates subtask work from their worktrees. + let completion_action = req.completion_action.clone(); + + (new_depth, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) + } else { + // Top-level task: depth 0 + ( + 0, + req.repository_url.clone(), + req.base_branch.clone(), + req.target_branch.clone(), + req.merge_mode.clone(), + req.target_repo_path.clone(), + req.completion_action.clone(), + ) + }; + + let copy_files_json = req.copy_files.as_ref().map(|f| serde_json::to_value(f).unwrap_or_default()); + + sqlx::query_as::<_, Task>( + r#" + INSERT INTO tasks ( + parent_task_id, depth, name, description, plan, priority, + repository_url, base_branch, target_branch, merge_mode, + target_repo_path, completion_action, continue_from_task_id, copy_files + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + RETURNING * + "#, + ) + .bind(req.parent_task_id) + .bind(depth) + .bind(&req.name) + .bind(&req.description) + .bind(&req.plan) + .bind(req.priority) + .bind(&repo_url) + .bind(&base_branch) + .bind(&target_branch) + .bind(&merge_mode) + .bind(&target_repo_path) + .bind(&completion_action) + .bind(&req.continue_from_task_id) + .bind(©_files_json) + .fetch_one(pool) + .await +} + +/// Get a task by ID. +pub async fn get_task(pool: &PgPool, id: Uuid) -> Result<Option<Task>, sqlx::Error> { + sqlx::query_as::<_, Task>( + r#" + SELECT * + FROM tasks + WHERE id = $1 + "#, + ) + .bind(id) + .fetch_optional(pool) + .await +} + +/// List all top-level tasks (no parent), ordered by created_at DESC. +pub async fn list_tasks(pool: &PgPool) -> Result<Vec<TaskSummary>, sqlx::Error> { + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.parent_task_id IS NULL + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .fetch_all(pool) + .await +} + +/// List subtasks of a parent task. +pub async fn list_subtasks(pool: &PgPool, parent_id: Uuid) -> Result<Vec<TaskSummary>, sqlx::Error> { + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.parent_task_id = $1 + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .bind(parent_id) + .fetch_all(pool) + .await +} + +/// Update a task by ID with optimistic locking. +pub async fn update_task( + pool: &PgPool, + id: Uuid, + req: UpdateTaskRequest, +) -> Result<Option<Task>, RepositoryError> { + // Get the existing task first + let existing = get_task(pool, id).await?; + let Some(existing) = existing else { + return Ok(None); + }; + + // Check version if provided (optimistic locking) + if let Some(expected_version) = req.version { + if existing.version != expected_version { + return Err(RepositoryError::VersionConflict { + expected: expected_version, + actual: existing.version, + }); + } + } + + // Apply updates + let name = req.name.unwrap_or(existing.name); + let description = req.description.or(existing.description); + let plan = req.plan.unwrap_or(existing.plan); + let status = req.status.unwrap_or(existing.status); + let priority = req.priority.unwrap_or(existing.priority); + let progress_summary = req.progress_summary.or(existing.progress_summary); + let last_output = req.last_output.or(existing.last_output); + let error_message = req.error_message.or(existing.error_message); + let merge_mode = req.merge_mode.or(existing.merge_mode); + let pr_url = req.pr_url.or(existing.pr_url); + let target_repo_path = req.target_repo_path.or(existing.target_repo_path); + let completion_action = req.completion_action.or(existing.completion_action); + // Handle clear_daemon_id: if true, set to NULL; otherwise use provided value or keep existing + let daemon_id = if req.clear_daemon_id { + None + } else { + req.daemon_id.or(existing.daemon_id) + }; + + // Update with version check in WHERE clause for race condition safety + let result = if req.version.is_some() { + sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET name = $2, description = $3, plan = $4, status = $5, priority = $6, + progress_summary = $7, last_output = $8, error_message = $9, + merge_mode = $10, pr_url = $11, daemon_id = $12, + target_repo_path = $13, completion_action = $14, updated_at = NOW() + WHERE id = $1 AND version = $15 + RETURNING * + "#, + ) + .bind(id) + .bind(&name) + .bind(&description) + .bind(&plan) + .bind(&status) + .bind(priority) + .bind(&progress_summary) + .bind(&last_output) + .bind(&error_message) + .bind(&merge_mode) + .bind(&pr_url) + .bind(daemon_id) + .bind(&target_repo_path) + .bind(&completion_action) + .bind(req.version.unwrap()) + .fetch_optional(pool) + .await? + } else { + sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET name = $2, description = $3, plan = $4, status = $5, priority = $6, + progress_summary = $7, last_output = $8, error_message = $9, + merge_mode = $10, pr_url = $11, daemon_id = $12, + target_repo_path = $13, completion_action = $14, updated_at = NOW() + WHERE id = $1 + RETURNING * + "#, + ) + .bind(id) + .bind(&name) + .bind(&description) + .bind(&plan) + .bind(&status) + .bind(priority) + .bind(&progress_summary) + .bind(&last_output) + .bind(&error_message) + .bind(&merge_mode) + .bind(&pr_url) + .bind(daemon_id) + .bind(&target_repo_path) + .bind(&completion_action) + .fetch_optional(pool) + .await? + }; + + // If versioned update returned None, there was a race condition + if result.is_none() && req.version.is_some() { + if let Some(current) = get_task(pool, id).await? { + return Err(RepositoryError::VersionConflict { + expected: req.version.unwrap(), + actual: current.version, + }); + } + } + + Ok(result) +} + +/// Delete a task by ID. +pub async fn delete_task(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM tasks + WHERE id = $1 + "#, + ) + .bind(id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Count total tasks. +pub async fn count_tasks(pool: &PgPool) -> Result<i64, sqlx::Error> { + let result: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM tasks WHERE parent_task_id IS NULL", + ) + .fetch_one(pool) + .await?; + + Ok(result.0) +} + +// ============================================================================= +// Owner-Scoped Task Functions +// ============================================================================= + +/// Create a new task for a specific owner. +pub async fn create_task_for_owner( + pool: &PgPool, + owner_id: Uuid, + req: CreateTaskRequest, +) -> Result<Task, sqlx::Error> { + // Calculate depth and inherit settings from parent if applicable + let (depth, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) = + if let Some(parent_id) = req.parent_task_id { + // Fetch parent task to get depth and inherit repo settings (must belong to same owner) + let parent = get_task_for_owner(pool, parent_id, owner_id).await? + .ok_or_else(|| sqlx::Error::RowNotFound)?; + + let new_depth = parent.depth + 1; + + // Validate max depth + if new_depth >= 2 { + return Err(sqlx::Error::Protocol(format!( + "Maximum task depth exceeded. Cannot create subtask at depth {} (max is 1). Subtasks cannot have children.", + new_depth + ))); + } + + // Inherit repo settings if not provided + let repo_url = req.repository_url.clone().or(parent.repository_url); + let base_branch = req.base_branch.clone().or(parent.base_branch); + let target_branch = req.target_branch.clone().or(parent.target_branch); + let merge_mode = req.merge_mode.clone().or(parent.merge_mode); + let target_repo_path = req.target_repo_path.clone().or(parent.target_repo_path); + // NOTE: completion_action is NOT inherited - subtasks should not auto-merge. + // The orchestrator integrates subtask work from their worktrees. + let completion_action = req.completion_action.clone(); + + (new_depth, repo_url, base_branch, target_branch, merge_mode, target_repo_path, completion_action) + } else { + // Top-level task: depth 0 + ( + 0, + req.repository_url.clone(), + req.base_branch.clone(), + req.target_branch.clone(), + req.merge_mode.clone(), + req.target_repo_path.clone(), + req.completion_action.clone(), + ) + }; + + let copy_files_json = req.copy_files.as_ref().map(|f| serde_json::to_value(f).unwrap_or_default()); + + sqlx::query_as::<_, Task>( + r#" + INSERT INTO tasks ( + owner_id, parent_task_id, depth, name, description, plan, priority, + repository_url, base_branch, target_branch, merge_mode, + target_repo_path, completion_action, continue_from_task_id, copy_files + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + RETURNING * + "#, + ) + .bind(owner_id) + .bind(req.parent_task_id) + .bind(depth) + .bind(&req.name) + .bind(&req.description) + .bind(&req.plan) + .bind(req.priority) + .bind(&repo_url) + .bind(&base_branch) + .bind(&target_branch) + .bind(&merge_mode) + .bind(&target_repo_path) + .bind(&completion_action) + .bind(&req.continue_from_task_id) + .bind(©_files_json) + .fetch_one(pool) + .await +} + +/// Get a task by ID, scoped to owner. +pub async fn get_task_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<Option<Task>, sqlx::Error> { + sqlx::query_as::<_, Task>( + r#" + SELECT * + FROM tasks + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .fetch_optional(pool) + .await +} + +/// List all top-level tasks (no parent) for an owner, ordered by created_at DESC. +pub async fn list_tasks_for_owner( + pool: &PgPool, + owner_id: Uuid, +) -> Result<Vec<TaskSummary>, sqlx::Error> { + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.owner_id = $1 AND t.parent_task_id IS NULL + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// List subtasks of a parent task, scoped to owner. +pub async fn list_subtasks_for_owner( + pool: &PgPool, + parent_id: Uuid, + owner_id: Uuid, +) -> Result<Vec<TaskSummary>, sqlx::Error> { + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.owner_id = $1 AND t.parent_task_id = $2 + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .bind(owner_id) + .bind(parent_id) + .fetch_all(pool) + .await +} + +/// Update a task by ID with optimistic locking, scoped to owner. +pub async fn update_task_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, + req: UpdateTaskRequest, +) -> Result<Option<Task>, RepositoryError> { + // Get the existing task first (scoped to owner) + let existing = get_task_for_owner(pool, id, owner_id).await?; + let Some(existing) = existing else { + return Ok(None); + }; + + // Check version if provided (optimistic locking) + if let Some(expected_version) = req.version { + if existing.version != expected_version { + return Err(RepositoryError::VersionConflict { + expected: expected_version, + actual: existing.version, + }); + } + } + + // Apply updates + let name = req.name.unwrap_or(existing.name); + let description = req.description.or(existing.description); + let plan = req.plan.unwrap_or(existing.plan); + let status = req.status.unwrap_or(existing.status); + let priority = req.priority.unwrap_or(existing.priority); + let progress_summary = req.progress_summary.or(existing.progress_summary); + let last_output = req.last_output.or(existing.last_output); + let error_message = req.error_message.or(existing.error_message); + let merge_mode = req.merge_mode.or(existing.merge_mode); + let pr_url = req.pr_url.or(existing.pr_url); + let target_repo_path = req.target_repo_path.or(existing.target_repo_path); + let completion_action = req.completion_action.or(existing.completion_action); + let daemon_id = if req.clear_daemon_id { + None + } else { + req.daemon_id.or(existing.daemon_id) + }; + + // Update with version check in WHERE clause for race condition safety + let result = if req.version.is_some() { + sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET name = $3, description = $4, plan = $5, status = $6, priority = $7, + progress_summary = $8, last_output = $9, error_message = $10, + merge_mode = $11, pr_url = $12, daemon_id = $13, + target_repo_path = $14, completion_action = $15, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 AND version = $16 + RETURNING * + "#, + ) + .bind(id) + .bind(owner_id) + .bind(&name) + .bind(&description) + .bind(&plan) + .bind(&status) + .bind(priority) + .bind(&progress_summary) + .bind(&last_output) + .bind(&error_message) + .bind(&merge_mode) + .bind(&pr_url) + .bind(daemon_id) + .bind(&target_repo_path) + .bind(&completion_action) + .bind(req.version.unwrap()) + .fetch_optional(pool) + .await? + } else { + sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET name = $3, description = $4, plan = $5, status = $6, priority = $7, + progress_summary = $8, last_output = $9, error_message = $10, + merge_mode = $11, pr_url = $12, daemon_id = $13, + target_repo_path = $14, completion_action = $15, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + RETURNING * + "#, + ) + .bind(id) + .bind(owner_id) + .bind(&name) + .bind(&description) + .bind(&plan) + .bind(&status) + .bind(priority) + .bind(&progress_summary) + .bind(&last_output) + .bind(&error_message) + .bind(&merge_mode) + .bind(&pr_url) + .bind(daemon_id) + .bind(&target_repo_path) + .bind(&completion_action) + .fetch_optional(pool) + .await? + }; + + // If versioned update returned None, there was a race condition + if result.is_none() && req.version.is_some() { + if let Some(current) = get_task_for_owner(pool, id, owner_id).await? { + return Err(RepositoryError::VersionConflict { + expected: req.version.unwrap(), + actual: current.version, + }); + } + } + + Ok(result) +} + +/// Delete a task by ID, scoped to owner. +pub async fn delete_task_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM tasks + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Update task status and record event. +pub async fn update_task_status( + pool: &PgPool, + id: Uuid, + new_status: &str, + event_data: Option<serde_json::Value>, +) -> Result<Option<Task>, sqlx::Error> { + // Get existing status + let existing = get_task(pool, id).await?; + let Some(existing) = existing else { + return Ok(None); + }; + + let previous_status = existing.status.clone(); + + // Update task status + let task = sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET status = $2, updated_at = NOW(), + started_at = CASE WHEN $2 = 'running' AND started_at IS NULL THEN NOW() ELSE started_at END, + completed_at = CASE WHEN $2 IN ('done', 'failed', 'merged') THEN NOW() ELSE completed_at END + WHERE id = $1 + RETURNING * + "#, + ) + .bind(id) + .bind(new_status) + .fetch_optional(pool) + .await?; + + // Record event + if task.is_some() { + let _ = create_task_event( + pool, + id, + "status_change", + Some(&previous_status), + Some(new_status), + event_data, + ) + .await; + } + + Ok(task) +} + +// ============================================================================= +// Task Event Functions +// ============================================================================= + +/// Create a task event. +pub async fn create_task_event( + pool: &PgPool, + task_id: Uuid, + event_type: &str, + previous_status: Option<&str>, + new_status: Option<&str>, + event_data: Option<serde_json::Value>, +) -> Result<TaskEvent, sqlx::Error> { + sqlx::query_as::<_, TaskEvent>( + r#" + INSERT INTO task_events (task_id, event_type, previous_status, new_status, event_data) + VALUES ($1, $2, $3, $4, $5) + RETURNING * + "#, + ) + .bind(task_id) + .bind(event_type) + .bind(previous_status) + .bind(new_status) + .bind(event_data) + .fetch_one(pool) + .await +} + +/// List events for a task. +pub async fn list_task_events( + pool: &PgPool, + task_id: Uuid, + limit: Option<i64>, +) -> Result<Vec<TaskEvent>, sqlx::Error> { + let limit = limit.unwrap_or(100); + sqlx::query_as::<_, TaskEvent>( + r#" + SELECT * + FROM task_events + WHERE task_id = $1 + ORDER BY created_at DESC + LIMIT $2 + "#, + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await +} + +// ============================================================================= +// Daemon Functions +// ============================================================================= + +/// Register a new daemon connection. +pub async fn register_daemon( + pool: &PgPool, + owner_id: Uuid, + connection_id: &str, + hostname: Option<&str>, + machine_id: Option<&str>, + max_concurrent_tasks: i32, +) -> Result<Daemon, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + INSERT INTO daemons (owner_id, connection_id, hostname, machine_id, max_concurrent_tasks) + VALUES ($1, $2, $3, $4, $5) + RETURNING * + "#, + ) + .bind(owner_id) + .bind(connection_id) + .bind(hostname) + .bind(machine_id) + .bind(max_concurrent_tasks) + .fetch_one(pool) + .await +} + +/// Get a daemon by ID. +pub async fn get_daemon(pool: &PgPool, id: Uuid) -> Result<Option<Daemon>, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + SELECT * + FROM daemons + WHERE id = $1 + "#, + ) + .bind(id) + .fetch_optional(pool) + .await +} + +/// Get a daemon by connection ID. +pub async fn get_daemon_by_connection( + pool: &PgPool, + connection_id: &str, +) -> Result<Option<Daemon>, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + SELECT * + FROM daemons + WHERE connection_id = $1 + "#, + ) + .bind(connection_id) + .fetch_optional(pool) + .await +} + +/// List all daemons. +pub async fn list_daemons(pool: &PgPool) -> Result<Vec<Daemon>, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + SELECT * + FROM daemons + ORDER BY connected_at DESC + "#, + ) + .fetch_all(pool) + .await +} + +/// List daemons for a specific owner. +pub async fn list_daemons_for_owner(pool: &PgPool, owner_id: Uuid) -> Result<Vec<Daemon>, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + SELECT * + FROM daemons + WHERE owner_id = $1 + ORDER BY connected_at DESC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// Get a daemon by ID for a specific owner. +pub async fn get_daemon_for_owner(pool: &PgPool, id: Uuid, owner_id: Uuid) -> Result<Option<Daemon>, sqlx::Error> { + sqlx::query_as::<_, Daemon>( + r#" + SELECT * + FROM daemons + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .fetch_optional(pool) + .await +} + +/// Update daemon heartbeat. +pub async fn update_daemon_heartbeat(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + UPDATE daemons + SET last_heartbeat_at = NOW(), status = 'connected' + WHERE id = $1 + "#, + ) + .bind(id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Update daemon status. +pub async fn update_daemon_status( + pool: &PgPool, + id: Uuid, + status: &str, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + UPDATE daemons + SET status = $2, + disconnected_at = CASE WHEN $2 = 'disconnected' THEN NOW() ELSE disconnected_at END + WHERE id = $1 + "#, + ) + .bind(id) + .bind(status) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Update daemon task count. +pub async fn update_daemon_task_count( + pool: &PgPool, + id: Uuid, + delta: i32, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + UPDATE daemons + SET current_task_count = GREATEST(0, current_task_count + $2) + WHERE id = $1 + "#, + ) + .bind(id) + .bind(delta) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Delete a daemon by ID. +pub async fn delete_daemon(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM daemons + WHERE id = $1 + "#, + ) + .bind(id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Delete a daemon by connection ID. +pub async fn delete_daemon_by_connection( + pool: &PgPool, + connection_id: &str, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM daemons + WHERE connection_id = $1 + "#, + ) + .bind(connection_id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Count connected daemons. +pub async fn count_daemons(pool: &PgPool) -> Result<i64, sqlx::Error> { + let result: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM daemons WHERE status = 'connected'", + ) + .fetch_one(pool) + .await?; + + Ok(result.0) +} + +// ============================================================================= +// Sibling Awareness Functions +// ============================================================================= + +/// List sibling tasks (tasks with the same parent, excluding the given task). +pub async fn list_sibling_tasks( + pool: &PgPool, + task_id: Uuid, + parent_id: Option<Uuid>, +) -> Result<Vec<TaskSummary>, sqlx::Error> { + match parent_id { + Some(parent) => { + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.parent_task_id = $1 AND t.id != $2 + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .bind(parent) + .bind(task_id) + .fetch_all(pool) + .await + } + None => { + // Top-level tasks (no parent) - siblings are other top-level tasks + sqlx::query_as::<_, TaskSummary>( + r#" + SELECT + t.id, t.parent_task_id, t.depth, t.name, t.status, t.priority, + t.progress_summary, t.version, t.created_at, t.updated_at, + (SELECT COUNT(*) FROM tasks WHERE parent_task_id = t.id) as subtask_count + FROM tasks t + WHERE t.parent_task_id IS NULL AND t.id != $1 + ORDER BY t.priority DESC, t.created_at DESC + "#, + ) + .bind(task_id) + .fetch_all(pool) + .await + } + } +} + +/// Get running sibling tasks (for context injection). +pub async fn get_running_siblings( + pool: &PgPool, + owner_id: Uuid, + task_id: Uuid, + parent_id: Option<Uuid>, +) -> Result<Vec<Task>, sqlx::Error> { + match parent_id { + Some(parent) => { + sqlx::query_as::<_, Task>( + r#" + SELECT * + FROM tasks t + WHERE t.owner_id = $1 + AND t.parent_task_id = $2 + AND t.id != $3 + AND t.status = 'running' + ORDER BY t.priority DESC + "#, + ) + .bind(owner_id) + .bind(parent) + .bind(task_id) + .fetch_all(pool) + .await + } + None => { + sqlx::query_as::<_, Task>( + r#" + SELECT * + FROM tasks t + WHERE t.owner_id = $1 + AND t.parent_task_id IS NULL + AND t.id != $2 + AND t.status = 'running' + ORDER BY t.priority DESC + "#, + ) + .bind(owner_id) + .bind(task_id) + .fetch_all(pool) + .await + } + } +} + +/// Get task with its siblings for context awareness. +pub async fn get_task_with_siblings( + pool: &PgPool, + id: Uuid, +) -> Result<Option<(Task, Vec<TaskSummary>)>, sqlx::Error> { + let task = get_task(pool, id).await?; + let Some(task) = task else { + return Ok(None); + }; + + let siblings = list_sibling_tasks(pool, id, task.parent_task_id).await?; + Ok(Some((task, siblings))) +} + +// ============================================================================= +// Task Output Persistence Functions +// ============================================================================= + +/// Save task output to the database. +/// This stores output in the task_events table with event_type='output'. +pub async fn save_task_output( + pool: &PgPool, + task_id: Uuid, + message_type: &str, + content: &str, + tool_name: Option<&str>, + tool_input: Option<serde_json::Value>, + is_error: Option<bool>, + cost_usd: Option<f64>, + duration_ms: Option<u64>, +) -> Result<TaskEvent, sqlx::Error> { + let event_data = serde_json::json!({ + "messageType": message_type, + "content": content, + "toolName": tool_name, + "toolInput": tool_input, + "isError": is_error, + "costUsd": cost_usd, + "durationMs": duration_ms, + }); + + create_task_event(pool, task_id, "output", None, None, Some(event_data)).await +} + +/// Get task output from the database. +/// Retrieves all output events for a task, ordered by creation time. +pub async fn get_task_output( + pool: &PgPool, + task_id: Uuid, + limit: Option<i64>, +) -> Result<Vec<TaskEvent>, sqlx::Error> { + let limit = limit.unwrap_or(1000); + sqlx::query_as::<_, TaskEvent>( + r#" + SELECT * + FROM task_events + WHERE task_id = $1 AND event_type = 'output' + ORDER BY created_at ASC + LIMIT $2 + "#, + ) + .bind(task_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Update task completion status with error message. +/// Sets the task status to 'done' or 'failed' and records completion time. +pub async fn complete_task( + pool: &PgPool, + task_id: Uuid, + success: bool, + error_message: Option<&str>, +) -> Result<Option<Task>, sqlx::Error> { + let status = if success { "done" } else { "failed" }; + + let task = sqlx::query_as::<_, Task>( + r#" + UPDATE tasks + SET status = $2, + error_message = COALESCE($3, error_message), + completed_at = NOW(), + updated_at = NOW() + WHERE id = $1 + RETURNING * + "#, + ) + .bind(task_id) + .bind(status) + .bind(error_message) + .fetch_optional(pool) + .await?; + + // Record completion event + if task.is_some() { + let event_data = serde_json::json!({ + "success": success, + "errorMessage": error_message, + }); + let _ = create_task_event( + pool, + task_id, + "complete", + Some("running"), + Some(status), + Some(event_data), + ) + .await; + } + + Ok(task) +} + +// ============================================================================= +// Mesh Chat History Functions +// ============================================================================= + +/// Get or create the active conversation for an owner. +pub async fn get_or_create_active_conversation( + pool: &PgPool, + owner_id: Uuid, +) -> Result<MeshChatConversation, sqlx::Error> { + // Try to get existing active conversation for this owner + let existing = sqlx::query_as::<_, MeshChatConversation>( + r#" + SELECT * + FROM mesh_chat_conversations + WHERE is_active = true AND owner_id = $1 + LIMIT 1 + "#, + ) + .bind(owner_id) + .fetch_optional(pool) + .await?; + + if let Some(conv) = existing { + return Ok(conv); + } + + // Create new conversation + sqlx::query_as::<_, MeshChatConversation>( + r#" + INSERT INTO mesh_chat_conversations (owner_id, is_active) + VALUES ($1, true) + RETURNING * + "#, + ) + .bind(owner_id) + .fetch_one(pool) + .await +} + +/// List messages for a conversation. +pub async fn list_chat_messages( + pool: &PgPool, + conversation_id: Uuid, + limit: Option<i32>, +) -> Result<Vec<MeshChatMessageRecord>, sqlx::Error> { + let limit = limit.unwrap_or(100); + sqlx::query_as::<_, MeshChatMessageRecord>( + r#" + SELECT * + FROM mesh_chat_messages + WHERE conversation_id = $1 + ORDER BY created_at ASC + LIMIT $2 + "#, + ) + .bind(conversation_id) + .bind(limit) + .fetch_all(pool) + .await +} + +/// Add a message to a conversation. +#[allow(clippy::too_many_arguments)] +pub async fn add_chat_message( + pool: &PgPool, + conversation_id: Uuid, + role: &str, + content: &str, + context_type: &str, + context_task_id: Option<Uuid>, + tool_calls: Option<serde_json::Value>, + pending_questions: Option<serde_json::Value>, +) -> Result<MeshChatMessageRecord, sqlx::Error> { + sqlx::query_as::<_, MeshChatMessageRecord>( + r#" + INSERT INTO mesh_chat_messages + (conversation_id, role, content, context_type, context_task_id, tool_calls, pending_questions) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING * + "#, + ) + .bind(conversation_id) + .bind(role) + .bind(content) + .bind(context_type) + .bind(context_task_id) + .bind(tool_calls) + .bind(pending_questions) + .fetch_one(pool) + .await +} + +/// Clear conversation (archive existing and create new). +pub async fn clear_conversation(pool: &PgPool, owner_id: Uuid) -> Result<MeshChatConversation, sqlx::Error> { + // Mark existing as inactive for this owner + sqlx::query( + r#" + UPDATE mesh_chat_conversations + SET is_active = false, updated_at = NOW() + WHERE is_active = true AND owner_id = $1 + "#, + ) + .bind(owner_id) + .execute(pool) + .await?; + + // Create new active conversation + get_or_create_active_conversation(pool, owner_id).await +} diff --git a/makima/src/llm/mesh_tools.rs b/makima/src/llm/mesh_tools.rs new file mode 100644 index 0000000..1d12c66 --- /dev/null +++ b/makima/src/llm/mesh_tools.rs @@ -0,0 +1,1080 @@ +//! Tool definitions for task mesh orchestration via LLM. +//! +//! These tools allow the LLM to create, manage, and coordinate tasks across +//! connected daemons running Claude Code containers. + +use serde_json::json; +use uuid::Uuid; + +use super::tools::Tool; + +/// Available tools for mesh/task orchestration +pub static MESH_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = once_cell::sync::Lazy::new(|| { + vec![ + // ============================================================================= + // Task Lifecycle Tools + // ============================================================================= + Tool { + name: "create_task".to_string(), + description: "Create a new task (or subtask if parent_task_id provided). The task will be in 'pending' status until run_task is called.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the task" + }, + "plan": { + "type": "string", + "description": "Detailed instructions/plan for what the task should accomplish" + }, + "parent_task_id": { + "type": "string", + "description": "Optional parent task ID to create this as a subtask" + }, + "repository_url": { + "type": "string", + "description": "Git repository URL or local path for the task (required)" + }, + "base_branch": { + "type": "string", + "description": "Optional base branch to start from (default: main)" + }, + "merge_mode": { + "type": "string", + "enum": ["pr", "auto", "manual"], + "description": "How to handle completion: 'pr' creates PR, 'auto' auto-merges, 'manual' leaves changes for review" + }, + "priority": { + "type": "integer", + "description": "Task priority (higher = more important, default: 0)" + } + }, + "required": ["name", "plan", "repository_url"] + }), + }, + Tool { + name: "run_task".to_string(), + description: "Start executing a pending task on an available daemon. The task must be in 'pending' or 'paused' status.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to run" + }, + "daemon_id": { + "type": "string", + "description": "Optional specific daemon ID to run on. If not specified, an available daemon will be selected." + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "pause_task".to_string(), + description: "Pause a running task. The container state will be preserved.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to pause" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "resume_task".to_string(), + description: "Resume a paused task.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to resume" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "interrupt_task".to_string(), + description: "Interrupt a running task. Use graceful=true to allow current operation to complete.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to interrupt" + }, + "graceful": { + "type": "boolean", + "description": "If true, wait for current operation to complete before stopping" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "discard_task".to_string(), + description: "Discard a task and delete its overlay. All changes will be lost. Use with caution.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to discard" + }, + "confirm": { + "type": "boolean", + "description": "Must be true to confirm deletion" + } + }, + "required": ["task_id", "confirm"] + }), + }, + // ============================================================================= + // Task Query Tools + // ============================================================================= + Tool { + name: "query_task_status".to_string(), + description: "Get detailed status and information about a task.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to query" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "list_tasks".to_string(), + description: "List all tasks, optionally filtered by status or parent.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "status_filter": { + "type": "string", + "enum": ["pending", "running", "paused", "blocked", "done", "failed", "merged"], + "description": "Optional filter by task status" + }, + "parent_task_id": { + "type": "string", + "description": "Optional filter to list only subtasks of this parent" + } + } + }), + }, + Tool { + name: "list_subtasks".to_string(), + description: "List all subtasks of a specific task.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the parent task" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "list_siblings".to_string(), + description: "List sibling tasks (tasks with the same parent) of a specific task.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to find siblings for" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "list_daemons".to_string(), + description: "List all connected daemons and their current status.".to_string(), + parameters: json!({ + "type": "object", + "properties": {} + }), + }, + Tool { + name: "list_daemon_directories".to_string(), + description: "List all available directories from connected daemons. Use this to find existing repositories and suggested working directories when creating tasks. Returns directories like the daemon's working directory and home directory where repos can be cloned.".to_string(), + parameters: json!({ + "type": "object", + "properties": {} + }), + }, + // ============================================================================= + // File Access Tools + // ============================================================================= + Tool { + name: "list_files".to_string(), + description: "List all files available in the system. Returns file IDs, names, and descriptions.".to_string(), + parameters: json!({ + "type": "object", + "properties": {} + }), + }, + Tool { + name: "read_file".to_string(), + description: "Read the contents of a file from the files system. Returns the file's name, description, summary, body content (headings and paragraphs), and transcript entries with speaker and timing information.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "ID of the file to read" + } + }, + "required": ["file_id"] + }), + }, + // ============================================================================= + // Task Communication Tools + // ============================================================================= + Tool { + name: "send_message_to_task".to_string(), + description: "Send a message to a running task's Claude Code instance. Use this to provide additional context, answer questions, or give new instructions.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the running task" + }, + "message": { + "type": "string", + "description": "Message to send to the task" + } + }, + "required": ["task_id", "message"] + }), + }, + Tool { + name: "update_task_plan".to_string(), + description: "Update the plan/instructions for a task. Can optionally interrupt a running task to apply new plan.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to update" + }, + "new_plan": { + "type": "string", + "description": "New plan/instructions for the task" + }, + "interrupt_if_running": { + "type": "boolean", + "description": "If true and task is running, interrupt it to apply new plan" + } + }, + "required": ["task_id", "new_plan"] + }), + }, + // ============================================================================= + // Overlay/Merge Tools + // ============================================================================= + Tool { + name: "peek_sibling_overlay".to_string(), + description: "View the changes made by a sibling task's overlay. Useful for understanding what other tasks have done before merging.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "sibling_task_id": { + "type": "string", + "description": "ID of the sibling task to peek at" + } + }, + "required": ["sibling_task_id"] + }), + }, + Tool { + name: "get_overlay_diff".to_string(), + description: "Get a git diff of all changes in a task's overlay.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "preview_merge".to_string(), + description: "Preview what a merge would look like without actually merging. Shows potential conflicts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to preview merge for" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "merge_subtask".to_string(), + description: "Merge a completed subtask's changes to its parent branch.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the subtask to merge" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "complete_task".to_string(), + description: "Mark a task as complete and trigger the merge flow based on merge_mode. For 'pr' mode, creates a pull request. For 'auto' mode, merges directly.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task to complete" + } + }, + "required": ["task_id"] + }), + }, + Tool { + name: "set_merge_mode".to_string(), + description: "Change the merge mode for a task.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "ID of the task" + }, + "mode": { + "type": "string", + "enum": ["pr", "auto", "manual"], + "description": "New merge mode: 'pr' (create PR), 'auto' (auto-merge), 'manual' (leave for manual review)" + } + }, + "required": ["task_id", "mode"] + }), + }, + // ============================================================================= + // Interactive Tools + // ============================================================================= + Tool { + name: "ask_user".to_string(), + description: "Ask the user one or more questions. Use this when you need clarification, want to offer choices, or need user input before proceeding.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "questions": { + "type": "array", + "description": "List of questions to ask the user", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this question" + }, + "question": { + "type": "string", + "description": "The question to ask the user" + }, + "options": { + "type": "array", + "items": { "type": "string" }, + "description": "Multiple choice options for the user to select from" + }, + "allowMultiple": { + "type": "boolean", + "description": "If true, user can select multiple options" + }, + "allowCustom": { + "type": "boolean", + "description": "If true, user can provide a custom answer" + } + }, + "required": ["id", "question", "options"] + } + } + }, + "required": ["questions"] + }), + }, + ] +}); + +/// Request for mesh tool operations that require async database/daemon access +#[derive(Debug, Clone)] +pub enum MeshToolRequest { + // Task lifecycle + CreateTask { + name: String, + plan: String, + parent_task_id: Option<Uuid>, + repository_url: Option<String>, + base_branch: Option<String>, + merge_mode: Option<String>, + priority: Option<i32>, + }, + RunTask { + task_id: Uuid, + daemon_id: Option<Uuid>, + }, + PauseTask { + task_id: Uuid, + }, + ResumeTask { + task_id: Uuid, + }, + InterruptTask { + task_id: Uuid, + graceful: bool, + }, + DiscardTask { + task_id: Uuid, + }, + + // Task queries + QueryTaskStatus { + task_id: Uuid, + }, + ListTasks { + status_filter: Option<String>, + parent_task_id: Option<Uuid>, + }, + ListSubtasks { + task_id: Uuid, + }, + ListSiblings { + task_id: Uuid, + }, + ListDaemons, + ListDaemonDirectories, + + // File access + ListFiles, + ReadFile { + file_id: Uuid, + }, + + // Task communication + SendMessageToTask { + task_id: Uuid, + message: String, + }, + UpdateTaskPlan { + task_id: Uuid, + new_plan: String, + interrupt_if_running: bool, + }, + + // Overlay/merge operations + PeekSiblingOverlay { + sibling_task_id: Uuid, + }, + GetOverlayDiff { + task_id: Uuid, + }, + PreviewMerge { + task_id: Uuid, + }, + MergeSubtask { + task_id: Uuid, + }, + CompleteTask { + task_id: Uuid, + }, + SetMergeMode { + task_id: Uuid, + mode: String, + }, +} + +/// Result from executing a mesh tool +#[derive(Debug)] +pub struct MeshToolExecutionResult { + pub success: bool, + pub message: String, + pub data: Option<serde_json::Value>, + /// Request for async operations (handled by mesh_chat handler) + pub request: Option<MeshToolRequest>, + /// Questions to ask the user (pauses conversation) + pub pending_questions: Option<Vec<super::tools::UserQuestion>>, +} + +/// Parse and validate a mesh tool call, returning a MeshToolRequest for async handling +pub fn parse_mesh_tool_call( + call: &super::tools::ToolCall, +) -> MeshToolExecutionResult { + match call.name.as_str() { + // Task lifecycle + "create_task" => parse_create_task(call), + "run_task" => parse_run_task(call), + "pause_task" => parse_pause_task(call), + "resume_task" => parse_resume_task(call), + "interrupt_task" => parse_interrupt_task(call), + "discard_task" => parse_discard_task(call), + + // Task queries + "query_task_status" => parse_query_task_status(call), + "list_tasks" => parse_list_tasks(call), + "list_subtasks" => parse_list_subtasks(call), + "list_siblings" => parse_list_siblings(call), + "list_daemons" => parse_list_daemons(), + "list_daemon_directories" => parse_list_daemon_directories(), + + // File access + "list_files" => parse_list_files(), + "read_file" => parse_read_file(call), + + // Task communication + "send_message_to_task" => parse_send_message_to_task(call), + "update_task_plan" => parse_update_task_plan(call), + + // Overlay/merge operations + "peek_sibling_overlay" => parse_peek_sibling_overlay(call), + "get_overlay_diff" => parse_get_overlay_diff(call), + "preview_merge" => parse_preview_merge(call), + "merge_subtask" => parse_merge_subtask(call), + "complete_task" => parse_complete_task(call), + "set_merge_mode" => parse_set_merge_mode(call), + + // Interactive tools + "ask_user" => parse_ask_user(call), + + _ => MeshToolExecutionResult { + success: false, + message: format!("Unknown mesh tool: {}", call.name), + data: None, + request: None, + pending_questions: None, + }, + } +} + +// ============================================================================= +// Tool Parsing Functions +// ============================================================================= + +fn parse_create_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let name = call.arguments.get("name").and_then(|v| v.as_str()); + let plan = call.arguments.get("plan").and_then(|v| v.as_str()); + let repository_url = call + .arguments + .get("repository_url") + .and_then(|v| v.as_str()); + + let Some(name) = name else { + return error_result("Missing required parameter: name"); + }; + let Some(plan) = plan else { + return error_result("Missing required parameter: plan"); + }; + let Some(repository_url) = repository_url else { + return error_result("Missing required parameter: repository_url"); + }; + + let parent_task_id = call + .arguments + .get("parent_task_id") + .and_then(|v| v.as_str()) + .and_then(|s| Uuid::parse_str(s).ok()); + + let repository_url = Some(repository_url.to_string()); + + let base_branch = call + .arguments + .get("base_branch") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let merge_mode = call + .arguments + .get("merge_mode") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let priority = call + .arguments + .get("priority") + .and_then(|v| v.as_i64()) + .map(|v| v as i32); + + MeshToolExecutionResult { + success: true, + message: "Creating task...".to_string(), + data: None, + request: Some(MeshToolRequest::CreateTask { + name: name.to_string(), + plan: plan.to_string(), + parent_task_id, + repository_url, + base_branch, + merge_mode, + priority, + }), + pending_questions: None, + } +} + +fn parse_run_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let daemon_id = call + .arguments + .get("daemon_id") + .and_then(|v| v.as_str()) + .and_then(|s| Uuid::parse_str(s).ok()); + + MeshToolExecutionResult { + success: true, + message: "Starting task...".to_string(), + data: None, + request: Some(MeshToolRequest::RunTask { task_id, daemon_id }), + pending_questions: None, + } +} + +fn parse_pause_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Pausing task...".to_string(), + data: None, + request: Some(MeshToolRequest::PauseTask { task_id }), + pending_questions: None, + } +} + +fn parse_resume_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Resuming task...".to_string(), + data: None, + request: Some(MeshToolRequest::ResumeTask { task_id }), + pending_questions: None, + } +} + +fn parse_interrupt_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let graceful = call + .arguments + .get("graceful") + .and_then(|v| v.as_bool()) + .unwrap_or(true); + + MeshToolExecutionResult { + success: true, + message: if graceful { + "Gracefully interrupting task...".to_string() + } else { + "Force interrupting task...".to_string() + }, + data: None, + request: Some(MeshToolRequest::InterruptTask { task_id, graceful }), + pending_questions: None, + } +} + +fn parse_discard_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let confirm = call + .arguments + .get("confirm") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if !confirm { + return error_result("Must set confirm=true to discard a task"); + } + + MeshToolExecutionResult { + success: true, + message: "Discarding task...".to_string(), + data: None, + request: Some(MeshToolRequest::DiscardTask { task_id }), + pending_questions: None, + } +} + +fn parse_query_task_status(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Querying task status...".to_string(), + data: None, + request: Some(MeshToolRequest::QueryTaskStatus { task_id }), + pending_questions: None, + } +} + +fn parse_list_tasks(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let status_filter = call + .arguments + .get("status_filter") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let parent_task_id = call + .arguments + .get("parent_task_id") + .and_then(|v| v.as_str()) + .and_then(|s| Uuid::parse_str(s).ok()); + + MeshToolExecutionResult { + success: true, + message: "Listing tasks...".to_string(), + data: None, + request: Some(MeshToolRequest::ListTasks { + status_filter, + parent_task_id, + }), + pending_questions: None, + } +} + +fn parse_list_subtasks(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Listing subtasks...".to_string(), + data: None, + request: Some(MeshToolRequest::ListSubtasks { task_id }), + pending_questions: None, + } +} + +fn parse_list_siblings(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Listing sibling tasks...".to_string(), + data: None, + request: Some(MeshToolRequest::ListSiblings { task_id }), + pending_questions: None, + } +} + +fn parse_list_daemons() -> MeshToolExecutionResult { + MeshToolExecutionResult { + success: true, + message: "Listing daemons...".to_string(), + data: None, + request: Some(MeshToolRequest::ListDaemons), + pending_questions: None, + } +} + +fn parse_list_daemon_directories() -> MeshToolExecutionResult { + MeshToolExecutionResult { + success: true, + message: "Listing daemon directories...".to_string(), + data: None, + request: Some(MeshToolRequest::ListDaemonDirectories), + pending_questions: None, + } +} + +fn parse_list_files() -> MeshToolExecutionResult { + MeshToolExecutionResult { + success: true, + message: "Listing files...".to_string(), + data: None, + request: Some(MeshToolRequest::ListFiles), + pending_questions: None, + } +} + +fn parse_read_file(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let file_id = parse_uuid_arg(call, "file_id"); + let Some(file_id) = file_id else { + return error_result("Missing or invalid required parameter: file_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Reading file...".to_string(), + data: None, + request: Some(MeshToolRequest::ReadFile { file_id }), + pending_questions: None, + } +} + +fn parse_send_message_to_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let message = call.arguments.get("message").and_then(|v| v.as_str()); + let Some(message) = message else { + return error_result("Missing required parameter: message"); + }; + + MeshToolExecutionResult { + success: true, + message: "Sending message to task...".to_string(), + data: None, + request: Some(MeshToolRequest::SendMessageToTask { + task_id, + message: message.to_string(), + }), + pending_questions: None, + } +} + +fn parse_update_task_plan(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let new_plan = call.arguments.get("new_plan").and_then(|v| v.as_str()); + let Some(new_plan) = new_plan else { + return error_result("Missing required parameter: new_plan"); + }; + + let interrupt_if_running = call + .arguments + .get("interrupt_if_running") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + MeshToolExecutionResult { + success: true, + message: "Updating task plan...".to_string(), + data: None, + request: Some(MeshToolRequest::UpdateTaskPlan { + task_id, + new_plan: new_plan.to_string(), + interrupt_if_running, + }), + pending_questions: None, + } +} + +fn parse_peek_sibling_overlay(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let sibling_task_id = parse_uuid_arg(call, "sibling_task_id"); + let Some(sibling_task_id) = sibling_task_id else { + return error_result("Missing or invalid required parameter: sibling_task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Peeking at sibling overlay...".to_string(), + data: None, + request: Some(MeshToolRequest::PeekSiblingOverlay { sibling_task_id }), + pending_questions: None, + } +} + +fn parse_get_overlay_diff(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Getting overlay diff...".to_string(), + data: None, + request: Some(MeshToolRequest::GetOverlayDiff { task_id }), + pending_questions: None, + } +} + +fn parse_preview_merge(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Previewing merge...".to_string(), + data: None, + request: Some(MeshToolRequest::PreviewMerge { task_id }), + pending_questions: None, + } +} + +fn parse_merge_subtask(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Merging subtask...".to_string(), + data: None, + request: Some(MeshToolRequest::MergeSubtask { task_id }), + pending_questions: None, + } +} + +fn parse_complete_task(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + MeshToolExecutionResult { + success: true, + message: "Completing task...".to_string(), + data: None, + request: Some(MeshToolRequest::CompleteTask { task_id }), + pending_questions: None, + } +} + +fn parse_set_merge_mode(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let task_id = parse_uuid_arg(call, "task_id"); + let Some(task_id) = task_id else { + return error_result("Missing or invalid required parameter: task_id"); + }; + + let mode = call.arguments.get("mode").and_then(|v| v.as_str()); + let Some(mode) = mode else { + return error_result("Missing required parameter: mode"); + }; + + if !["pr", "auto", "manual"].contains(&mode) { + return error_result("Invalid mode. Must be 'pr', 'auto', or 'manual'"); + } + + MeshToolExecutionResult { + success: true, + message: format!("Setting merge mode to '{}'...", mode), + data: None, + request: Some(MeshToolRequest::SetMergeMode { + task_id, + mode: mode.to_string(), + }), + pending_questions: None, + } +} + +fn parse_ask_user(call: &super::tools::ToolCall) -> MeshToolExecutionResult { + let questions_value = call.arguments.get("questions"); + + let Some(questions_array) = questions_value.and_then(|v| v.as_array()) else { + return error_result("Missing or invalid 'questions' parameter"); + }; + + let mut questions: Vec<super::tools::UserQuestion> = Vec::new(); + + for q in questions_array { + let id = q.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let question = q.get("question").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let options: Vec<String> = q + .get("options") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|o| o.as_str()) + .map(|s| s.to_string()) + .collect() + }) + .unwrap_or_default(); + let allow_multiple = q.get("allowMultiple").and_then(|v| v.as_bool()).unwrap_or(false); + let allow_custom = q.get("allowCustom").and_then(|v| v.as_bool()).unwrap_or(true); + + if id.is_empty() || question.is_empty() || options.is_empty() { + continue; + } + + questions.push(super::tools::UserQuestion { + id, + question, + options, + allow_multiple, + allow_custom, + }); + } + + if questions.is_empty() { + return error_result("No valid questions provided"); + } + + let question_count = questions.len(); + MeshToolExecutionResult { + success: true, + message: format!("Asking user {} question(s). Waiting for response...", question_count), + data: None, + request: None, + pending_questions: Some(questions), + } +} + +// ============================================================================= +// Helper Functions +// ============================================================================= + +fn parse_uuid_arg(call: &super::tools::ToolCall, key: &str) -> Option<Uuid> { + call.arguments + .get(key) + .and_then(|v| v.as_str()) + .and_then(|s| Uuid::parse_str(s).ok()) +} + +fn error_result(message: &str) -> MeshToolExecutionResult { + MeshToolExecutionResult { + success: false, + message: message.to_string(), + data: None, + request: None, + pending_questions: None, + } +} diff --git a/makima/src/llm/mod.rs b/makima/src/llm/mod.rs index 1001854..39cdbdd 100644 --- a/makima/src/llm/mod.rs +++ b/makima/src/llm/mod.rs @@ -2,10 +2,12 @@ pub mod claude; pub mod groq; +pub mod mesh_tools; pub mod tools; pub use claude::{ClaudeClient, ClaudeModel}; pub use groq::GroqClient; +pub use mesh_tools::{parse_mesh_tool_call, MeshToolExecutionResult, MeshToolRequest, MESH_TOOLS}; pub use tools::{ execute_tool_call, Tool, ToolCall, ToolResult, UserAnswer, UserQuestion, VersionToolRequest, AVAILABLE_TOOLS, diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 77fc8c6..649633e 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -73,6 +73,51 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }), }, Tool { + name: "add_code".to_string(), + description: "Add a code block element to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The code content" + }, + "language": { + "type": "string", + "description": "Optional programming language for syntax highlighting (e.g., 'javascript', 'python', 'rust')" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["content"] + }), + }, + Tool { + name: "add_list".to_string(), + description: "Add a list element (ordered or unordered) to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { "type": "string" }, + "description": "Array of list item strings" + }, + "ordered": { + "type": "boolean", + "description": "If true, creates a numbered list; if false (default), creates a bullet list" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["items"] + }), + }, + Tool { name: "add_chart".to_string(), description: "Add a chart visualization to the file body. Supports line, bar, pie, and area charts.".to_string(), parameters: json!({ @@ -122,7 +167,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }, Tool { name: "update_element".to_string(), - description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(), + description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For code: type, content, language (optional). For list: type, items (array of strings), ordered (boolean). For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(), parameters: json!({ "type": "object", "properties": { @@ -132,7 +177,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }, "element_type": { "type": "string", - "enum": ["heading", "paragraph", "chart"], + "enum": ["heading", "paragraph", "code", "list", "chart"], "description": "Type of element" }, "text": { @@ -143,6 +188,23 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "type": "integer", "description": "Heading level 1-6 (required for heading)" }, + "content": { + "type": "string", + "description": "Code content (required for code)" + }, + "language": { + "type": "string", + "description": "Programming language for syntax highlighting (optional for code)" + }, + "items": { + "type": "array", + "items": { "type": "string" }, + "description": "List items (required for list)" + }, + "ordered": { + "type": "boolean", + "description": "If true, numbered list; if false, bullet list (for list)" + }, "chartType": { "type": "string", "enum": ["line", "bar", "pie", "area"], @@ -418,6 +480,8 @@ pub fn execute_tool_call( match call.name.as_str() { "add_heading" => execute_add_heading(call, current_body), "add_paragraph" => execute_add_paragraph(call, current_body), + "add_code" => execute_add_code(call, current_body), + "add_list" => execute_add_list(call, current_body), "add_chart" => execute_add_chart(call, current_body), "remove_element" => execute_remove_element(call, current_body), "update_element" => execute_update_element(call, current_body), @@ -605,6 +669,103 @@ fn execute_add_paragraph(call: &ToolCall, current_body: &[BodyElement]) -> ToolE } } +fn execute_add_code(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let language = call + .arguments + .get("language") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let content = call + .arguments + .get("content") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Code { + language: language.clone(), + content: content.clone(), + }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + let lang_str = language.as_deref().unwrap_or("plain"); + let preview: String = content.chars().take(50).collect(); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added code block ({}): {}", lang_str, preview), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + version_request: None, + pending_questions: None, + } +} + +fn execute_add_list(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let ordered = call + .arguments + .get("ordered") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let items: Vec<String> = call + .arguments + .get("items") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::List { + ordered, + items: items.clone(), + }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + let list_type = if ordered { "ordered" } else { "unordered" }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added {} list with {} items", list_type, items.len()), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + version_request: None, + pending_questions: None, + } +} + fn execute_add_chart(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { let chart_type_str = call .arguments @@ -778,6 +939,19 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); BodyElement::Paragraph { text } } + "code" => { + let language = call.arguments.get("language").and_then(|v| v.as_str()).map(|s| s.to_string()); + let content = call.arguments.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Code { language, content } + } + "list" => { + let ordered = call.arguments.get("ordered").and_then(|v| v.as_bool()).unwrap_or(false); + let items = call.arguments.get("items") + .and_then(|v| v.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect()) + .unwrap_or_default(); + BodyElement::List { ordered, items } + } "chart" => { let chart_type_str = call.arguments.get("chartType").and_then(|v| v.as_str()).unwrap_or("bar"); let chart_type = match chart_type_str { @@ -796,7 +970,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool return ToolExecutionResult { result: ToolResult { success: false, - message: format!("Unknown element_type: {}. Must be heading, paragraph, or chart.", element_type), + message: format!("Unknown element_type: {}. Must be heading, paragraph, code, list, or chart.", element_type), }, new_body: None, new_summary: None, @@ -1149,6 +1323,18 @@ fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult { "type": "paragraph", "text": text }), + BodyElement::Code { language, content } => json!({ + "index": i, + "type": "code", + "language": language, + "content": content + }), + BodyElement::List { ordered, items } => json!({ + "index": i, + "type": "list", + "ordered": ordered, + "items": items + }), BodyElement::Chart { chart_type, title, data, config } => json!({ "index": i, "type": "chart", @@ -1226,6 +1412,18 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx "type": "paragraph", "text": text }), + BodyElement::Code { language, content } => json!({ + "index": index, + "type": "code", + "language": language, + "content": content + }), + BodyElement::List { ordered, items } => json!({ + "index": index, + "type": "list", + "ordered": ordered, + "items": items + }), BodyElement::Chart { chart_type, title, data, config } => json!({ "index": index, "type": "chart", @@ -1246,6 +1444,8 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx let type_str = match element { BodyElement::Heading { .. } => "heading", BodyElement::Paragraph { .. } => "paragraph", + BodyElement::Code { .. } => "code", + BodyElement::List { .. } => "list", BodyElement::Chart { .. } => "chart", BodyElement::Image { .. } => "image", }; diff --git a/makima/src/server/auth.rs b/makima/src/server/auth.rs new file mode 100644 index 0000000..b694df6 --- /dev/null +++ b/makima/src/server/auth.rs @@ -0,0 +1,1238 @@ +//! Authentication module for Makima server. +//! +//! Supports multiple authentication methods: +//! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification) +//! - API keys for programmatic access (daemons, CLI) +//! - Tool keys for orchestrator internal access + +use axum::{ + extract::FromRequestParts, + http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use chrono::{DateTime, Utc}; +use dashmap::DashMap; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use sqlx::{FromRow, PgPool, Row}; +use std::time::{Duration, Instant}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +// ============================================================================= +// Configuration +// ============================================================================= + +/// JWT algorithm configuration. +#[derive(Debug, Clone)] +pub enum JwtAlgorithm { + /// RS256 with RSA public key + Rs256 { public_key: String }, + /// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys) + Es256 { public_key: String }, +} + +/// Authentication configuration loaded from environment. +#[derive(Debug, Clone)] +pub struct AuthConfig { + /// Supabase project URL (e.g., https://your-project.supabase.co) + pub supabase_url: String, + /// JWT algorithm and key material + pub algorithm: JwtAlgorithm, +} + +impl AuthConfig { + /// Load auth config from environment variables. + /// + /// Supports two modes (checked in order): + /// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA) + /// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key) + /// + /// Returns None if auth is not configured. + pub fn from_env() -> Option<Self> { + let supabase_url = std::env::var("SUPABASE_URL").ok()?; + + // Try ES256 first (default for Supabase), then RS256 + let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") { + tracing::info!("Using ES256 JWT verification with ECDSA public key"); + JwtAlgorithm::Es256 { public_key } + } else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") { + tracing::info!("Using RS256 JWT verification with RSA public key"); + JwtAlgorithm::Rs256 { public_key } + } else { + return None; + }; + + Some(Self { + supabase_url, + algorithm, + }) + } +} + +// ============================================================================= +// JWT Claims +// ============================================================================= + +/// JWT claims from Supabase Auth tokens. +#[derive(Debug, Serialize, Deserialize)] +pub struct SupabaseClaims { + /// Audience (e.g., "authenticated") + pub aud: String, + /// Expiration time (Unix timestamp) + pub exp: i64, + /// Issued at (Unix timestamp) + pub iat: i64, + /// Issuer (Supabase project URL + /auth/v1) + pub iss: String, + /// Subject (user ID) + pub sub: Uuid, + /// User's email + pub email: Option<String>, + /// User's phone + pub phone: Option<String>, + /// App metadata (set by server/admin) + pub app_metadata: Option<serde_json::Value>, + /// User metadata (set by user) + pub user_metadata: Option<serde_json::Value>, + /// Role (e.g., "authenticated") + pub role: Option<String>, + /// Session ID + pub session_id: Option<Uuid>, +} + +// ============================================================================= +// JWT Verifier +// ============================================================================= + +/// JWT verifier for Supabase tokens. +pub struct JwtVerifier { + supabase_url: String, + decoding_key: DecodingKey, + algorithm: Algorithm, +} + +impl JwtVerifier { + /// Create a new JWT verifier from auth config. + /// + /// Supports multiple key formats: + /// - JWK (JSON Web Key) - detected by presence of `{` + /// - PEM - detected by `-----BEGIN` + /// - Base64-encoded DER - fallback + pub fn new(config: AuthConfig) -> Result<Self, AuthError> { + let (decoding_key, algorithm) = match &config.algorithm { + JwtAlgorithm::Rs256 { public_key } => { + let key = Self::parse_public_key(public_key, "RSA")?; + (key, Algorithm::RS256) + } + JwtAlgorithm::Es256 { public_key } => { + let key = Self::parse_public_key(public_key, "EC")?; + (key, Algorithm::ES256) + } + }; + + Ok(Self { + supabase_url: config.supabase_url, + decoding_key, + algorithm, + }) + } + + /// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER). + fn parse_public_key(key_data: &str, key_type: &str) -> Result<DecodingKey, AuthError> { + let trimmed = key_data.trim(); + + // Check for JSON format (JWK or JWKS) + if trimmed.starts_with('{') { + // First try to parse as a generic JSON value to inspect structure + let mut json_value: serde_json::Value = serde_json::from_str(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?; + + // Check if it's a JWKS (has "keys" array) + if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) { + // Find the first signing key (or just use the first key) + let jwk_value = keys.first_mut() + .ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?; + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = jwk_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWKS (first key)"); + return DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))); + } + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = json_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + // Try as single JWK + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWK"); + DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))) + } + // Check for PEM format + else if trimmed.contains("-----BEGIN") { + tracing::info!("Loaded JWT public key from PEM"); + match key_type { + "RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))), + "EC" => DecodingKey::from_ec_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + // Assume base64-encoded DER + else { + tracing::info!("Loaded JWT public key from base64 DER"); + let der_bytes = base64::engine::general_purpose::STANDARD + .decode(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?; + + match key_type { + "RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)), + "EC" => Ok(DecodingKey::from_ec_der(&der_bytes)), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + } + + /// Verify a JWT token and return claims. + pub fn verify(&self, token: &str) -> Result<SupabaseClaims, AuthError> { + // Decode header to check algorithm mismatch + let header = jsonwebtoken::decode_header(token) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?; + + tracing::debug!( + "JWT header: algorithm={:?}, typ={:?}, kid={:?}", + header.alg, + header.typ, + header.kid + ); + + if header.alg != self.algorithm { + let hint = match header.alg { + Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings", + Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key", + _ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported", + }; + tracing::warn!( + "JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}", + header.alg, + self.algorithm, + hint + ); + return Err(AuthError::InvalidToken(format!( + "Algorithm mismatch: token is {:?}, expected {:?}", + header.alg, self.algorithm + ))); + } + + let mut validation = Validation::new(self.algorithm); + validation.set_audience(&["authenticated"]); + validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]); + + // First try with full validation + let token_data = match decode::<SupabaseClaims>(token, &self.decoding_key, &validation) { + Ok(data) => data, + Err(e) => { + // Log detailed error info + tracing::warn!( + "JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)", + e, + self.algorithm, + self.supabase_url + ); + + // If it's InvalidAlgorithm, try to understand why by decoding payload manually + if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) { + // Decode the payload part of the JWT manually (base64) + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() >= 2 { + if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) { + if let Ok(payload_str) = String::from_utf8(payload_bytes) { + if let Ok(claims) = serde_json::from_str::<serde_json::Value>(&payload_str) { + tracing::warn!( + "JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}", + claims.get("iss"), + claims.get("aud"), + claims.get("sub") + ); + } + } + } + } + } + + return Err(AuthError::InvalidToken(e.to_string())); + } + }; + + Ok(token_data.claims) + } + + /// Extract user ID from a token. + pub fn get_user_id(&self, token: &str) -> Result<Uuid, AuthError> { + let claims = self.verify(token)?; + Ok(claims.sub) + } +} + +// ============================================================================= +// Auth Error +// ============================================================================= + +/// Authentication error types. +#[derive(Debug)] +pub enum AuthError { + /// No authentication token provided + MissingToken, + /// Token format is invalid + InvalidToken(String), + /// Token has expired + ExpiredToken, + /// User not found in database + UserNotFound, + /// API key is invalid or revoked + InvalidApiKey, + /// Database error during auth lookup + DatabaseError(String), + /// Authentication is not configured + NotConfigured, + /// Insufficient permissions for the operation + InsufficientPermissions, +} + +impl std::fmt::Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthError::MissingToken => write!(f, "Missing authentication token"), + AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg), + AuthError::ExpiredToken => write!(f, "Token has expired"), + AuthError::UserNotFound => write!(f, "User not found"), + AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"), + AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg), + AuthError::NotConfigured => write!(f, "Authentication not configured"), + AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"), + } + } +} + +impl std::error::Error for AuthError {} + +impl IntoResponse for AuthError { + fn into_response(self) -> axum::response::Response { + let (status, code, message) = match &self { + AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"), + AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"), + AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"), + AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"), + AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"), + AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"), + AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"), + AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"), + }; + + (status, Json(ApiError::new(code, message))).into_response() + } +} + +// ============================================================================= +// Auth Source +// ============================================================================= + +/// Source of authentication. +#[derive(Debug, Clone)] +pub enum AuthSource { + /// Authenticated via Supabase JWT (web client) + Jwt, + /// Authenticated via API key (daemon, CLI, integrations) + ApiKey, + /// Authenticated via tool key (orchestrator internal access) + ToolKey(Uuid), +} + +// ============================================================================= +// Authenticated User +// ============================================================================= + +/// Authenticated user context extracted from request. +/// +/// Contains the resolved user_id and owner_id for database operations. +#[derive(Debug, Clone)] +pub struct AuthenticatedUser { + /// Supabase auth user ID (from auth.users) + pub user_id: Uuid, + /// Owner ID for data isolation (from users.default_owner_id) + pub owner_id: Uuid, + /// How the user was authenticated + pub auth_source: AuthSource, + /// User's email (if available) + pub email: Option<String>, +} + +// ============================================================================= +// Header Constants +// ============================================================================= + +/// Header name for tool key authentication (orchestrators). +pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; + +/// Header name for API key authentication. +pub const API_KEY_HEADER: &str = "x-makima-api-key"; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Hash an API key for database lookup. +pub fn hash_api_key(key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + hex::encode(hasher.finalize()) +} + +// ============================================================================= +// API Key Generation +// ============================================================================= + +/// API key prefix for identification. +pub const API_KEY_PREFIX: &str = "mk_"; + +/// Result of generating an API key. +pub struct GeneratedApiKey { + /// The full API key (shown only once to user) + pub full_key: String, + /// SHA-256 hash of the key (stored in database) + pub key_hash: String, + /// Prefix for display (first 8 chars after mk_) + pub key_prefix: String, +} + +/// Generate a new API key with mk_ prefix. +/// +/// Returns the full key (to show once), hash (to store), and prefix (for display). +pub fn generate_api_key() -> GeneratedApiKey { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + + let key_bytes = URL_SAFE_NO_PAD.encode(bytes); + let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes); + + let key_hash = hash_api_key(&full_key); + let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]); + + GeneratedApiKey { + full_key, + key_hash, + key_prefix, + } +} + +// ============================================================================= +// API Key Cache +// ============================================================================= + +/// Cache entry for validated API keys. +struct ApiKeyCacheEntry { + user_id: Uuid, + owner_id: Uuid, + cached_at: Instant, +} + +/// In-memory cache for API key validation to avoid database lookups on every request. +pub struct ApiKeyCache { + /// key_hash -> (user_id, owner_id, cached_at) + cache: DashMap<String, ApiKeyCacheEntry>, + /// Time-to-live for cache entries + ttl: Duration, +} + +impl ApiKeyCache { + /// Create a new cache with the specified TTL in seconds. + pub fn new(ttl_seconds: u64) -> Self { + Self { + cache: DashMap::new(), + ttl: Duration::from_secs(ttl_seconds), + } + } + + /// Get cached user_id and owner_id for a key hash, if not expired. + pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> { + self.cache.get(key_hash).and_then(|entry| { + if entry.cached_at.elapsed() < self.ttl { + Some((entry.user_id, entry.owner_id)) + } else { + None + } + }) + } + + /// Cache a validated API key. + pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) { + self.cache.insert( + key_hash, + ApiKeyCacheEntry { + user_id, + owner_id, + cached_at: Instant::now(), + }, + ); + } + + /// Invalidate a cache entry (e.g., on key revocation). + pub fn invalidate(&self, key_hash: &str) { + self.cache.remove(key_hash); + } + + /// Clear all cache entries. + pub fn clear(&self) { + self.cache.clear(); + } +} + +impl Default for ApiKeyCache { + fn default() -> Self { + // Default TTL: 5 minutes + Self::new(300) + } +} + +// ============================================================================= +// API Key Models +// ============================================================================= + +/// API key record from the database. +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKey { + pub id: Uuid, + pub user_id: Uuid, + #[serde(skip)] + pub key_hash: String, + pub key_prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, + pub revoked_at: Option<DateTime<Utc>>, +} + +/// Request to create a new API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyRequest { + /// User-provided label for the key + pub name: Option<String>, +} + +/// Response after creating an API key (includes the full key - shown only once). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyResponse { + pub id: Uuid, + /// The full API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, +} + +/// Response for getting API key info (excludes the full key). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKeyInfoResponse { + pub id: Uuid, + pub prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, +} + +impl From<ApiKey> for ApiKeyInfoResponse { + fn from(key: ApiKey) -> Self { + Self { + id: key.id, + prefix: key.key_prefix, + name: key.name, + last_used_at: key.last_used_at, + created_at: key.created_at, + } + } +} + +/// Request to refresh an API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyRequest { + /// New name for the refreshed key + pub name: Option<String>, +} + +/// Response after refreshing an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyResponse { + pub id: Uuid, + /// The new API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, + pub previous_key_revoked: bool, +} + +/// Response after revoking an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RevokeApiKeyResponse { + pub message: String, + pub revoked_key_prefix: String, +} + +/// API key event types for audit logging. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApiKeyEventType { + Created, + Used, + Revoked, + Refreshed, +} + +impl std::fmt::Display for ApiKeyEventType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyEventType::Created => write!(f, "created"), + ApiKeyEventType::Used => write!(f, "used"), + ApiKeyEventType::Revoked => write!(f, "revoked"), + ApiKeyEventType::Refreshed => write!(f, "refreshed"), + } + } +} + +// ============================================================================= +// API Keys Repository +// ============================================================================= + +/// Repository error for API key operations. +#[derive(Debug)] +pub enum ApiKeyError { + /// Database error + Database(sqlx::Error), + /// An active API key already exists for this user + KeyAlreadyExists, + /// No active API key found for this user + KeyNotFound, +} + +impl std::fmt::Display for ApiKeyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyError::Database(e) => write!(f, "Database error: {}", e), + ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"), + ApiKeyError::KeyNotFound => write!(f, "No active API key found"), + } + } +} + +impl std::error::Error for ApiKeyError {} + +impl From<sqlx::Error> for ApiKeyError { + fn from(e: sqlx::Error) -> Self { + ApiKeyError::Database(e) + } +} + +/// Get the active API key for a user (if any). +pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result<Option<ApiKey>, sqlx::Error> { + sqlx::query_as::<_, ApiKey>( + r#" + SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + FROM api_keys + WHERE user_id = $1 AND revoked_at IS NULL + "#, + ) + .bind(user_id) + .fetch_optional(pool) + .await +} + +/// Create a new API key for a user. +/// +/// Returns an error if the user already has an active key. +/// The `generated` parameter should be created using `generate_api_key()`. +pub async fn create_api_key( + pool: &PgPool, + user_id: Uuid, + generated: &GeneratedApiKey, + name: Option<&str>, +) -> Result<ApiKey, ApiKeyError> { + // Check if user already has an active key + if let Some(_) = get_active_api_key(pool, user_id).await? { + return Err(ApiKeyError::KeyAlreadyExists); + } + + let key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&generated.key_hash) + .bind(&generated.key_prefix) + .bind(name) + .fetch_one(pool) + .await?; + + // Log the creation event + let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await; + + Ok(key) +} + +/// Revoke an API key by marking it with revoked_at timestamp. +pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result<ApiKey, ApiKeyError> { + // Get the active key first + let key = get_active_api_key(pool, user_id) + .await? + .ok_or(ApiKeyError::KeyNotFound)?; + + // Revoke it + let revoked = sqlx::query_as::<_, ApiKey>( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(key.id) + .fetch_one(pool) + .await?; + + // Log the revocation event + let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await; + + Ok(revoked) +} + +/// Refresh an API key: revoke the old one and create a new one atomically. +/// +/// Returns the new key. The caller should use `generate_api_key()` to create +/// the `new_generated` parameter. +pub async fn refresh_api_key( + pool: &PgPool, + user_id: Uuid, + new_generated: &GeneratedApiKey, + new_name: Option<&str>, +) -> Result<(ApiKey, Option<String>), ApiKeyError> { + // Get and revoke the old key (if exists) + let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? { + let old_prefix = old_key.key_prefix.clone(); + + // Revoke the old key + sqlx::query( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + "#, + ) + .bind(old_key.id) + .execute(pool) + .await?; + + // Log the refresh event on the old key + let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await; + + Some(old_prefix) + } else { + None + }; + + // Create the new key + let new_key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&new_generated.key_hash) + .bind(&new_generated.key_prefix) + .bind(new_name) + .fetch_one(pool) + .await?; + + // Log the creation event on the new key + let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await; + + Ok((new_key, old_prefix)) +} + +/// Update last_used_at timestamp for an API key. +pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + UPDATE api_keys + SET last_used_at = NOW() + WHERE key_hash = $1 AND revoked_at IS NULL + "#, + ) + .bind(key_hash) + .execute(pool) + .await?; + + Ok(()) +} + +/// Log an API key event for audit purposes. +pub async fn log_api_key_event( + pool: &PgPool, + api_key_id: Uuid, + event_type: ApiKeyEventType, + ip_address: Option<&str>, + user_agent: Option<&str>, +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent) + VALUES ($1, $2, $3::inet, $4) + "#, + ) + .bind(api_key_id) + .bind(event_type.to_string()) + .bind(ip_address) + .bind(user_agent) + .execute(pool) + .await?; + + Ok(()) +} + +// ============================================================================= +// Internal Helper Functions +// ============================================================================= + +/// Resolve owner_id from user_id by looking up the users table. +/// If the user doesn't exist, auto-creates them on first login. +/// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously. +async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> { + // First, try to get existing user + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + if let Some(row) = row { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + return owner_id.ok_or(AuthError::UserNotFound); + } + + // User doesn't exist - auto-create on first login + tracing::info!("Creating new user record for {}", user_id); + + // Create owner first (use ON CONFLICT to handle race conditions) + let owner_id = Uuid::new_v4(); + sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING") + .bind(owner_id) + .bind(email.unwrap_or("Unknown")) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?; + + // Create user with reference to owner (use ON CONFLICT to handle race conditions) + sqlx::query( + "INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING" + ) + .bind(user_id) + .bind(email) + .bind(owner_id) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?; + + // Re-fetch the user to get the actual owner_id (in case another request created it first) + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + owner_id.ok_or(AuthError::UserNotFound) + } + None => Err(AuthError::DatabaseError("Failed to create user record".to_string())) + } +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id = owner_id.ok_or(AuthError::UserNotFound)?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok((user_id, owner_id)) + } + None => Err(AuthError::InvalidApiKey), + } +} + +/// Extract authentication from request headers. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +async fn extract_auth( + state: &SharedState, + headers: &HeaderMap, +) -> Result<AuthenticatedUser, AuthError> { + // 1. Check for tool key (orchestrator access) + if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { + if let Ok(key_str) = tool_key.to_str() { + if let Some(task_id) = state.validate_tool_key(key_str) { + // Tool keys are trusted - use a placeholder user/owner for orchestrator actions + // The orchestrator inherits the owner_id from its task + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + + // Get owner_id from the task + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))? + .ok_or(AuthError::UserNotFound)?; + + let task_owner: Uuid = row.try_get("owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + return Ok(AuthenticatedUser { + user_id: Uuid::nil(), // Tool keys don't have a user + owner_id: task_owner, + auth_source: AuthSource::ToolKey(task_id), + email: None, + }); + } + tracing::warn!("Invalid tool key provided"); + } + } + + // 2. Check for API key + if let Some(api_key) = headers.get(API_KEY_HEADER) { + if let Ok(key_str) = api_key.to_str() { + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let (user_id, owner_id) = validate_api_key(pool, key_str).await?; + + return Ok(AuthenticatedUser { + user_id, + owner_id, + auth_source: AuthSource::ApiKey, + email: None, + }); + } + } + + // 3. Check for JWT (Bearer token) + if let Some(auth_header) = headers.get(AUTHORIZATION) { + if let Ok(auth_str) = auth_header.to_str() { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + let verifier = state + .jwt_verifier + .as_ref() + .ok_or(AuthError::NotConfigured)?; + + let claims = verifier.verify(token)?; + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?; + + return Ok(AuthenticatedUser { + user_id: claims.sub, + owner_id, + auth_source: AuthSource::Jwt, + email: claims.email, + }); + } + } + } + + Err(AuthError::MissingToken) +} + +// ============================================================================= +// Extractors +// ============================================================================= + +/// Extractor for authenticated requests. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +/// +/// Returns 401 Unauthorized if no valid authentication is found. +/// +/// # Example +/// ```ignore +/// async fn protected_handler( +/// Authenticated(user): Authenticated, +/// ) -> impl IntoResponse { +/// Json(format!("Hello user {}", user.user_id)) +/// } +/// ``` +pub struct Authenticated(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for Authenticated { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + Ok(Authenticated(user)) + } +} + +/// Extractor for user-only authentication (JWT or API key, no tool keys). +/// +/// Use this for endpoints that should only be accessible to actual users, +/// not orchestrators with tool keys. +/// +/// Returns 401 Unauthorized if no valid user authentication is found. +/// Returns 403 Forbidden if a tool key is used. +/// +/// # Example +/// ```ignore +/// async fn user_profile( +/// UserOnly(user): UserOnly, +/// ) -> impl IntoResponse { +/// // Only actual users can access this +/// Json(format!("User profile for {}", user.user_id)) +/// } +/// ``` +pub struct UserOnly(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for UserOnly { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + + // Reject tool key authentication + if matches!(user.auth_source, AuthSource::ToolKey(_)) { + return Err(AuthError::InsufficientPermissions); + } + + Ok(UserOnly(user)) + } +} + +/// Extractor for optional authentication. +/// +/// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise. +/// Never returns an error - invalid auth is treated as no auth. +/// +/// # Example +/// ```ignore +/// async fn public_or_private( +/// MaybeAuthenticated(user): MaybeAuthenticated, +/// ) -> impl IntoResponse { +/// match user { +/// Some(u) => Json(format!("Hello {}", u.user_id)), +/// None => Json("Hello anonymous".to_string()), +/// } +/// } +/// ``` +pub struct MaybeAuthenticated(pub Option<AuthenticatedUser>); + +impl FromRequestParts<SharedState> for MaybeAuthenticated { + type Rejection = std::convert::Infallible; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await.ok(); + Ok(MaybeAuthenticated(user)) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_api_key() { + let key = "mk_test123456789"; + let hash = hash_api_key(key); + + // Hash should be consistent + assert_eq!(hash, hash_api_key(key)); + + // Hash should be 64 characters (SHA-256 hex) + assert_eq!(hash.len(), 64); + } + + #[test] + fn test_auth_error_display() { + assert_eq!( + AuthError::MissingToken.to_string(), + "Missing authentication token" + ); + assert_eq!( + AuthError::InvalidToken("bad".to_string()).to_string(), + "Invalid token: bad" + ); + } + + #[test] + fn test_generate_api_key_format() { + let generated = generate_api_key(); + + // Full key should start with mk_ prefix + assert!(generated.full_key.starts_with(API_KEY_PREFIX)); + + // Full key should be mk_ + 43 chars (32 bytes base64url encoded) + assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43 + + // Prefix should be mk_ + first 8 chars + assert!(generated.key_prefix.starts_with(API_KEY_PREFIX)); + assert_eq!(generated.key_prefix.len(), 3 + 8); + + // Hash should be 64 hex chars (SHA-256) + assert_eq!(generated.key_hash.len(), 64); + } + + #[test] + fn test_generate_api_key_uniqueness() { + let key1 = generate_api_key(); + let key2 = generate_api_key(); + + // Keys should be unique + assert_ne!(key1.full_key, key2.full_key); + assert_ne!(key1.key_hash, key2.key_hash); + } + + #[test] + fn test_api_key_cache_basic() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_123"; + + // Cache miss initially + assert!(cache.get(key_hash).is_none()); + + // Set and verify cache hit + cache.set(key_hash.to_string(), user_id, owner_id); + let result = cache.get(key_hash); + assert!(result.is_some()); + let (cached_user, cached_owner) = result.unwrap(); + assert_eq!(cached_user, user_id); + assert_eq!(cached_owner, owner_id); + } + + #[test] + fn test_api_key_cache_invalidate() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_456"; + + cache.set(key_hash.to_string(), user_id, owner_id); + assert!(cache.get(key_hash).is_some()); + + cache.invalidate(key_hash); + assert!(cache.get(key_hash).is_none()); + } + + #[test] + fn test_api_key_cache_clear() { + let cache = ApiKeyCache::new(300); + + cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4()); + cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4()); + + assert!(cache.get("hash1").is_some()); + assert!(cache.get("hash2").is_some()); + + cache.clear(); + + assert!(cache.get("hash1").is_none()); + assert!(cache.get("hash2").is_none()); + } + + #[test] + fn test_api_key_event_type_display() { + assert_eq!(ApiKeyEventType::Created.to_string(), "created"); + assert_eq!(ApiKeyEventType::Used.to_string(), "used"); + assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked"); + assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed"); + } +} diff --git a/makima/src/server/handlers/api_keys.rs b/makima/src/server/handlers/api_keys.rs new file mode 100644 index 0000000..5a678a2 --- /dev/null +++ b/makima/src/server/handlers/api_keys.rs @@ -0,0 +1,282 @@ +//! HTTP handlers for API key management. +//! +//! These endpoints allow users to create, view, refresh, and revoke their API keys. +//! API keys are used for daemon authentication and programmatic access. + +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + Json, +}; + +use crate::server::auth::{ + create_api_key, generate_api_key, get_active_api_key, refresh_api_key, revoke_api_key, + ApiKeyError, ApiKeyInfoResponse, CreateApiKeyRequest, CreateApiKeyResponse, + RefreshApiKeyRequest, RefreshApiKeyResponse, RevokeApiKeyResponse, UserOnly, +}; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +/// Create a new API key for the authenticated user. +/// +/// Each user can only have one active API key at a time. If an existing key +/// exists, this will return a 409 Conflict error - use the refresh endpoint +/// to replace the existing key, or revoke it first. +#[utoipa::path( + post, + path = "/api/v1/auth/api-keys", + request_body = CreateApiKeyRequest, + responses( + (status = 201, description = "API key created", body = CreateApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 409, description = "API key already exists", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn create_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<CreateApiKeyRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Generate a new API key + let generated = generate_api_key(); + + match create_api_key(pool, user.user_id, &generated, req.name.as_deref()).await { + Ok(key) => { + let response = CreateApiKeyResponse { + id: key.id, + key: generated.full_key, + prefix: key.key_prefix, + name: key.name, + created_at: key.created_at, + }; + (StatusCode::CREATED, Json(response)).into_response() + } + Err(ApiKeyError::KeyAlreadyExists) => ( + StatusCode::CONFLICT, + Json(ApiError::new( + "KEY_EXISTS", + "An active API key already exists. Revoke it first or use refresh.", + )), + ) + .into_response(), + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to create API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to create API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get information about the current active API key. +/// +/// Returns the key's ID, prefix (for identification), name, and timestamps. +/// The full key is never returned - it was only shown once when created. +#[utoipa::path( + get, + path = "/api/v1/auth/api-keys", + responses( + (status = 200, description = "API key info", body = ApiKeyInfoResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 404, description = "No active API key", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn get_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match get_active_api_key(pool, user.user_id).await { + Ok(Some(key)) => { + let response: ApiKeyInfoResponse = key.into(); + Json(response).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NO_KEY", "No active API key found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Refresh the current API key. +/// +/// This revokes the existing key (if any) and creates a new one atomically. +/// Use this for key rotation without downtime. +#[utoipa::path( + post, + path = "/api/v1/auth/api-keys/refresh", + request_body = RefreshApiKeyRequest, + responses( + (status = 200, description = "API key refreshed", body = RefreshApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn refresh_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<RefreshApiKeyRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Generate a new API key + let generated = generate_api_key(); + + match refresh_api_key(pool, user.user_id, &generated, req.name.as_deref()).await { + Ok((key, old_prefix)) => { + // Invalidate cache for the old key if we had a cache + // (The cache lookup is by hash, but we revoked the old key in DB so it won't match) + + let response = RefreshApiKeyResponse { + id: key.id, + key: generated.full_key, + prefix: key.key_prefix, + name: key.name, + created_at: key.created_at, + previous_key_revoked: old_prefix.is_some(), + }; + Json(response).into_response() + } + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to refresh API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to refresh API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Revoke the current active API key. +/// +/// After revocation, the key can no longer be used for authentication. +/// A new key can be created after revocation. +#[utoipa::path( + delete, + path = "/api/v1/auth/api-keys", + responses( + (status = 200, description = "API key revoked", body = RevokeApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 404, description = "No active API key", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn revoke_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match revoke_api_key(pool, user.user_id).await { + Ok(key) => { + let response = RevokeApiKeyResponse { + message: "API key revoked successfully".to_string(), + revoked_key_prefix: key.key_prefix, + }; + Json(response).into_response() + } + Err(ApiKeyError::KeyNotFound) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NO_KEY", "No active API key found")), + ) + .into_response(), + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to revoke API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to revoke API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index 51f17c1..dfdb64e 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -53,6 +53,9 @@ pub struct ChatRequest { /// Optional conversation history for context continuity #[serde(default)] pub history: Option<Vec<ChatHistoryMessage>>, + /// Optional focused element index (for targeted editing) + #[serde(default)] + pub focused_element_index: Option<usize>, } #[derive(Debug, Serialize, ToSchema)] @@ -232,6 +235,9 @@ pub async fn chat_handler( // Build context about the file let file_context = build_file_context(&file); + // Build focused element context if specified + let focused_context = build_focused_element_context(&file.body, request.focused_element_index); + // Build agentic system prompt let system_prompt = format!( r#"You are an intelligent document editing agent. You help users view, analyze, and modify document files. @@ -274,13 +280,14 @@ You have access to tools for: ## Current Document Context {file_context} - +{focused_context} ## Important Notes - Body element indices are 0-based - When updating elements, provide ALL required fields for that element type - The transcript is read-only (you cannot modify it, only read it) - Changes are saved automatically after tool execution"#, - file_context = file_context + file_context = file_context, + focused_context = focused_context ); // Build initial messages (Groq/OpenAI format - will be converted for Claude) @@ -690,12 +697,25 @@ fn build_file_context(file: &crate::db::models::File) -> String { let desc = match element { BodyElement::Heading { level, text } => format!("H{}: {}", level, text), BodyElement::Paragraph { text } => { - let preview = if text.len() > 50 { - format!("{}...", &text[..50]) + let preview: String = text.chars().take(50).collect(); + if text.chars().count() > 50 { + format!("Paragraph: {}...", preview) } else { - text.clone() - }; - format!("Paragraph: {}", preview) + format!("Paragraph: {}", preview) + } + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + let preview: String = content.chars().take(50).collect(); + if content.chars().count() > 50 { + format!("Code ({}): {}...", lang, preview) + } else { + format!("Code ({}): {}", lang, preview) + } + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "ordered" } else { "unordered" }; + format!("List ({}): {} items", list_type, items.len()) } BodyElement::Chart { chart_type, title, .. } => { format!( @@ -726,6 +746,64 @@ fn build_file_context(file: &crate::db::models::File) -> String { context } +/// Build context for a focused element +fn build_focused_element_context(body: &[BodyElement], focused_index: Option<usize>) -> String { + let Some(index) = focused_index else { + return String::new(); + }; + + let Some(element) = body.get(index) else { + return format!( + "\n## Focused Element\nNote: User focused on element [{}] but it doesn't exist (document has {} elements).\n", + index, + body.len() + ); + }; + + let (element_type, full_content) = match element { + BodyElement::Heading { level, text } => { + (format!("Heading (level {})", level), text.clone()) + } + BodyElement::Paragraph { text } => { + ("Paragraph".to_string(), text.clone()) + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + (format!("Code ({})", lang), content.clone()) + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "Ordered list" } else { "Unordered list" }; + let content = items.iter() + .enumerate() + .map(|(i, item)| format!("{}. {}", i + 1, item)) + .collect::<Vec<_>>() + .join("\n"); + (list_type.to_string(), content) + } + BodyElement::Chart { chart_type, title, .. } => { + let title_str = title.as_deref().unwrap_or("untitled"); + (format!("Chart ({:?})", chart_type), title_str.to_string()) + } + BodyElement::Image { alt, caption, .. } => { + let desc = alt.as_deref().or(caption.as_deref()).unwrap_or("no description"); + ("Image".to_string(), desc.to_string()) + } + }; + + format!( + r#" +## Focused Element +The user is focusing on element [{}]: {} +Full content of focused element: +--- +{} +--- +When the user's request is ambiguous about which element to modify, prioritize this focused element. +"#, + index, element_type, full_content + ) +} + /// Result of handling a version tool request struct VersionRequestResult { result: ToolResult, @@ -795,12 +873,25 @@ async fn handle_version_request( let desc = match element { BodyElement::Heading { level, text } => format!("H{}: {}", level, text), BodyElement::Paragraph { text } => { - let preview = if text.len() > 100 { - format!("{}...", &text[..100]) + let preview: String = text.chars().take(100).collect(); + if text.chars().count() > 100 { + format!("Paragraph: {}...", preview) } else { - text.clone() - }; - format!("Paragraph: {}", preview) + format!("Paragraph: {}", preview) + } + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + let preview: String = content.chars().take(100).collect(); + if content.chars().count() > 100 { + format!("Code ({}): {}...", lang, preview) + } else { + format!("Code ({}): {}", lang, preview) + } + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "ordered" } else { "unordered" }; + format!("List ({}): {} items", list_type, items.len()) } BodyElement::Chart { chart_type, title, .. } => { format!( diff --git a/makima/src/server/handlers/files.rs b/makima/src/server/handlers/files.rs index c65eed5..9634b73 100644 --- a/makima/src/server/handlers/files.rs +++ b/makima/src/server/handlers/files.rs @@ -10,21 +10,30 @@ use uuid::Uuid; use crate::db::models::{CreateFileRequest, FileListResponse, FileSummary, UpdateFileRequest}; use crate::db::repository::{self, RepositoryError}; +use crate::server::auth::Authenticated; use crate::server::messages::ApiError; use crate::server::state::{FileUpdateNotification, SharedState}; -/// List all files for the current owner. +/// List all files for the authenticated user's owner. #[utoipa::path( get, path = "/api/v1/files", responses( (status = 200, description = "List of files", body = FileListResponse), + (status = 401, description = "Unauthorized", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] -pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { +pub async fn list_files( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { return ( StatusCode::SERVICE_UNAVAILABLE, @@ -33,7 +42,7 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { .into_response(); }; - match repository::list_files(pool).await { + match repository::list_files_for_owner(pool, auth.owner_id).await { Ok(files) => { let summaries: Vec<FileSummary> = files.into_iter().map(FileSummary::from).collect(); let total = summaries.len() as i64; @@ -54,7 +63,7 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { } } -/// Get a single file by ID. +/// Get a single file by ID (scoped by owner). #[utoipa::path( get, path = "/api/v1/files/{id}", @@ -63,14 +72,20 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { ), responses( (status = 200, description = "File details", body = crate::db::models::File), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn get_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -81,7 +96,7 @@ pub async fn get_file( .into_response(); }; - match repository::get_file(pool, id).await { + match repository::get_file_for_owner(pool, id, auth.owner_id).await { Ok(Some(file)) => Json(file).into_response(), Ok(None) => ( StatusCode::NOT_FOUND, @@ -107,13 +122,19 @@ pub async fn get_file( responses( (status = 201, description = "File created", body = crate::db::models::File), (status = 400, description = "Invalid request", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn create_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Json(req): Json<CreateFileRequest>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -124,7 +145,7 @@ pub async fn create_file( .into_response(); }; - match repository::create_file(pool, req).await { + match repository::create_file_for_owner(pool, auth.owner_id, req).await { Ok(file) => (StatusCode::CREATED, Json(file)).into_response(), Err(e) => { tracing::error!("Failed to create file: {}", e); @@ -137,7 +158,7 @@ pub async fn create_file( } } -/// Update an existing file. +/// Update an existing file (scoped by owner). #[utoipa::path( put, path = "/api/v1/files/{id}", @@ -147,15 +168,21 @@ pub async fn create_file( request_body = UpdateFileRequest, responses( (status = 200, description = "File updated", body = crate::db::models::File), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 409, description = "Version conflict", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn update_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, Json(req): Json<UpdateFileRequest>, ) -> impl IntoResponse { @@ -185,7 +212,7 @@ pub async fn update_file( updated_fields.push("body".to_string()); } - match repository::update_file(pool, id, req).await { + match repository::update_file_for_owner(pool, id, auth.owner_id, req).await { Ok(Some(file)) => { // Broadcast update notification state.broadcast_file_update(FileUpdateNotification { @@ -233,7 +260,7 @@ pub async fn update_file( } } -/// Delete a file. +/// Delete a file (scoped by owner). #[utoipa::path( delete, path = "/api/v1/files/{id}", @@ -242,14 +269,20 @@ pub async fn update_file( ), responses( (status = 204, description = "File deleted"), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn delete_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -260,7 +293,7 @@ pub async fn delete_file( .into_response(); }; - match repository::delete_file(pool, id).await { + match repository::delete_file_for_owner(pool, id, auth.owner_id).await { Ok(true) => StatusCode::NO_CONTENT.into_response(), Ok(false) => ( StatusCode::NOT_FOUND, diff --git a/makima/src/server/handlers/mesh.rs b/makima/src/server/handlers/mesh.rs new file mode 100644 index 0000000..760740c --- /dev/null +++ b/makima/src/server/handlers/mesh.rs @@ -0,0 +1,1679 @@ +//! HTTP handlers for task and daemon mesh operations. + +use axum::{ + extract::{Path, State}, + http::{HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use uuid::Uuid; + +use crate::db::models::{ + CreateTaskRequest, DaemonDirectory, DaemonDirectoriesResponse, DaemonListResponse, + SendMessageRequest, Task, TaskEventListResponse, TaskListResponse, TaskOutputEntry, + TaskOutputResponse, TaskWithSubtasks, UpdateTaskRequest, +}; +use crate::db::repository::{self, RepositoryError}; +use crate::server::auth::Authenticated; +use crate::server::messages::ApiError; +use crate::server::state::{DaemonCommand, SharedState, TaskUpdateNotification}; + +// ============================================================================= +// Authentication Types +// ============================================================================= + +/// Source of authentication for mesh endpoints. +#[derive(Debug, Clone)] +pub enum AuthSource { + /// Authenticated via tool key (orchestrator accessing API). + /// Contains the task ID that owns this key. + ToolKey(Uuid), + /// Authenticated via user token (web client). + /// Contains the user ID. (Not implemented yet) + #[allow(dead_code)] + UserToken(Uuid), + /// No authentication provided (anonymous access). + Anonymous, +} + +/// Header name for tool key authentication. +pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; + +/// Extract authentication source from request headers. +/// +/// Checks for: +/// 1. `X-Makima-Tool-Key` header for orchestrator tool access +/// 2. `Authorization: Bearer` header for user access (future) +/// 3. Falls back to Anonymous if no auth provided +pub fn extract_auth(state: &SharedState, headers: &HeaderMap) -> AuthSource { + // Check for tool key header first + if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { + if let Ok(key_str) = tool_key.to_str() { + if let Some(task_id) = state.validate_tool_key(key_str) { + return AuthSource::ToolKey(task_id); + } + tracing::warn!("Invalid tool key provided"); + } + } + + // Check for Authorization header (future user auth) + if let Some(auth_header) = headers.get("authorization") { + if let Ok(auth_str) = auth_header.to_str() { + if auth_str.starts_with("Bearer ") { + // Future: validate JWT and extract user ID + tracing::debug!("Bearer token auth not yet implemented"); + } + } + } + + // Default to anonymous + AuthSource::Anonymous +} + +// ============================================================================= +// Task Handlers +// ============================================================================= + +/// List all tasks for the current owner. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks", + responses( + (status = 200, description = "List of tasks", body = TaskListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_tasks( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::list_tasks_for_owner(pool, auth.owner_id).await { + Ok(tasks) => { + let total = tasks.len() as i64; + Json(TaskListResponse { tasks, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list tasks: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get a single task by ID with its subtasks (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task details with subtasks", body = TaskWithSubtasks), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(task)) => { + // Get subtasks for this task (also scoped by owner) + match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(subtasks) => Json(TaskWithSubtasks { task, subtasks }).into_response(), + Err(e) => { + tracing::error!("Failed to get subtasks for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Create a new task. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks", + request_body = CreateTaskRequest, + responses( + (status = 201, description = "Task created", body = Task), + (status = 400, description = "Invalid request", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn create_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Json(req): Json<CreateTaskRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::create_task_for_owner(pool, auth.owner_id, req).await { + Ok(task) => (StatusCode::CREATED, Json(task)).into_response(), + Err(e) => { + tracing::error!("Failed to create task: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Update an existing task (scoped by owner). +#[utoipa::path( + put, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = UpdateTaskRequest, + responses( + (status = 200, description = "Task updated", body = Task), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 409, description = "Version conflict", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn update_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(req): Json<UpdateTaskRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Track which fields are being updated for the notification + let mut updated_fields = Vec::new(); + if req.name.is_some() { + updated_fields.push("name".to_string()); + } + if req.description.is_some() { + updated_fields.push("description".to_string()); + } + if req.status.is_some() { + updated_fields.push("status".to_string()); + } + if req.priority.is_some() { + updated_fields.push("priority".to_string()); + } + if req.plan.is_some() { + updated_fields.push("plan".to_string()); + } + if req.progress_summary.is_some() { + updated_fields.push("progress_summary".to_string()); + } + if req.error_message.is_some() { + updated_fields.push("error_message".to_string()); + } + + match repository::update_task_for_owner(pool, id, auth.owner_id, req).await { + Ok(Some(task)) => { + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: task.id, + owner_id: Some(auth.owner_id), + version: task.version, + status: task.status.clone(), + updated_fields, + updated_by: "user".to_string(), + }); + Json(task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(RepositoryError::VersionConflict { expected, actual }) => { + tracing::info!( + "Version conflict on task {}: expected {}, actual {}", + id, + expected, + actual + ); + ( + StatusCode::CONFLICT, + Json(serde_json::json!({ + "code": "VERSION_CONFLICT", + "message": format!( + "Task was modified by another user. Expected version {}, actual version {}", + expected, actual + ), + "expectedVersion": expected, + "actualVersion": actual, + })), + ) + .into_response() + } + Err(RepositoryError::Database(e)) => { + tracing::error!("Failed to update task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Delete a task (scoped by owner). +#[utoipa::path( + delete, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 204, description = "Task deleted"), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn delete_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task first to check if it's running and needs to be stopped + if let Ok(Some(task)) = repository::get_task_for_owner(pool, id, auth.owner_id).await { + let is_active = matches!( + task.status.as_str(), + "running" | "starting" | "initializing" | "paused" + ); + + // If task is active and has a daemon, send interrupt command + if is_active { + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { + task_id: id, + graceful: false, + }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::warn!( + task_id = %id, + daemon_id = %daemon_id, + "Failed to send InterruptTask before delete: {}", + e + ); + } else { + tracing::info!( + task_id = %id, + daemon_id = %daemon_id, + "Sent InterruptTask before delete" + ); + } + } + } + } + + match repository::delete_task_for_owner(pool, id, auth.owner_id).await { + Ok(true) => StatusCode::NO_CONTENT.into_response(), + Ok(false) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to delete task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Start a task by sending it to an available daemon (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/start", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task started", body = Task), + (status = 400, description = "Task cannot be started", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or no daemons available", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn start_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + headers: HeaderMap, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + // Extract authentication to log who is starting the task + let legacy_auth = extract_auth(&state, &headers); + match &legacy_auth { + AuthSource::ToolKey(orchestrator_id) => { + tracing::info!( + task_id = %id, + orchestrator_task_id = %orchestrator_id, + owner_id = %auth.owner_id, + "Orchestrator starting subtask via tool key" + ); + } + AuthSource::Anonymous => { + tracing::info!( + task_id = %id, + owner_id = %auth.owner_id, + "Starting task (user request)" + ); + } + AuthSource::UserToken(user_id) => { + tracing::info!( + task_id = %id, + user_id = %user_id, + owner_id = %auth.owner_id, + "Starting task via user token" + ); + } + } + + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task can be started (allow pending, failed, interrupted, done, or merged) + let startable_statuses = ["pending", "failed", "interrupted", "done", "merged"]; + if !startable_statuses.contains(&task.status.as_str()) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task cannot be started from status: {}", task.status), + )), + ) + .into_response(); + } + + // Find an available daemon belonging to this owner + let target_daemon_id = match state.daemon_connections + .iter() + .find(|d| d.value().owner_id == auth.owner_id) + { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot start task.", + )), + ) + .into_response(); + } + }; + + // Check if this is an orchestrator (depth 0 with subtasks) + let subtask_count = match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(subtasks) => { + tracing::info!( + task_id = %id, + subtask_count = subtasks.len(), + subtask_ids = ?subtasks.iter().map(|s| s.id.to_string()).collect::<Vec<_>>(), + "Counted subtasks for orchestrator check" + ); + subtasks.len() + }, + Err(e) => { + tracing::warn!("Failed to check subtasks for {}: {}", id, e); + 0 + } + }; + let is_orchestrator = task.depth == 0 && subtask_count > 0; + + tracing::info!( + task_id = %id, + task_depth = task.depth, + subtask_count = subtask_count, + is_orchestrator = is_orchestrator, + "Starting task with orchestrator determination" + ); + + // IMPORTANT: Update database FIRST to assign daemon_id before sending command + // This prevents race conditions where the task starts but daemon_id is not set + let update_req = UpdateTaskRequest { + status: Some("starting".to_string()), + daemon_id: Some(target_daemon_id), + version: Some(task.version), + ..Default::default() + }; + + let updated_task = match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Send SpawnTask command to daemon + let command = DaemonCommand::SpawnTask { + task_id: id, + task_name: task.name.clone(), + plan: task.plan.clone(), + repo_url: task.repository_url.clone(), + base_branch: task.base_branch.clone(), + target_branch: task.target_branch.clone(), + parent_task_id: task.parent_task_id, + depth: task.depth, + is_orchestrator, + target_repo_path: task.target_repo_path.clone(), + completion_action: task.completion_action.clone(), + continue_from_task_id: task.continue_from_task_id, + copy_files: task.copy_files.as_ref().and_then(|v| serde_json::from_value(v.clone()).ok()), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send SpawnTask command: {}", e); + // Rollback: clear daemon_id and reset status since command failed + let rollback_req = UpdateTaskRequest { + status: Some("pending".to_string()), + clear_daemon_id: true, // Explicitly clear daemon_id + ..Default::default() + }; + let _ = repository::update_task_for_owner(pool, id, auth.owner_id, rollback_req).await; + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "starting".to_string(), + updated_fields: vec!["status".to_string(), "daemon_id".to_string()], + updated_by: "system".to_string(), + }); + + Json(updated_task).into_response() +} + +/// Stop a running task (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/stop", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task stopped", body = Task), + (status = 400, description = "Task is not running", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn stop_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is running/active + let is_active = matches!( + task.status.as_str(), + "running" | "starting" | "initializing" | "paused" + ); + if !is_active { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task cannot be stopped from status: {}", task.status), + )), + ) + .into_response(); + } + + // Find the daemon running this task + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + daemon_id + } else { + // No daemon assigned, just update status directly + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user".to_string()), + version: Some(task.version), + ..Default::default() + }; + + return match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + }; + }; + + // Send InterruptTask command to daemon + let command = DaemonCommand::InterruptTask { + task_id: id, + graceful: false, + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::warn!("Failed to send InterruptTask command: {}", e); + // Daemon might be disconnected - update task status directly + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user (daemon unavailable)".to_string()), + version: Some(task.version), + ..Default::default() + }; + + return match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + }; + } + + // Update task status to "failed" (stopped) + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user".to_string()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Send a message to a running task's stdin (scoped by owner). +/// +/// This can be used to provide input to Claude Code when it's waiting for user input, +/// or to inject context/instructions into a running task. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/message", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = SendMessageRequest, + responses( + (status = 200, description = "Message sent successfully"), + (status = 400, description = "Task is not running", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn send_message( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(req): Json<SendMessageRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is running + if task.status != "running" { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!( + "Cannot send message to task in status: {}. Task must be running.", + task.status + ), + )), + ) + .into_response(); + } + + // Find the daemon running this task + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + daemon_id + } else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "Task has no assigned daemon. Cannot send message.", + )), + ) + .into_response(); + }; + + // Send SendMessage command to daemon + let command = DaemonCommand::SendMessage { + task_id: id, + message: req.message.clone(), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send SendMessage command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + tracing::info!(task_id = %id, message_len = req.message.len(), "Message sent to task"); + + // Return success + ( + StatusCode::OK, + Json(serde_json::json!({ + "success": true, + "taskId": id, + "messageLength": req.message.len() + })), + ) + .into_response() +} + +/// Get task output history (scoped by owner). +/// +/// Retrieves all recorded output from a task's Claude Code process. +/// This allows the frontend to fetch missed output when subscribing late +/// or reconnecting after a disconnect. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/output", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task output history", body = TaskOutputResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_task_output( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Verify task exists and belongs to owner + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + } + + // Get output history (task already verified to belong to owner) + match repository::get_task_output(pool, id, None).await { + Ok(events) => { + let entries: Vec<TaskOutputEntry> = events + .into_iter() + .filter_map(TaskOutputEntry::from_task_event) + .collect(); + let total = entries.len(); + + Json(TaskOutputResponse { + entries, + total, + task_id: id, + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to get task output: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// List subtasks for a parent task (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/subtasks", + params( + ("id" = Uuid, Path, description = "Parent task ID") + ), + responses( + (status = 200, description = "List of subtasks", body = TaskListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_subtasks( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(tasks) => { + let total = tasks.len() as i64; + Json(TaskListResponse { tasks, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list subtasks for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// List events for a task (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/events", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "List of task events", body = TaskEventListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_task_events( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Verify task exists and belongs to owner + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + } + + match repository::list_task_events(pool, id, None).await { + Ok(events) => { + let total = events.len() as i64; + Json(TaskEventListResponse { events, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list events for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Retry completion action for a completed task (scoped by owner). +/// +/// This allows retrying a completion action (push branch, merge, create PR) +/// after filling in the target_repo_path if it wasn't set when the task completed. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/retry-completion", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Completion action initiated"), + (status = 400, description = "Invalid request (task not completed, no completion action, etc.)", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn retry_completion_action( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is in a terminal state + let terminal_statuses = ["done", "failed", "merged"]; + if !terminal_statuses.contains(&task.status.as_str()) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!( + "Task must be completed to retry completion action. Current status: {}", + task.status + ), + )), + ) + .into_response(); + } + + // Check if completion action is set + let action = match &task.completion_action { + Some(action) if action != "none" => action.clone(), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "NO_COMPLETION_ACTION", + "Task has no completion action configured (or is set to 'none')", + )), + ) + .into_response(); + } + }; + + // Check if target_repo_path is set + let target_repo_path = match &task.target_repo_path { + Some(path) if !path.is_empty() => path.clone(), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "NO_TARGET_REPO", + "Target repository path must be set before retrying completion action", + )), + ) + .into_response(); + } + }; + + // Note: We don't check overlay_path here because the server may not have it + // The daemon will scan its worktrees directory to find the worktree by task ID + + // Find a daemon to execute the action (must belong to this owner) + // Prefer the daemon that ran the task, but fall back to any available daemon for this owner + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + // Check if this daemon is still connected and belongs to this owner + if state.daemon_connections.iter().any(|d| d.value().id == daemon_id && d.value().owner_id == auth.owner_id) { + daemon_id + } else { + // Fall back to any connected daemon for this owner + match state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id) { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot execute completion action.", + )), + ) + .into_response(); + } + } + } + } else { + // No daemon assigned - use any available for this owner + match state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id) { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot execute completion action.", + )), + ) + .into_response(); + } + } + }; + + // Send RetryCompletionAction command to daemon + let command = DaemonCommand::RetryCompletionAction { + task_id: id, + task_name: task.name.clone(), + action: action.clone(), + target_repo_path: target_repo_path.clone(), + target_branch: task.target_branch.clone(), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send RetryCompletionAction command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + tracing::info!( + task_id = %id, + action = %action, + target_repo = %target_repo_path, + "Retry completion action initiated" + ); + + ( + StatusCode::OK, + Json(serde_json::json!({ + "success": true, + "taskId": id, + "action": action, + "targetRepoPath": target_repo_path, + "message": "Completion action initiated. Check task output for results." + })), + ) + .into_response() +} + +// ============================================================================= +// Daemon Handlers +// ============================================================================= + +/// List all connected daemons (requires authentication). +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons", + responses( + (status = 200, description = "List of daemons", body = DaemonListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_daemons( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Only list daemons belonging to this owner + match repository::list_daemons_for_owner(pool, auth.owner_id).await { + Ok(daemons) => { + let total = daemons.len() as i64; + Json(DaemonListResponse { daemons, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list daemons: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get a single daemon by ID (requires authentication). +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/{id}", + params( + ("id" = Uuid, Path, description = "Daemon ID") + ), + responses( + (status = 200, description = "Daemon details", body = crate::db::models::Daemon), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Daemon not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_daemon( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Only get daemon if it belongs to this owner + match repository::get_daemon_for_owner(pool, id, auth.owner_id).await { + Ok(Some(daemon)) => Json(daemon).into_response(), + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Daemon not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get daemon {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get suggested directories from connected daemons (requires authentication). +/// +/// Returns directories that can be used as target_repo_path for completion actions. +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/directories", + responses( + (status = 200, description = "List of suggested directories", body = DaemonDirectoriesResponse), + (status = 401, description = "Unauthorized", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_daemon_directories( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let mut directories = Vec::new(); + + // Iterate over connected daemons belonging to this owner and collect their directories + for entry in state.daemon_connections.iter() { + let daemon = entry.value(); + + // Only include daemons belonging to this owner + if daemon.owner_id != auth.owner_id { + continue; + } + + // Add working directory if available + if let Some(ref working_dir) = daemon.working_directory { + directories.push(DaemonDirectory { + path: working_dir.clone(), + label: "Working Directory".to_string(), + directory_type: "working".to_string(), + hostname: daemon.hostname.clone(), + exists: None, + }); + } + + // Add home directory if available (for cloning completed work) + if let Some(ref home_dir) = daemon.home_directory { + directories.push(DaemonDirectory { + path: home_dir.clone(), + label: "Makima Home".to_string(), + directory_type: "home".to_string(), + hostname: daemon.hostname.clone(), + exists: None, + }); + } + } + + Json(DaemonDirectoriesResponse { directories }) +} + +/// Request to clone a worktree to a target directory. +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CloneWorktreeRequest { + /// Path to the target directory. + pub target_dir: String, +} + +/// Clone a task's worktree to a target directory (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/clone", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = CloneWorktreeRequest, + responses( + (status = 200, description = "Clone command sent"), + (status = 400, description = "Invalid request or task not completed", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn clone_worktree( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(body): Json<CloneWorktreeRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Verify task is in a completed state + let is_completed = matches!(task.status.as_str(), "done" | "failed" | "merged"); + if !is_completed { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task must be completed to clone (current status: {})", task.status), + )), + ) + .into_response(); + } + + // Find a connected daemon belonging to this owner to send the command + let daemon_entry = state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id); + let daemon_id = match daemon_entry { + Some(entry) => entry.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("NO_DAEMON", "No daemon connected for your account")), + ) + .into_response(); + } + }; + + // Send CloneWorktree command to daemon + let command = DaemonCommand::CloneWorktree { + task_id: id, + target_dir: body.target_dir.clone(), + }; + + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::error!("Failed to send CloneWorktree command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + Json(serde_json::json!({ + "status": "cloning", + "taskId": id.to_string(), + "targetDir": body.target_dir, + })) + .into_response() +} + +/// Request to check if a target directory exists. +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CheckTargetExistsRequest { + /// Path to check. + pub target_dir: String, +} + +/// Response for check target exists. +#[derive(Debug, serde::Serialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CheckTargetExistsResponse { + /// Whether the target directory exists. + pub exists: bool, + /// The path that was checked (expanded). + pub target_dir: String, +} + +/// Check if a target directory exists (for clone validation, requires authentication). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/check-target", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = CheckTargetExistsRequest, + responses( + (status = 200, description = "Check result", body = CheckTargetExistsResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "No daemon connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn check_target_exists( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(body): Json<CheckTargetExistsRequest>, +) -> impl IntoResponse { + // Find a connected daemon belonging to this owner to send the command + let daemon_entry = state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id); + let daemon_id = match daemon_entry { + Some(entry) => entry.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("NO_DAEMON", "No daemon connected for your account")), + ) + .into_response(); + } + }; + + // Send CheckTargetExists command to daemon + let command = DaemonCommand::CheckTargetExists { + task_id: id, + target_dir: body.target_dir.clone(), + }; + + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::error!("Failed to send CheckTargetExists command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + // The actual result will be sent back via WebSocket + // For now, just acknowledge the request was sent + Json(serde_json::json!({ + "status": "checking", + "taskId": id.to_string(), + "targetDir": body.target_dir, + })) + .into_response() +} diff --git a/makima/src/server/handlers/mesh_chat.rs b/makima/src/server/handlers/mesh_chat.rs new file mode 100644 index 0000000..5d6d2ee --- /dev/null +++ b/makima/src/server/handlers/mesh_chat.rs @@ -0,0 +1,2088 @@ +//! Chat endpoint for LLM-powered task orchestration. +//! +//! This handler provides an agentic loop for managing tasks, daemons, and +//! overlay operations through LLM tool calling. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::db::{models::CreateTaskRequest, repository}; +use crate::llm::{ + claude::{self, ClaudeClient, ClaudeError, ClaudeModel}, + groq::{GroqClient, GroqError, Message, ToolCallResponse}, + parse_mesh_tool_call, LlmModel, MeshToolRequest, ToolCall, ToolResult, UserQuestion, + MESH_TOOLS, +}; +use crate::server::auth::Authenticated; +use crate::server::state::{DaemonCommand, SharedState, TaskUpdateNotification}; + +/// Maximum number of tool-calling rounds to prevent infinite loops +const MAX_TOOL_ROUNDS: usize = 30; + +#[derive(Debug, Clone, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatHistoryMessage { + /// Role: "user" or "assistant" + pub role: String, + /// Message content + pub content: String, +} + +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatRequest { + /// The user's message/instruction + pub message: String, + /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq" + #[serde(default)] + pub model: Option<String>, + /// Optional conversation history for context continuity (deprecated - now loaded from DB) + #[serde(default)] + pub history: Option<Vec<MeshChatHistoryMessage>>, + /// Context type: "mesh", "task", or "subtask" + #[serde(default)] + pub context_type: Option<String>, + /// Task ID if context is task/subtask + #[serde(default)] + pub context_task_id: Option<Uuid>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatResponse { + /// The LLM's response message + pub response: String, + /// Tool calls that were executed + pub tool_calls: Vec<MeshToolCallInfo>, + /// Questions pending user answers (pauses conversation) + #[serde(skip_serializing_if = "Option::is_none")] + pub pending_questions: Option<Vec<UserQuestion>>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshToolCallInfo { + pub name: String, + pub result: ToolResult, +} + +/// Enum to hold LLM clients +enum LlmClient { + Groq(GroqClient), + Claude(ClaudeClient), +} + +/// Unified result from LLM call +struct LlmResult { + content: Option<String>, + tool_calls: Vec<ToolCall>, + raw_tool_calls: Vec<ToolCallResponse>, + finish_reason: String, +} + +/// Chat with mesh orchestrator at the top level (no specific task context) +#[utoipa::path( + post, + path = "/api/v1/mesh/chat", + request_body = MeshChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = MeshChatResponse), + (status = 401, description = "Unauthorized"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn mesh_toplevel_chat_handler( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Json(request): Json<MeshChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or(LlmModel::ClaudeSonnet); + + tracing::info!("Mesh top-level chat using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::ClaudeOpus => match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::GroqKimi => match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "GROQ_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Groq client error: {}", e) })), + ) + .into_response(); + } + }, + }; + + // Build context about all tasks and daemons + let mesh_context = build_mesh_overview_context(pool, &state, auth.owner_id).await; + + // Build agentic system prompt for top-level mesh orchestration + let system_prompt = format!( + r#"You are an intelligent task orchestration agent. You help users manage and coordinate tasks running on connected daemons with Claude Code containers. + +## Your Capabilities +You have access to tools for: +- **Task Lifecycle**: create_task, run_task, pause_task, resume_task, interrupt_task, discard_task +- **Task Queries**: query_task_status, list_tasks, list_subtasks, list_siblings, list_daemons +- **File Access**: list_files, read_file (read documents from the files system) +- **Task Communication**: send_message_to_task, update_task_plan +- **Overlay/Merge Operations**: peek_sibling_overlay, get_overlay_diff, preview_merge, merge_subtask, complete_task, set_merge_mode + +## Current Mesh Overview +{mesh_context} + +## Agentic Behavior Guidelines + +### 1. Analyze Before Acting +- For complex orchestration requests, first gather information using query_task_status, list_tasks, or list_daemons +- Understand the current state before making changes +- For simple, direct requests (e.g., "create a new task"), you can act immediately + +### 2. Plan Multi-Step Operations +- Break complex orchestration into logical steps +- For parallel execution: create multiple subtasks, then run them on different daemons +- For sequential execution: create subtasks and run them in order + +### 3. Create and Manage Tasks +- Use create_task to create new top-level tasks or subtasks +- Assign appropriate priorities and plans +- **Repository Default**: When creating tasks, use the daemon's working directory as the repository_url by default (shown as "Default Repository" above). Only omit repository_url if the task doesn't involve code, or use a different URL if the user explicitly requests it. +- If a working directory is a git repository, use it as the repository_url for code-related tasks + +### 4. Coordinate Multiple Tasks +- Use list_tasks to see all tasks and their statuses +- Use list_daemons to see available compute resources +- Balance workload across daemons + +### 5. Be Efficient +- Don't over-analyze simple requests +- Use the minimum number of tool calls needed +- Provide clear summaries of actions taken + +## Important Notes +- Task IDs are UUIDs - ensure you use the correct format +- Running a task requires at least one connected daemon +- When creating subtasks, specify the parent_task_id +- Always confirm destructive operations (discard_task) with the user"#, + mesh_context = mesh_context + ); + + // Run the shared agentic loop + run_mesh_agentic_loop(pool, &state, &llm_client, system_prompt, &request, auth.owner_id).await +} + +/// Chat with task mesh orchestrator using LLM tool calling (scoped by owner) +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/chat", + request_body = MeshChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = MeshChatResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Task not found"), + (status = 500, description = "Internal server error") + ), + params( + ("id" = Uuid, Path, description = "Task ID (context for orchestration)") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn mesh_chat_handler( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(task_id): Path<Uuid>, + Json(request): Json<MeshChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + // Get the context task (scoped by owner) + let task = match repository::get_task_for_owner(pool, task_id, auth.owner_id).await { + Ok(Some(task)) => task, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({ "error": "Task not found" })), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Database error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Database error: {}", e) })), + ) + .into_response(); + } + }; + + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or(LlmModel::ClaudeSonnet); + + tracing::info!("Mesh chat using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::ClaudeOpus => match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::GroqKimi => match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "GROQ_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Groq client error: {}", e) })), + ) + .into_response(); + } + }, + }; + + // Build context about the current task and mesh state + let task_context = build_task_context(&task); + + // Build agentic system prompt for task orchestration + let system_prompt = format!( + r#"You are an intelligent task orchestration agent. You help users manage and coordinate tasks running on connected daemons with Claude Code containers. + +## Your Capabilities +You have access to tools for: +- **Task Lifecycle**: create_task, run_task, pause_task, resume_task, interrupt_task, discard_task +- **Task Queries**: query_task_status, list_tasks, list_subtasks, list_siblings, list_daemons +- **File Access**: list_files, read_file (read documents from the files system) +- **Task Communication**: send_message_to_task, update_task_plan +- **Overlay/Merge Operations**: peek_sibling_overlay, get_overlay_diff, preview_merge, merge_subtask, complete_task, set_merge_mode + +## Current Context +{task_context} + +## Agentic Behavior Guidelines + +### 1. Analyze Before Acting +- For complex orchestration requests, first gather information using query_task_status, list_tasks, or list_daemons +- Understand the current state before making changes +- For simple, direct requests (e.g., "pause this task"), you can act immediately + +### 2. Plan Multi-Step Operations +- Break complex orchestration into logical steps +- For parallel execution: create multiple subtasks, then run them on different daemons +- For sequential execution: create subtasks and run them in order + +### 3. Monitor Task Progress +- Use query_task_status to check on running tasks +- Watch for status changes and react accordingly +- Handle failures gracefully (retry, escalate, or report) + +### 4. Coordinate Sibling Tasks +- Use peek_sibling_overlay to see what other tasks have changed +- Preview merges before completing to catch conflicts +- Coordinate timing when multiple tasks need to merge + +### 5. Be Efficient +- Don't over-analyze simple requests +- Use the minimum number of tool calls needed +- Provide clear summaries of actions taken + +## Important Notes +- Task IDs are UUIDs - ensure you use the correct format +- Running a task requires at least one connected daemon +- Overlay operations require the task to have been run at least once +- Always confirm destructive operations (discard_task) with the user +- When creating subtasks for this task, use parent_task_id: {task_id}"#, + task_context = task_context, + task_id = task_id + ); + + // Run the shared agentic loop + run_mesh_agentic_loop(pool, &state, &llm_client, system_prompt, &request, auth.owner_id).await +} + +fn build_task_context(task: &crate::db::models::Task) -> String { + let mut context = format!( + "Current Task: {} (ID: {})\n", + task.name, task.id + ); + context.push_str(&format!("Status: {}\n", task.status)); + context.push_str(&format!("Priority: {}\n", task.priority)); + + if let Some(ref desc) = task.description { + context.push_str(&format!("Description: {}\n", desc)); + } + + // Truncate plan preview if too long + let plan_preview = if task.plan.len() > 200 { + format!("{}...", &task.plan[..200]) + } else { + task.plan.clone() + }; + context.push_str(&format!("Plan: {}\n", plan_preview)); + + if let Some(ref summary) = task.progress_summary { + context.push_str(&format!("Progress: {}\n", summary)); + } + + if let Some(ref error) = task.error_message { + context.push_str(&format!("Error: {}\n", error)); + } + + // Repository info + if let Some(ref url) = task.repository_url { + context.push_str(&format!("Repository: {}\n", url)); + } + if let Some(ref branch) = task.base_branch { + context.push_str(&format!("Base branch: {}\n", branch)); + } + + context +} + +/// Build overview context for top-level mesh orchestration +async fn build_mesh_overview_context(pool: &sqlx::PgPool, state: &SharedState, owner_id: Uuid) -> String { + let mut context = String::new(); + + // Get task counts by status + match repository::list_tasks_for_owner(pool, owner_id).await { + Ok(tasks) => { + let total = tasks.len(); + let pending = tasks.iter().filter(|t| t.status == "pending").count(); + let running = tasks.iter().filter(|t| t.status == "running").count(); + let paused = tasks.iter().filter(|t| t.status == "paused").count(); + let done = tasks.iter().filter(|t| t.status == "done").count(); + let failed = tasks.iter().filter(|t| t.status == "failed").count(); + + context.push_str(&format!( + "Tasks: {} total ({} pending, {} running, {} paused, {} done, {} failed)\n", + total, pending, running, paused, done, failed + )); + + // List recent/active tasks + if !tasks.is_empty() { + context.push_str("\nRecent Tasks:\n"); + for task in tasks.iter().take(5) { + context.push_str(&format!( + " - {} (ID: {}, Status: {})\n", + task.name, task.id, task.status + )); + } + if tasks.len() > 5 { + context.push_str(&format!(" ... and {} more\n", tasks.len() - 5)); + } + } + } + Err(e) => { + context.push_str(&format!("Error fetching tasks: {}\n", e)); + } + } + + // Get connected daemons for this owner + let owner_daemons: Vec<_> = state.daemon_connections.iter() + .filter(|e| e.value().owner_id == owner_id) + .collect(); + let daemon_count = owner_daemons.len(); + context.push_str(&format!("\nConnected Daemons: {}\n", daemon_count)); + + for entry in owner_daemons.iter().take(3) { + let daemon = entry.value(); + let working_dir = daemon.working_directory.as_deref().unwrap_or("not set"); + context.push_str(&format!( + " - {} (ID: {}, Working Directory: {})\n", + daemon.hostname.as_deref().unwrap_or("unknown"), + daemon.id, + working_dir + )); + } + + // Add default repository guidance if there's exactly one daemon with a working directory + let daemons_with_working_dir: Vec<_> = owner_daemons.iter() + .filter(|e| e.value().working_directory.is_some()) + .collect(); + + if daemons_with_working_dir.len() == 1 { + if let Some(dir) = &daemons_with_working_dir[0].value().working_directory { + context.push_str(&format!( + "\nDefault Repository: {} (use this as repository_url when creating tasks unless user specifies otherwise)\n", + dir + )); + } + } + + context +} + +/// Run the shared agentic loop for mesh chat +async fn run_mesh_agentic_loop( + pool: &sqlx::PgPool, + state: &SharedState, + llm_client: &LlmClient, + system_prompt: String, + request: &MeshChatRequest, + owner_id: Uuid, +) -> axum::response::Response { + // Get or create conversation for storing messages + let conversation = match repository::get_or_create_active_conversation(pool, owner_id).await { + Ok(c) => c, + Err(e) => { + tracing::error!("Failed to get/create conversation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Failed to initialize conversation: {}", e) })), + ) + .into_response(); + } + }; + + // Build initial messages + let mut messages = vec![Message { + role: "system".to_string(), + content: Some(system_prompt), + tool_calls: None, + tool_call_id: None, + }]; + + // Load conversation history from database (or use provided for backwards compatibility) + if let Some(history) = &request.history { + // Legacy: use provided history + for hist_msg in history { + messages.push(Message { + role: hist_msg.role.clone(), + content: Some(hist_msg.content.clone()), + tool_calls: None, + tool_call_id: None, + }); + } + tracing::info!( + history_messages = history.len(), + "Loaded mesh conversation history from request (legacy)" + ); + } else { + // New: load from database + match repository::list_chat_messages(pool, conversation.id, Some(50)).await { + Ok(db_messages) => { + for msg in db_messages { + messages.push(Message { + role: msg.role.clone(), + content: Some(msg.content.clone()), + tool_calls: None, + tool_call_id: None, + }); + } + tracing::info!( + history_messages = messages.len() - 1, // minus system message + "Loaded mesh conversation history from database" + ); + } + Err(e) => { + tracing::warn!("Failed to load chat history: {}", e); + // Continue without history + } + } + } + + // Add current user message + messages.push(Message { + role: "user".to_string(), + content: Some(request.message.clone()), + tool_calls: None, + tool_call_id: None, + }); + + // State for tracking + let mut all_tool_call_infos: Vec<MeshToolCallInfo> = Vec::new(); + let mut final_response: Option<String> = None; + let mut consecutive_failures = 0; + const MAX_CONSECUTIVE_FAILURES: usize = 3; + let mut pending_questions: Option<Vec<UserQuestion>> = None; + + // Multi-turn agentic tool calling loop + for round in 0..MAX_TOOL_ROUNDS { + tracing::info!( + round = round, + total_tool_calls = all_tool_call_infos.len(), + "Mesh agentic loop iteration" + ); + + // Check consecutive failures + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + tracing::warn!( + "Breaking mesh loop due to {} consecutive failures", + consecutive_failures + ); + final_response = Some( + "I encountered multiple consecutive errors and stopped. \ + Please check the task state and try again." + .to_string(), + ); + break; + } + + // Call the appropriate LLM API + let result = match llm_client { + LlmClient::Groq(groq) => { + match groq.chat_with_tools(messages.clone(), &MESH_TOOLS).await { + Ok(r) => LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls: r.raw_tool_calls, + finish_reason: r.finish_reason, + }, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("LLM API error: {}", e) })), + ) + .into_response(); + } + } + } + LlmClient::Claude(claude_client) => { + let claude_messages = claude::groq_messages_to_claude(&messages); + match claude_client + .chat_with_tools(claude_messages, &MESH_TOOLS) + .await + { + Ok(r) => { + let raw_tool_calls: Vec<ToolCallResponse> = r + .tool_calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + call_type: "function".to_string(), + function: crate::llm::groq::FunctionCall { + name: tc.name.clone(), + arguments: tc.arguments.to_string(), + }, + }) + .collect(); + + LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls, + finish_reason: r.stop_reason, + } + } + Err(e) => { + tracing::error!("Claude API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("LLM API error: {}", e) })), + ) + .into_response(); + } + } + } + }; + + // Check if there are tool calls to execute + if result.tool_calls.is_empty() { + final_response = result.content; + break; + } + + // Add assistant message with tool calls to conversation + messages.push(Message { + role: "assistant".to_string(), + content: result.content.clone(), + tool_calls: Some(result.raw_tool_calls.clone()), + tool_call_id: None, + }); + + // Execute each tool call + for (i, tool_call) in result.tool_calls.iter().enumerate() { + tracing::info!(tool = %tool_call.name, round = round, "Executing mesh tool call"); + + // Parse the tool call + let mut execution_result = parse_mesh_tool_call(tool_call); + + // Handle async mesh tool requests + if let Some(mesh_request) = execution_result.request.take() { + let async_result = handle_mesh_request(pool, state, mesh_request, owner_id).await; + execution_result.success = async_result.success; + execution_result.message = async_result.message; + execution_result.data = async_result.data; + } + + // Track consecutive failures + if execution_result.success { + consecutive_failures = 0; + } else { + consecutive_failures += 1; + tracing::warn!( + tool = %tool_call.name, + consecutive_failures = consecutive_failures, + "Mesh tool call failed" + ); + } + + // Check for pending user questions + if let Some(questions) = execution_result.pending_questions { + tracing::info!( + question_count = questions.len(), + "Mesh LLM requesting user input" + ); + pending_questions = Some(questions); + all_tool_call_infos.push(MeshToolCallInfo { + name: tool_call.name.clone(), + result: ToolResult { + success: execution_result.success, + message: execution_result.message.clone(), + }, + }); + break; + } + + // Build tool result message + let result_content = if let Some(data) = &execution_result.data { + json!({ + "success": execution_result.success, + "message": execution_result.message, + "data": data + }) + .to_string() + } else { + json!({ + "success": execution_result.success, + "message": execution_result.message + }) + .to_string() + }; + + // Add tool result message + let tool_call_id = match llm_client { + LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(), + LlmClient::Claude(_) => tool_call.id.clone(), + }; + + messages.push(Message { + role: "tool".to_string(), + content: Some(result_content), + tool_calls: None, + tool_call_id: Some(tool_call_id), + }); + + // Track for response + all_tool_call_infos.push(MeshToolCallInfo { + name: tool_call.name.clone(), + result: ToolResult { + success: execution_result.success, + message: execution_result.message, + }, + }); + } + + // If user questions are pending, pause + if pending_questions.is_some() { + final_response = result.content; + break; + } + + // If finish reason indicates completion, exit loop + let finish_lower = result.finish_reason.to_lowercase(); + if finish_lower == "stop" || finish_lower == "end_turn" { + final_response = result.content; + break; + } + } + + // Build response + let response_text = final_response.unwrap_or_else(|| { + if all_tool_call_infos.is_empty() { + "I couldn't understand your request. Please try rephrasing.".to_string() + } else { + format!( + "Done! Executed {} tool{}.", + all_tool_call_infos.len(), + if all_tool_call_infos.len() == 1 { + "" + } else { + "s" + } + ) + } + }); + + // Save messages to database (only if not using legacy history mode) + if request.history.is_none() { + let context_type = request.context_type.clone().unwrap_or_else(|| "mesh".to_string()); + + // Validate context_task_id exists before using it (to avoid FK constraint violation) + let context_task_id = if let Some(task_id) = request.context_task_id { + match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(_)) => Some(task_id), + Ok(None) => { + tracing::warn!("context_task_id {} not found, ignoring", task_id); + None + } + Err(e) => { + tracing::warn!("Failed to validate context_task_id {}: {}", task_id, e); + None + } + } + } else { + None + }; + + // Save user message + if let Err(e) = repository::add_chat_message( + pool, + conversation.id, + "user", + &request.message, + &context_type, + context_task_id, + None, + None, + ) + .await + { + tracing::warn!("Failed to save user message to DB: {}", e); + } + + // Serialize tool calls for storage + let tool_calls_json = if all_tool_call_infos.is_empty() { + None + } else { + Some(serde_json::to_value(&all_tool_call_infos).unwrap_or_default()) + }; + + // Serialize pending questions for storage + let pending_questions_json = pending_questions + .as_ref() + .map(|q| serde_json::to_value(q).unwrap_or_default()); + + // Save assistant message + if let Err(e) = repository::add_chat_message( + pool, + conversation.id, + "assistant", + &response_text, + &context_type, + context_task_id, + tool_calls_json, + pending_questions_json, + ) + .await + { + tracing::warn!("Failed to save assistant message to DB: {}", e); + } + + tracing::info!( + conversation_id = %conversation.id, + context_type = %context_type, + "Saved mesh chat messages to database" + ); + } + + ( + StatusCode::OK, + Json(MeshChatResponse { + response: response_text, + tool_calls: all_tool_call_infos, + pending_questions, + }), + ) + .into_response() +} + +/// Result from handling an async mesh tool request +struct MeshRequestResult { + success: bool, + message: String, + data: Option<serde_json::Value>, +} + +/// Handle async mesh tool requests that require database/daemon access +async fn handle_mesh_request( + pool: &sqlx::PgPool, + state: &SharedState, + request: MeshToolRequest, + owner_id: Uuid, +) -> MeshRequestResult { + match request { + MeshToolRequest::CreateTask { + name, + plan, + parent_task_id, + repository_url, + base_branch, + merge_mode, + priority, + } => { + // Check if repository_url matches a daemon's working directory (for this owner) + let is_daemon_working_dir = repository_url.as_ref().map(|url| { + state.daemon_connections.iter().any(|entry| { + entry.value().owner_id == owner_id && + entry.value().working_directory.as_ref() == Some(url) + }) + }).unwrap_or(false); + + // Derive completion_action from merge_mode, or default to "branch" if using daemon working dir + let (completion_action, target_repo_path) = if let Some(ref mode) = merge_mode { + // Explicit merge_mode provided - derive from it + let action = match mode.as_str() { + "pr" => "pr".to_string(), + "auto" => "merge".to_string(), + "manual" => "branch".to_string(), + _ => "none".to_string(), + }; + // If using daemon working dir and action involves the repo, set target_repo_path + let target = if is_daemon_working_dir && action != "none" { + repository_url.clone() + } else { + None + }; + (Some(action), target) + } else if is_daemon_working_dir { + // No merge_mode but using daemon working dir - default to "branch" + (Some("branch".to_string()), repository_url.clone()) + } else { + (None, None) + }; + + let create_req = CreateTaskRequest { + name: name.clone(), + description: None, + plan, + parent_task_id, + repository_url, + base_branch, + target_branch: None, + merge_mode, + priority: priority.unwrap_or(0), + target_repo_path, + completion_action, + continue_from_task_id: None, + copy_files: None, + }; + + match repository::create_task_for_owner(pool, owner_id, create_req).await { + Ok(task) => MeshRequestResult { + success: true, + message: format!("Created task '{}' with ID {}", name, task.id), + data: Some(json!({ + "taskId": task.id, + "name": task.name, + "status": task.status, + })), + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to create task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::RunTask { task_id, daemon_id } => { + // Get task to check status + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "pending" && task.status != "paused" { + return MeshRequestResult { + success: false, + message: format!( + "Task cannot be run - status is '{}' (must be 'pending' or 'paused')", + task.status + ), + data: None, + }; + } + + // Find a daemon to run on (must belong to this owner) + let target_daemon_id = if let Some(id) = daemon_id { + // Verify the specified daemon belongs to this owner + if !state.daemon_connections.iter().any(|d| d.value().id == id && d.value().owner_id == owner_id) { + return MeshRequestResult { + success: false, + message: "Specified daemon not found or not accessible.".to_string(), + data: None, + }; + } + id + } else { + // Find any connected daemon for this owner + let daemon = state.daemon_connections.iter().find(|d| d.value().owner_id == owner_id); + match daemon { + Some(d) => d.value().id, + None => { + return MeshRequestResult { + success: false, + message: "No daemons connected for your account. Cannot run task.".to_string(), + data: None, + } + } + } + }; + + // Check if this is an orchestrator (depth 0 with subtasks) + let subtask_count = match repository::list_subtasks_for_owner(pool, task_id, owner_id).await { + Ok(subtasks) => subtasks.len(), + Err(_) => 0, + }; + let is_orchestrator = task.depth == 0 && subtask_count > 0; + + // Send SpawnTask command to daemon + let command = DaemonCommand::SpawnTask { + task_id, + task_name: task.name.clone(), + plan: task.plan.clone(), + repo_url: task.repository_url.clone(), + base_branch: task.base_branch.clone(), + target_branch: task.target_branch.clone(), + parent_task_id: task.parent_task_id, + depth: task.depth, + is_orchestrator, + target_repo_path: task.target_repo_path.clone(), + completion_action: task.completion_action.clone(), + continue_from_task_id: task.continue_from_task_id, + copy_files: task.copy_files.as_ref().and_then(|v| serde_json::from_value(v.clone()).ok()), + }; + + match state.send_daemon_command(target_daemon_id, command).await { + Ok(()) => { + // Update task status to running + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("running".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "running".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} is now running on daemon {}", task_id, target_daemon_id), + data: Some(json!({ + "taskId": task_id, + "daemonId": target_daemon_id, + "status": "running", + })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to start task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::PauseTask { task_id } => { + // Get task and its daemon + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "running" { + return MeshRequestResult { + success: false, + message: format!("Task is not running (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::PauseTask { task_id }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to pause task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("paused".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "paused".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} paused", task_id), + data: Some(json!({ "taskId": task_id, "status": "paused" })), + } + } + + MeshToolRequest::ResumeTask { task_id } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "paused" { + return MeshRequestResult { + success: false, + message: format!("Task is not paused (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::ResumeTask { task_id }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to resume task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("running".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "running".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} resumed", task_id), + data: Some(json!({ "taskId": task_id, "status": "running" })), + } + } + + MeshToolRequest::InterruptTask { task_id, graceful } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { task_id, graceful }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to interrupt task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("paused".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "paused".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!( + "Task {} {}interrupted", + task_id, + if graceful { "gracefully " } else { "" } + ), + data: Some(json!({ "taskId": task_id, "status": "paused" })), + } + } + + MeshToolRequest::DiscardTask { task_id } => { + match repository::delete_task_for_owner(pool, task_id, owner_id).await { + Ok(true) => MeshRequestResult { + success: true, + message: format!("Task {} discarded", task_id), + data: Some(json!({ "taskId": task_id, "deleted": true })), + }, + Ok(false) => MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to delete task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::QueryTaskStatus { task_id } => { + match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(task)) => MeshRequestResult { + success: true, + message: format!("Task '{}' is {}", task.name, task.status), + data: Some(json!({ + "taskId": task.id, + "name": task.name, + "status": task.status, + "priority": task.priority, + "description": task.description, + "plan": task.plan, + "progressSummary": task.progress_summary, + "errorMessage": task.error_message, + "repositoryUrl": task.repository_url, + "baseBranch": task.base_branch, + "targetBranch": task.target_branch, + "mergeMode": task.merge_mode, + "prUrl": task.pr_url, + "daemonId": task.daemon_id, + "containerId": task.container_id, + "createdAt": task.created_at, + "startedAt": task.started_at, + "completedAt": task.completed_at, + })), + }, + Ok(None) => MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListTasks { + status_filter, + parent_task_id, + } => { + // TODO: Add filtering support to repository + match repository::list_tasks_for_owner(pool, owner_id).await { + Ok(mut tasks) => { + // Apply filters + if let Some(ref status) = status_filter { + tasks.retain(|t| &t.status == status); + } + if let Some(ref parent_id) = parent_task_id { + tasks.retain(|t| t.parent_task_id.as_ref() == Some(parent_id)); + } + + let task_data: Vec<serde_json::Value> = tasks + .iter() + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + "parentTaskId": t.parent_task_id, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} tasks", tasks.len()), + data: Some(json!({ "tasks": task_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListSubtasks { task_id } => { + match repository::list_subtasks_for_owner(pool, task_id, owner_id).await { + Ok(subtasks) => { + let subtask_data: Vec<serde_json::Value> = subtasks + .iter() + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} subtasks", subtasks.len()), + data: Some(json!({ "subtasks": subtask_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListSiblings { task_id } => { + // Get task to find parent + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + let Some(parent_id) = task.parent_task_id else { + return MeshRequestResult { + success: true, + message: "Task has no parent, so no siblings".to_string(), + data: Some(json!({ "siblings": [] })), + }; + }; + + // Get all subtasks of parent, excluding current task + match repository::list_subtasks_for_owner(pool, parent_id, owner_id).await { + Ok(siblings) => { + let sibling_data: Vec<serde_json::Value> = siblings + .iter() + .filter(|t| t.id != task_id) + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} sibling tasks", sibling_data.len()), + data: Some(json!({ "siblings": sibling_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListDaemons => { + // Only list daemons belonging to this owner + let daemons: Vec<serde_json::Value> = state + .daemon_connections + .iter() + .filter(|entry| entry.value().owner_id == owner_id) + .map(|entry| { + let d = entry.value(); + json!({ + "daemonId": d.id, + "connectionId": d.connection_id, + "hostname": d.hostname, + "machineId": d.machine_id, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("{} daemon(s) connected", daemons.len()), + data: Some(json!({ "daemons": daemons })), + } + } + + MeshToolRequest::ListDaemonDirectories => { + let mut directories: Vec<serde_json::Value> = Vec::new(); + + // Only list directories from daemons belonging to this owner + for entry in state.daemon_connections.iter() { + let daemon = entry.value(); + + // Only include daemons belonging to this owner + if daemon.owner_id != owner_id { + continue; + } + + // Add working directory if available + if let Some(ref working_dir) = daemon.working_directory { + directories.push(json!({ + "path": working_dir, + "label": "Working Directory", + "directoryType": "working", + "hostname": daemon.hostname, + })); + } + + // Add home directory if available + if let Some(ref home_dir) = daemon.home_directory { + directories.push(json!({ + "path": home_dir, + "label": "Makima Home", + "directoryType": "home", + "hostname": daemon.hostname, + })); + } + } + + MeshRequestResult { + success: true, + message: format!("Found {} available directories", directories.len()), + data: Some(json!({ "directories": directories })), + } + } + + MeshToolRequest::ListFiles => { + match repository::list_files_for_owner(pool, owner_id).await { + Ok(files) => { + let file_data: Vec<serde_json::Value> = files + .iter() + .map(|f| { + json!({ + "fileId": f.id, + "name": f.name, + "description": f.description, + "createdAt": f.created_at, + "updatedAt": f.updated_at, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} files", files.len()), + data: Some(json!({ "files": file_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ReadFile { file_id } => { + match repository::get_file_for_owner(pool, file_id, owner_id).await { + Ok(Some(file)) => { + // Convert body elements to readable text + let body_content: Vec<serde_json::Value> = file + .body + .iter() + .map(|elem| { + match elem { + crate::db::models::BodyElement::Heading { level, text } => { + json!({ "type": "heading", "level": level, "text": text }) + } + crate::db::models::BodyElement::Paragraph { text } => { + json!({ "type": "paragraph", "text": text }) + } + crate::db::models::BodyElement::Code { language, content } => { + json!({ "type": "code", "language": language, "content": content }) + } + crate::db::models::BodyElement::List { ordered, items } => { + json!({ "type": "list", "ordered": ordered, "items": items }) + } + crate::db::models::BodyElement::Chart { chart_type, title, data, config: _ } => { + let data_count = data.as_array().map(|arr| arr.len()).unwrap_or(0); + json!({ "type": "chart", "chartType": chart_type, "title": title, "dataPoints": data_count }) + } + crate::db::models::BodyElement::Image { src, alt, caption } => { + json!({ "type": "image", "src": src, "alt": alt, "caption": caption }) + } + } + }) + .collect(); + + // Also build a plain text version for easier reading + let plain_text: String = file + .body + .iter() + .filter_map(|elem| { + match elem { + crate::db::models::BodyElement::Heading { level, text } => { + Some(format!("{} {}", "#".repeat(*level as usize), text)) + } + crate::db::models::BodyElement::Paragraph { text } => { + Some(text.clone()) + } + crate::db::models::BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or(""); + Some(format!("```{}\n{}\n```", lang, content)) + } + crate::db::models::BodyElement::List { ordered, items } => { + let list_text: Vec<String> = items.iter().enumerate().map(|(i, item)| { + if *ordered { + format!("{}. {}", i + 1, item) + } else { + format!("- {}", item) + } + }).collect(); + Some(list_text.join("\n")) + } + _ => None, + } + }) + .collect::<Vec<_>>() + .join("\n\n"); + + // Convert transcript entries to JSON + let transcript: Vec<serde_json::Value> = file + .transcript + .iter() + .map(|entry| { + json!({ + "id": entry.id, + "speaker": entry.speaker, + "start": entry.start, + "end": entry.end, + "text": entry.text, + }) + }) + .collect(); + + // Build a plain text transcript for easier reading + let transcript_text: String = file + .transcript + .iter() + .map(|entry| { + format!("[{:.1}s] {}: {}", entry.start, entry.speaker, entry.text) + }) + .collect::<Vec<_>>() + .join("\n"); + + MeshRequestResult { + success: true, + message: format!("Read file '{}'", file.name), + data: Some(json!({ + "fileId": file.id, + "name": file.name, + "description": file.description, + "summary": file.summary, + "body": body_content, + "plainText": plain_text, + "transcript": transcript, + "transcriptText": transcript_text, + "transcriptCount": file.transcript.len(), + "createdAt": file.created_at, + "updatedAt": file.updated_at, + })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: format!("File {} not found", file_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::SendMessageToTask { task_id, message } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "running" { + return MeshRequestResult { + success: false, + message: format!("Task is not running (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::SendMessage { task_id, message }; + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => MeshRequestResult { + success: true, + message: "Message sent to task".to_string(), + data: Some(json!({ "taskId": task_id })), + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to send message: {}", e), + data: None, + }, + } + } else { + MeshRequestResult { + success: false, + message: "Task has no daemon assigned".to_string(), + data: None, + } + } + } + + MeshToolRequest::UpdateTaskPlan { + task_id, + new_plan, + interrupt_if_running, + } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + // Interrupt if running and requested + if task.status == "running" && interrupt_if_running { + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { + task_id, + graceful: true, + }; + let _ = state.send_daemon_command(daemon_id, command).await; + } + } + + let update_req = crate::db::models::UpdateTaskRequest { + plan: Some(new_plan), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: updated.status.clone(), + updated_fields: vec!["plan".to_string()], + updated_by: "system".to_string(), + }); + MeshRequestResult { + success: true, + message: "Task plan updated".to_string(), + data: Some(json!({ "taskId": task_id })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to update task: {}", e), + data: None, + }, + } + } + + // Overlay operations - these require daemon communication + // For now, return placeholder responses since daemon implementation is separate + MeshToolRequest::PeekSiblingOverlay { sibling_task_id } => MeshRequestResult { + success: false, + message: format!( + "Overlay operations require a connected daemon. Task {} may not have overlay data yet.", + sibling_task_id + ), + data: None, + }, + + MeshToolRequest::GetOverlayDiff { task_id } => MeshRequestResult { + success: false, + message: format!( + "Overlay operations require a connected daemon. Task {} may not have overlay data yet.", + task_id + ), + data: None, + }, + + MeshToolRequest::PreviewMerge { task_id } => MeshRequestResult { + success: false, + message: format!( + "Merge preview requires a connected daemon. Task {} may not have overlay data yet.", + task_id + ), + data: None, + }, + + MeshToolRequest::MergeSubtask { task_id } => MeshRequestResult { + success: false, + message: format!( + "Merge operations require a connected daemon. Task {}", + task_id + ), + data: None, + }, + + MeshToolRequest::CompleteTask { task_id } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + // Update status to done + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("done".to_string()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "done".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + let merge_mode = task.merge_mode.unwrap_or_else(|| "pr".to_string()); + MeshRequestResult { + success: true, + message: format!( + "Task {} completed. Merge mode: {}", + task_id, + &merge_mode + ), + data: Some(json!({ + "taskId": task_id, + "status": "done", + "mergeMode": merge_mode, + })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to complete task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::SetMergeMode { task_id, mode } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + let update_req = crate::db::models::UpdateTaskRequest { + merge_mode: Some(mode.clone()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: updated.status, + updated_fields: vec!["merge_mode".to_string()], + updated_by: "system".to_string(), + }); + MeshRequestResult { + success: true, + message: format!("Merge mode set to '{}'", mode), + data: Some(json!({ "taskId": task_id, "mergeMode": mode })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to update merge mode: {}", e), + data: None, + }, + } + } + } +} + +// ============================================================================= +// Chat History Endpoints +// ============================================================================= + +use crate::db::models::MeshChatHistoryResponse; + +/// Get chat history for the current conversation (requires authentication) +#[utoipa::path( + get, + path = "/api/v1/mesh/chat/history", + responses( + (status = 200, description = "Chat history", body = MeshChatHistoryResponse), + (status = 401, description = "Unauthorized"), + (status = 503, description = "Database not configured"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_chat_history( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + let conversation = match repository::get_or_create_active_conversation(pool, auth.owner_id).await { + Ok(c) => c, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response() + } + }; + + let messages = match repository::list_chat_messages(pool, conversation.id, None).await { + Ok(m) => m, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response() + } + }; + + ( + StatusCode::OK, + Json(MeshChatHistoryResponse { + conversation_id: conversation.id, + messages, + }), + ) + .into_response() +} + +/// Clear chat history (archives current conversation and starts new, requires authentication) +#[utoipa::path( + delete, + path = "/api/v1/mesh/chat/history", + responses( + (status = 200, description = "History cleared"), + (status = 401, description = "Unauthorized"), + (status = 503, description = "Database not configured"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn clear_chat_history( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + match repository::clear_conversation(pool, auth.owner_id).await { + Ok(new_conv) => ( + StatusCode::OK, + Json(json!({ "success": true, "conversationId": new_conv.id })), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response(), + } +} diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs new file mode 100644 index 0000000..644d0bc --- /dev/null +++ b/makima/src/server/handlers/mesh_daemon.rs @@ -0,0 +1,959 @@ +//! WebSocket handler for daemon connections. +//! +//! Daemons connect to report task progress, stream output, and receive commands. +//! Each daemon manages Claude Code containers on its local machine. +//! +//! ## Authentication +//! +//! Daemons authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +//! The API key is validated against the database and the daemon is associated with +//! the corresponding owner_id for data isolation. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use futures::{SinkExt, StreamExt}; +use serde::Deserialize; +use sqlx::Row; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::db::repository; +use crate::server::auth::{hash_api_key, API_KEY_HEADER}; +use crate::server::messages::ApiError; +use crate::server::state::{ + DaemonCommand, SharedState, TaskOutputNotification, TaskUpdateNotification, +}; + +// ============================================================================= +// Claude Code JSON Output Parsing +// ============================================================================= + +/// Claude Code stream-json message structure +#[derive(Debug, Deserialize)] +struct ClaudeMessage { + #[serde(rename = "type")] + msg_type: String, + subtype: Option<String>, + message: Option<ClaudeMessageContent>, + tool_name: Option<String>, + tool_input: Option<serde_json::Value>, + tool_result: Option<ClaudeToolResult>, + result: Option<String>, + cost_usd: Option<f64>, + duration_ms: Option<u64>, + error: Option<String>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeMessageContent { + content: Option<Vec<ClaudeContentBlock>>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeContentBlock { + #[serde(rename = "type")] + block_type: String, + text: Option<String>, + name: Option<String>, + input: Option<serde_json::Value>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeToolResult { + content: Option<String>, + is_error: Option<bool>, +} + +/// Parse a line of Claude Code output into a structured notification +fn parse_claude_output(task_id: Uuid, owner_id: Uuid, line: &str, is_partial: bool) -> Option<TaskOutputNotification> { + let trimmed = line.trim(); + if trimmed.is_empty() { + return None; + } + + // Try to parse as JSON + if trimmed.starts_with('{') { + if let Ok(msg) = serde_json::from_str::<ClaudeMessage>(trimmed) { + return parse_claude_message(task_id, owner_id, msg, is_partial); + } + } + + // Not JSON or failed to parse - treat as raw output + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "raw".to_string(), + content: trimmed.to_string(), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) +} + +fn parse_claude_message(task_id: Uuid, owner_id: Uuid, msg: ClaudeMessage, is_partial: bool) -> Option<TaskOutputNotification> { + match msg.msg_type.as_str() { + "system" => { + // System messages (init, etc.) - include subtype info + let content = match msg.subtype.as_deref() { + Some("init") => "Session started".to_string(), + Some(sub) => format!("System: {}", sub), + None => "System message".to_string(), + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "assistant" => { + // Extract text content from message blocks + if let Some(message) = msg.message { + if let Some(blocks) = message.content { + // Check for text blocks + let text_content: Vec<String> = blocks + .iter() + .filter(|b| b.block_type == "text") + .filter_map(|b| b.text.clone()) + .collect(); + + if !text_content.is_empty() { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "assistant".to_string(), + content: text_content.join("\n"), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + + // Check for tool_use blocks + if let Some(tool_block) = blocks.iter().find(|b| b.block_type == "tool_use") { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", tool_block.name.as_deref().unwrap_or("unknown")), + tool_name: tool_block.name.clone(), + tool_input: tool_block.input.clone(), + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + } + } + None + } + + "tool_use" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", msg.tool_name.as_deref().unwrap_or("unknown")), + tool_name: msg.tool_name, + tool_input: msg.tool_input, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "tool_result" => { + if let Some(result) = msg.tool_result { + let content = result.content.unwrap_or_default(); + // Truncate long results + let content = if content.len() > 500 { + format!("{}...", &content[..500]) + } else { + content + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_result".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: result.is_error, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } else { + None + } + } + + "result" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "result".to_string(), + content: msg.result.unwrap_or_else(|| "Task completed".to_string()), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: msg.cost_usd, + duration_ms: msg.duration_ms, + is_partial, + }) + } + + "error" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "error".to_string(), + content: msg.error.unwrap_or_else(|| "An error occurred".to_string()), + tool_name: None, + tool_input: None, + is_error: Some(true), + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + _ => None, // Skip unknown message types + } +} + +/// Message from daemon to server. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonMessage { + /// Authentication request (first message required) + Authenticate { + #[serde(rename = "apiKey")] + api_key: String, + #[serde(rename = "machineId")] + machine_id: String, + hostname: String, + #[serde(rename = "maxConcurrentTasks")] + max_concurrent_tasks: i32, + }, + /// Periodic heartbeat with current status + Heartbeat { + #[serde(rename = "activeTasks")] + active_tasks: Vec<Uuid>, + }, + /// Task output streaming (stdout/stderr from Claude Code) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + output: String, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Task status change notification + TaskStatusChange { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "oldStatus")] + old_status: String, + #[serde(rename = "newStatus")] + new_status: String, + }, + /// Task progress update with summary + TaskProgress { + #[serde(rename = "taskId")] + task_id: Uuid, + summary: String, + }, + /// Task completion notification + TaskComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + error: Option<String>, + }, + /// Register a tool key for orchestrator API access + RegisterToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + /// The API key for this orchestrator to use when calling mesh endpoints + key: String, + }, + /// Revoke a tool key when task completes + RevokeToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Response to RetryCompletionAction command + CompletionActionResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// PR URL if action was "pr" and successful + #[serde(rename = "prUrl")] + pr_url: Option<String>, + }, + /// Report daemon's available directories for task output + DaemonDirectories { + /// Current working directory of the daemon + #[serde(rename = "workingDirectory")] + working_directory: String, + /// Path to ~/.makima/home directory (for cloning completed work) + #[serde(rename = "homeDirectory")] + home_directory: String, + /// Path to worktrees directory (~/.makima/worktrees) + #[serde(rename = "worktreesDirectory")] + worktrees_directory: String, + }, + /// Response to CloneWorktree command + CloneWorktreeResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// The path where the worktree was cloned + #[serde(rename = "targetDir")] + target_dir: Option<String>, + }, + /// Response to CheckTargetExists command + CheckTargetExistsResult { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Whether the target directory exists + exists: bool, + /// The path that was checked + #[serde(rename = "targetDir")] + target_dir: String, + }, +} + +/// Validated daemon authentication result. +#[derive(Debug, Clone)] +struct DaemonAuthResult { + /// User ID from the API key + user_id: Uuid, + /// Owner ID for data isolation + owner_id: Uuid, +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_daemon_api_key(pool: &sqlx::PgPool, key: &str) -> Result<DaemonAuthResult, String> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| format!("Database error: {}", e))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| format!("Failed to get user_id: {}", e))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| format!("Failed to get owner_id: {}", e))?; + let owner_id = owner_id.ok_or_else(|| "User has no default owner".to_string())?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok(DaemonAuthResult { user_id, owner_id }) + } + None => Err("Invalid or revoked API key".to_string()), + } +} + +/// WebSocket upgrade handler for daemon connections. +/// +/// Daemons must authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +/// The API key is validated against the database and used to determine the owner_id +/// for data isolation. +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/connect", + params( + ("X-Api-Key" = String, Header, description = "API key for daemon authentication"), + ), + responses( + (status = 101, description = "WebSocket connection established"), + (status = 401, description = "Missing or invalid API key"), + (status = 503, description = "Database not configured"), + ), + tag = "Mesh" +)] +pub async fn daemon_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, + headers: HeaderMap, +) -> Response { + // Extract API key from headers + let api_key = match headers.get(API_KEY_HEADER).or_else(|| headers.get("x-api-key")) { + Some(value) => match value.to_str() { + Ok(key) if !key.is_empty() => key.to_string(), + _ => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("INVALID_API_KEY", "Invalid API key header value")), + ) + .into_response(); + } + }, + None => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("MISSING_API_KEY", "X-Api-Key header required")), + ) + .into_response(); + } + }; + + // Validate API key against database + let pool = match state.db_pool.as_ref() { + Some(pool) => pool, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + axum::Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + } + }; + + let auth_result = match validate_daemon_api_key(pool, &api_key).await { + Ok(result) => result, + Err(e) => { + tracing::warn!("Daemon authentication failed: {}", e); + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("AUTH_FAILED", e)), + ) + .into_response(); + } + }; + + tracing::info!( + user_id = %auth_result.user_id, + owner_id = %auth_result.owner_id, + "Daemon authenticated via API key" + ); + + ws.on_upgrade(move |socket| handle_daemon_connection(socket, state, auth_result)) +} + +async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_result: DaemonAuthResult) { + let (mut sender, mut receiver) = socket.split(); + + // Generate a unique connection ID and daemon ID + let connection_id = Uuid::new_v4().to_string(); + let daemon_id = Uuid::new_v4(); + let owner_id = auth_result.owner_id; + + // Create command channel for sending commands to this daemon + let (cmd_tx, mut cmd_rx) = mpsc::channel::<DaemonCommand>(64); + + // Wait for the daemon to send its registration info (hostname, machine_id, etc.) + // The daemon is already authenticated via API key header, but we need metadata + #[allow(unused_assignments)] + let mut registered = false; + + // Wait for registration message with metadata + loop { + tokio::select! { + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Authenticate { api_key: _, machine_id, hostname, max_concurrent_tasks }) => { + // API key was already validated via headers, but we use this message + // for backward compatibility to get the machine_id and hostname + + tracing::info!( + daemon_id = %daemon_id, + owner_id = %owner_id, + hostname = %hostname, + machine_id = %machine_id, + max_concurrent_tasks = max_concurrent_tasks, + "Daemon registered" + ); + + // Register daemon in state with owner_id + state.register_daemon( + connection_id.clone(), + daemon_id, + owner_id, + Some(hostname), + Some(machine_id), + cmd_tx.clone(), + ); + + registered = true; + + // Send authentication confirmation + let response = DaemonCommand::Authenticated { daemon_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + + break; // Exit registration loop, continue to main loop + } + Ok(_) => { + // Non-auth message before registration - still requires registration message + let response = DaemonCommand::Error { + code: "NOT_REGISTERED".into(), + message: "Must send registration message (Authenticate) first".into(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + Err(e) => { + let response = DaemonCommand::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Daemon disconnected during registration"); + return; + } + Some(Err(e)) => { + tracing::warn!("Daemon WebSocket error during registration: {}", e); + return; + } + _ => {} + } + } + } + } + + if !registered { + return; + } + + let daemon_uuid = daemon_id; + + // Main message loop after authentication + loop { + tokio::select! { + // Handle incoming messages from daemon + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Heartbeat { active_tasks }) => { + tracing::trace!( + "Daemon {} heartbeat: {} active tasks", + daemon_uuid, active_tasks.len() + ); + // TODO: Update daemon last_heartbeat_at in DB + } + Ok(DaemonMessage::TaskOutput { task_id, output, is_partial }) => { + // Parse the output line and broadcast structured data + if let Some(notification) = parse_claude_output(task_id, owner_id, &output, is_partial) { + // Broadcast to connected clients + state.broadcast_task_output(notification.clone()); + + // Persist to database (fire and forget) + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let notification = notification.clone(); + tokio::spawn(async move { + if let Err(e) = repository::save_task_output( + &pool, + notification.task_id, + ¬ification.message_type, + ¬ification.content, + notification.tool_name.as_deref(), + notification.tool_input.clone(), + notification.is_error, + notification.cost_usd, + notification.duration_ms, + ).await { + tracing::warn!( + task_id = %notification.task_id, + "Failed to persist task output: {}", + e + ); + } + }); + } + } + } + Ok(DaemonMessage::TaskStatusChange { task_id, old_status, new_status }) => { + tracing::info!( + "Task {} status change: {} -> {}", + task_id, old_status, new_status + ); + + // Update task status in database and broadcast + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let new_status_owned = new_status.clone(); + tokio::spawn(async move { + match repository::update_task_status( + &pool, + task_id, + &new_status_owned, + None, + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: new_status_owned, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when updating status" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to update task status: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: new_status, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::TaskProgress { task_id, summary }) => { + tracing::debug!("Task {} progress: {}", task_id, summary); + // TODO: Update task progress_summary in database + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: "running".into(), + updated_fields: vec!["progress_summary".into()], + updated_by: "daemon".into(), + }); + } + Ok(DaemonMessage::TaskComplete { task_id, success, error }) => { + let status = if success { "done" } else { "failed" }; + tracing::info!( + "Task {} completed: success={}, error={:?}", + task_id, success, error + ); + + // Revoke any tool keys for this task + state.revoke_tool_key(task_id); + + // Update task in database with completion info + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let error_clone = error.clone(); + tokio::spawn(async move { + match repository::complete_task( + &pool, + task_id, + success, + error_clone.as_deref(), + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: updated_task.status.clone(), + updated_fields: vec![ + "status".into(), + "completed_at".into(), + "error_message".into(), + ], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when completing" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to complete task: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: status.into(), + updated_fields: vec!["status".into(), "completed_at".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::Authenticate { .. }) => { + // Already authenticated, ignore + } + Ok(DaemonMessage::RegisterToolKey { task_id, key }) => { + tracing::info!( + task_id = %task_id, + "Registering tool key for orchestrator" + ); + state.register_tool_key(key, task_id); + } + Ok(DaemonMessage::RevokeToolKey { task_id }) => { + tracing::info!( + task_id = %task_id, + "Revoking tool key for task" + ); + state.revoke_tool_key(task_id); + } + Ok(DaemonMessage::DaemonDirectories { working_directory, home_directory, worktrees_directory }) => { + tracing::info!( + daemon_id = %daemon_uuid, + working_directory = %working_directory, + home_directory = %home_directory, + worktrees_directory = %worktrees_directory, + "Daemon directories received" + ); + state.update_daemon_directories( + &connection_id, + working_directory, + home_directory, + worktrees_directory, + ); + } + Ok(DaemonMessage::CompletionActionResult { task_id, success, message, pr_url }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + pr_url = ?pr_url, + "Completion action result received" + ); + + // Update task with PR URL if created + if let Some(ref url) = pr_url { + if let Some(ref pool) = state.db_pool { + let update_req = crate::db::models::UpdateTaskRequest { + pr_url: Some(url.clone()), + ..Default::default() + }; + if let Err(e) = crate::db::repository::update_task(pool, task_id, update_req).await { + tracing::error!("Failed to update task PR URL: {}", e); + } + } + } + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Completion action succeeded: {}", message) + } else { + format!("✗ Completion action failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CloneWorktreeResult { task_id, success, message, target_dir }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + target_dir = ?target_dir, + "Clone worktree result received" + ); + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Clone succeeded: {}", message) + } else { + format!("✗ Clone failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CheckTargetExistsResult { task_id, exists, target_dir }) => { + tracing::debug!( + task_id = %task_id, + exists = exists, + target_dir = %target_dir, + "Check target exists result received" + ); + + // Broadcast as task output so UI can use the result + let output_text = if exists { + format!("Target directory exists: {}", target_dir) + } else { + format!("Target directory does not exist: {}", target_dir) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Err(e) => { + tracing::warn!("Failed to parse daemon message: {}", e); + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::info!("Daemon {} disconnected", daemon_uuid); + break; + } + Some(Err(e)) => { + tracing::warn!("Daemon {} WebSocket error: {}", daemon_uuid, e); + break; + } + _ => {} + } + } + + // Handle commands to send to daemon + cmd = cmd_rx.recv() => { + match cmd { + Some(command) => { + let json = serde_json::to_string(&command).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + tracing::warn!("Failed to send command to daemon {}", daemon_uuid); + break; + } + } + None => { + // Channel closed + break; + } + } + } + } + } + + // Cleanup on disconnect + state.unregister_daemon(&connection_id); + + // Clear daemon_id from any tasks that were running on this daemon + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + tokio::spawn(async move { + // Find tasks assigned to this daemon that are still active + if let Err(e) = clear_daemon_from_tasks(&pool, daemon_uuid).await { + tracing::error!( + daemon_id = %daemon_uuid, + error = %e, + "Failed to clear daemon from tasks on disconnect" + ); + } + }); + } +} + +/// Clear daemon_id from tasks when daemon disconnects +async fn clear_daemon_from_tasks(pool: &sqlx::PgPool, daemon_id: Uuid) -> Result<(), sqlx::Error> { + // Update tasks that were running on this daemon to failed state + let result = sqlx::query( + r#" + UPDATE tasks + SET daemon_id = NULL, + status = 'failed', + error_message = 'Daemon disconnected', + updated_at = NOW() + WHERE daemon_id = $1 + AND status IN ('starting', 'running', 'initializing') + "#, + ) + .bind(daemon_id) + .execute(pool) + .await?; + + if result.rows_affected() > 0 { + tracing::warn!( + daemon_id = %daemon_id, + tasks_affected = result.rows_affected(), + "Marked tasks as failed due to daemon disconnect" + ); + } + + Ok(()) +} diff --git a/makima/src/server/handlers/mesh_merge.rs b/makima/src/server/handlers/mesh_merge.rs new file mode 100644 index 0000000..2d7c742 --- /dev/null +++ b/makima/src/server/handlers/mesh_merge.rs @@ -0,0 +1,441 @@ +//! Merge operation handlers for orchestrator tasks. +//! +//! These endpoints allow orchestrators to merge subtask branches. +//! Commands are forwarded to the daemon via WebSocket; the daemon +//! responds asynchronously through the WebSocket channel. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use uuid::Uuid; + +use crate::db::models::{ + BranchListResponse, MergeCommitRequest, MergeCompleteCheckResponse, MergeResolveRequest, + MergeResultResponse, MergeSkipRequest, MergeStartRequest, MergeStatusResponse, +}; +use crate::db::repository; +use crate::server::messages::ApiError; +use crate::server::state::{DaemonCommand, SharedState}; + +/// Get the daemon ID for a task, returning error if not found. +async fn get_task_daemon_id( + state: &SharedState, + task_id: Uuid, +) -> Result<Uuid, (StatusCode, Json<ApiError>)> { + let pool = state.db_pool.as_ref().ok_or_else(|| { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("service_unavailable", "Database not configured")), + ) + })?; + + // Get task and its daemon_id + let task = repository::get_task(pool, task_id) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("database_error", format!("Database error: {}", e))), + ) + })? + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(ApiError::new("not_found", format!("Task {} not found", task_id))), + ) + })?; + + task.daemon_id.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("bad_request", "Task has no assigned daemon")), + ) + }) +} + +/// List all subtask branches for a task. +/// +/// GET /api/v1/mesh/tasks/{id}/branches +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/branches", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Command sent to daemon"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured") + ), + tag = "Mesh" +)] +pub async fn list_branches( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::ListBranches { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(BranchListResponse { branches: vec![] }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Start merging a subtask branch. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/start +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/start", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeStartRequest, + responses( + (status = 202, description = "Merge command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_start( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeStartRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeStart { + task_id, + source_branch: req.source_branch, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Merge command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Get current merge status. +/// +/// GET /api/v1/mesh/tasks/{id}/merge/status +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/merge/status", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Status request sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_status( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeStatus { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeStatusResponse { + in_progress: false, + source_branch: None, + conflicted_files: vec![], + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Resolve a merge conflict. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/resolve +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/resolve", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeResolveRequest, + responses( + (status = 202, description = "Resolve command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_resolve( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeResolveRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeResolve { + task_id, + file: req.file, + strategy: req.strategy, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Resolve command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Commit the current merge. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/commit +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/commit", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeCommitRequest, + responses( + (status = 202, description = "Commit command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_commit( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeCommitRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeCommit { + task_id, + message: req.message, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Commit command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Abort the current merge. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/abort +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/abort", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Abort command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_abort( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeAbort { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Abort command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Skip merging a subtask branch. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/skip +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/skip", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeSkipRequest, + responses( + (status = 202, description = "Skip command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_skip( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeSkipRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeSkip { + task_id, + subtask_id: req.subtask_id, + reason: req.reason, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Skip command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Check if all branches are merged or skipped. +/// +/// GET /api/v1/mesh/tasks/{id}/merge/check +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/merge/check", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Check command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_check( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::CheckMergeComplete { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeCompleteCheckResponse { + can_complete: true, + unmerged_branches: vec![], + merged_count: 0, + skipped_count: 0, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} diff --git a/makima/src/server/handlers/mesh_ws.rs b/makima/src/server/handlers/mesh_ws.rs new file mode 100644 index 0000000..d15fba7 --- /dev/null +++ b/makima/src/server/handlers/mesh_ws.rs @@ -0,0 +1,346 @@ +//! WebSocket handler for task change subscriptions and output streaming. +//! +//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications +//! when tasks are updated. They can also subscribe to task output for live terminal streaming. +//! +//! ## Owner-scoped filtering +//! +//! Notifications are filtered by owner_id. If a notification has an owner_id set, +//! it will only be delivered to clients who are subscribed to tasks belonging to that owner. +//! The task's owner_id is looked up from the database when the client subscribes. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use sqlx::Row; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::server::state::SharedState; + +/// Client message for task subscription management. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskClientMessage { + /// Subscribe to updates for a specific task + Subscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from updates for a specific task + Unsubscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribe to all task updates + SubscribeAll, + /// Unsubscribe from all task updates + UnsubscribeAll, + /// Subscribe to live output streaming for a specific task + SubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from output streaming for a specific task + UnsubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, +} + +/// Server message for task subscription WebSocket. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskServerMessage { + /// Subscription confirmed for specific task + Subscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscription confirmed for specific task + Unsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribed to all task updates + SubscribedAll, + /// Unsubscribed from all task updates + UnsubscribedAll, + /// Task was updated + TaskUpdated { + #[serde(rename = "taskId")] + task_id: Uuid, + version: i32, + status: String, + #[serde(rename = "updatedFields")] + updated_fields: Vec<String>, + #[serde(rename = "updatedBy")] + updated_by: String, + }, + /// Live output from Claude Code container (parsed and structured) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + #[serde(rename = "messageType")] + message_type: String, + /// Main text content + content: String, + /// Tool name if tool_use message + #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")] + tool_name: Option<String>, + /// Tool input JSON if tool_use message + #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")] + tool_input: Option<serde_json::Value>, + /// Whether tool result was an error + #[serde(rename = "isError", skip_serializing_if = "Option::is_none")] + is_error: Option<bool>, + /// Cost in USD if result message + #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")] + cost_usd: Option<f64>, + /// Duration in ms if result message + #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")] + duration_ms: Option<u64>, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Output subscription confirmed + OutputSubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Output unsubscription confirmed + OutputUnsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Error occurred + Error { code: String, message: String }, +} + +/// WebSocket upgrade handler for task subscriptions. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/subscribe", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "Mesh" +)] +pub async fn task_subscription_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_task_subscription(socket, state)) +} + +/// Look up the owner_id for a task from the database. +async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> { + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .ok()??; + + row.try_get("owner_id").ok() +} + +async fn handle_task_subscription(socket: WebSocket, state: SharedState) { + let (mut sender, mut receiver) = socket.split(); + + // Map of task IDs to their owner_ids for this client's subscriptions + let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new(); + // Whether client is subscribed to all task updates (not owner-scoped) + let mut subscribed_all = false; + // Map of task IDs to their owner_ids for output streaming subscriptions + let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new(); + + // Subscribe to broadcast channels + let mut task_update_rx = state.task_updates.subscribe(); + let mut task_output_rx = state.task_output.subscribe(); + + loop { + tokio::select! { + // Handle incoming WebSocket messages from client + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<TaskClientMessage>(&text) { + Ok(TaskClientMessage::Subscribe { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + task_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::Subscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::Unsubscribe { task_id }) => { + task_subscriptions.remove(&task_id); + let response = TaskServerMessage::Unsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from task {}", task_id); + } + Ok(TaskClientMessage::SubscribeAll) => { + subscribed_all = true; + let response = TaskServerMessage::SubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to all tasks"); + } + Ok(TaskClientMessage::UnsubscribeAll) => { + subscribed_all = false; + let response = TaskServerMessage::UnsubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from all tasks"); + } + Ok(TaskClientMessage::SubscribeOutput { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + output_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::OutputSubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => { + output_subscriptions.remove(&task_id); + let response = TaskServerMessage::OutputUnsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from output for task {}", task_id); + } + Err(e) => { + let response = TaskServerMessage::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Client disconnected from task subscription"); + break; + } + Some(Err(e)) => { + tracing::warn!("Task WebSocket error: {}", e); + break; + } + _ => {} + } + } + + // Handle task update broadcasts + notification = task_update_rx.recv() => { + match notification { + Ok(notification) => { + // Check if client should receive this notification + let should_forward = if subscribed_all { + // SubscribeAll gets all notifications (typically for admin views) + true + } else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) { + // Client is subscribed to this specific task + // Verify owner_id matches (if set on both sides) + match (notification.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskUpdated { + task_id: notification.task_id, + version: notification.version, + status: notification.status, + updated_fields: notification.updated_fields, + updated_by: notification.updated_by, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + + // Handle task output broadcasts + output = task_output_rx.recv() => { + match output { + Ok(output) => { + // Check if client should receive this output + let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) { + // Client is subscribed to output for this task + // Verify owner_id matches (if set on both sides) + match (output.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskOutput { + task_id: output.task_id, + message_type: output.message_type, + content: output.content, + tool_name: output.tool_name, + tool_input: output.tool_input, + is_error: output.is_error, + cost_usd: output.cost_usd, + duration_ms: output.duration_ms, + is_partial: output.is_partial, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task output subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + } + } +} diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index 3211f94..8681104 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,7 +1,14 @@ //! HTTP and WebSocket request handlers. +pub mod api_keys; pub mod chat; pub mod file_ws; pub mod files; pub mod listen; +pub mod mesh; +pub mod mesh_chat; +pub mod mesh_daemon; +pub mod mesh_merge; +pub mod mesh_ws; +pub mod users; pub mod versions; diff --git a/makima/src/server/handlers/users.rs b/makima/src/server/handlers/users.rs new file mode 100644 index 0000000..0b2ccdd --- /dev/null +++ b/makima/src/server/handlers/users.rs @@ -0,0 +1,972 @@ +//! HTTP handlers for user account management. +//! +//! These endpoints allow users to manage their account settings: +//! - Change password +//! - Change email +//! - Delete account + +use axum::{ + extract::State, + http::{HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use crate::server::auth::UserOnly; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +// ============================================================================= +// Request/Response Types +// ============================================================================= + +/// Request to change password. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangePasswordRequest { + /// The user's current password for verification + pub current_password: String, + /// The new password to set + pub new_password: String, +} + +/// Response after changing password. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangePasswordResponse { + pub success: bool, + pub message: String, +} + +/// Request to change email. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangeEmailRequest { + /// The user's password for verification + pub password: String, + /// The new email address to set + pub new_email: String, +} + +/// Response after changing email. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangeEmailResponse { + pub success: bool, + pub message: String, + /// Whether a verification email was sent to the new address + pub verification_sent: bool, +} + +/// Request to delete account. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DeleteAccountRequest { + /// The user's password for verification + pub password: String, + /// Confirmation text - must match the user's email + pub confirmation: String, +} + +/// Response after deleting account. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DeleteAccountResponse { + pub success: bool, + pub message: String, +} + +// ============================================================================= +// Password Validation +// ============================================================================= + +/// Password strength validation result. +#[derive(Debug)] +pub struct PasswordValidation { + pub is_valid: bool, + pub errors: Vec<String>, +} + +/// Validate password strength. +/// Requirements: +/// - At least 6 characters (matches login form) +fn validate_password_strength(password: &str) -> PasswordValidation { + let mut errors = Vec::new(); + + if password.len() < 6 { + errors.push("Password must be at least 6 characters long".to_string()); + } + + PasswordValidation { + is_valid: errors.is_empty(), + errors, + } +} + +/// Validate email format. +fn validate_email(email: &str) -> bool { + // Basic email validation - must contain @ and at least one . after @ + let parts: Vec<&str> = email.split('@').collect(); + if parts.len() != 2 { + return false; + } + let local = parts[0]; + let domain = parts[1]; + // Local part must not be empty + if local.is_empty() { + return false; + } + // Domain must have at least one dot and not start/end with dot + domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.') +} + +// ============================================================================= +// Supabase Admin Client +// ============================================================================= + +/// Supabase Admin API client for user management operations. +/// Uses the service role key for admin-level operations. +pub struct SupabaseAdminClient { + base_url: String, + secret_api_key: String, + client: reqwest::Client, +} + +impl SupabaseAdminClient { + /// Create a new Supabase admin client from environment variables. + pub fn from_env() -> Option<Self> { + let base_url = std::env::var("SUPABASE_URL").ok()?; + let secret_api_key = std::env::var("SUPABASE_SECRET_API_KEY").ok()?; + + Some(Self { + base_url, + secret_api_key, + client: reqwest::Client::new(), + }) + } + + /// Verify a user's password by attempting to sign in. + pub async fn verify_password(&self, email: &str, password: &str) -> Result<bool, String> { + let url = format!("{}/auth/v1/token?grant_type=password", self.base_url); + + let response = self + .client + .post(&url) + .header("apikey", &self.secret_api_key) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": email, + "password": password + })) + .send() + .await + .map_err(|e| format!("Failed to verify password: {}", e))?; + + Ok(response.status().is_success()) + } + + /// Update a user's password. + pub async fn update_password( + &self, + user_id: &str, + new_password: &str, + ) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .put(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "password": new_password + })) + .send() + .await + .map_err(|e| format!("Failed to update password: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update password: {}", error_text)) + } + } + + /// Update a user's email. + pub async fn update_email( + &self, + user_id: &str, + new_email: &str, + ) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .put(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": new_email, + "email_confirm": true + })) + .send() + .await + .map_err(|e| format!("Failed to update email: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update email: {}", error_text)) + } + } + + /// Delete a user from Supabase Auth. + pub async fn delete_user(&self, user_id: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .delete(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .send() + .await + .map_err(|e| format!("Failed to delete user: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to delete user: {}", error_text)) + } + } + + /// Get user info including email. + pub async fn get_user(&self, user_id: &str) -> Result<Option<String>, String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .get(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .send() + .await + .map_err(|e| format!("Failed to get user: {}", e))?; + + if response.status().is_success() { + let json: serde_json::Value = response + .json() + .await + .map_err(|e| format!("Failed to parse user data: {}", e))?; + Ok(json.get("email").and_then(|e| e.as_str()).map(String::from)) + } else if response.status() == reqwest::StatusCode::NOT_FOUND { + Ok(None) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to get user: {}", error_text)) + } + } +} + +// ============================================================================= +// Supabase User Client (uses user's JWT, no admin required) +// ============================================================================= + +/// Supabase User API client for self-service operations. +/// Uses the user's JWT token - no admin/service role key required. +pub struct SupabaseUserClient { + base_url: String, + anon_key: String, + jwt_token: String, + client: reqwest::Client, +} + +impl SupabaseUserClient { + /// Create a new Supabase user client from environment and JWT token. + pub fn new(jwt_token: String) -> Option<Self> { + let base_url = std::env::var("SUPABASE_URL").ok()?; + let anon_key = std::env::var("SUPABASE_ANON_KEY").ok()?; + + Some(Self { + base_url, + anon_key, + jwt_token, + client: reqwest::Client::new(), + }) + } + + /// Update the user's password using their own JWT. + pub async fn update_password(&self, new_password: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/user", self.base_url); + + let response = self + .client + .put(&url) + .header("apikey", &self.anon_key) + .header("Authorization", format!("Bearer {}", self.jwt_token)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "password": new_password + })) + .send() + .await + .map_err(|e| format!("Failed to update password: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update password: {}", error_text)) + } + } + + /// Update the user's email using their own JWT. + pub async fn update_email(&self, new_email: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/user", self.base_url); + + let response = self + .client + .put(&url) + .header("apikey", &self.anon_key) + .header("Authorization", format!("Bearer {}", self.jwt_token)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": new_email + })) + .send() + .await + .map_err(|e| format!("Failed to update email: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update email: {}", error_text)) + } + } + + /// Verify current password by attempting to sign in. + pub async fn verify_password(&self, email: &str, password: &str) -> Result<bool, String> { + let url = format!("{}/auth/v1/token?grant_type=password", self.base_url); + + let response = self + .client + .post(&url) + .header("apikey", &self.anon_key) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": email, + "password": password + })) + .send() + .await + .map_err(|e| format!("Failed to verify password: {}", e))?; + + Ok(response.status().is_success()) + } +} + +// ============================================================================= +// Handlers +// ============================================================================= + +/// Change the authenticated user's password. +/// +/// Requires verification of the current password before allowing the change. +/// The new password must meet strength requirements. +#[utoipa::path( + put, + path = "/api/v1/users/me/password", + request_body = ChangePasswordRequest, + responses( + (status = 200, description = "Password changed successfully", body = ChangePasswordResponse), + (status = 400, description = "Invalid request (weak password, wrong current password)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn change_password_handler( + State(_state): State<SharedState>, + headers: HeaderMap, + UserOnly(user): UserOnly, + Json(req): Json<ChangePasswordRequest>, +) -> impl IntoResponse { + // Validate new password strength + let validation = validate_password_strength(&req.new_password); + if !validation.is_valid { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "WEAK_PASSWORD", + &validation.errors.join("; "), + )), + ) + .into_response(); + } + + // Get user's email (required for password verification) + let email = match &user.email { + Some(email) => email.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Extract JWT from Authorization header for user-level API calls + let jwt_token = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|s| s.to_string()); + + // Try user client first (uses JWT, no admin required), fall back to admin client + if let Some(token) = jwt_token { + if let Some(user_client) = SupabaseUserClient::new(token) { + // Verify current password + match user_client.verify_password(&email, &req.current_password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Current password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update password using user's JWT + return match user_client.update_password(&req.new_password).await { + Ok(()) => { + tracing::info!("Password changed for user {}", user.user_id); + Json(ChangePasswordResponse { + success: true, + message: "Password changed successfully".to_string(), + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to update password: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update password")), + ) + .into_response() + } + }; + } + } + + // Fall back to admin client if user client not available + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_NOT_CONFIGURED", + "Supabase not configured. Ensure SUPABASE_URL and SUPABASE_ANON_KEY are set.", + )), + ) + .into_response(); + } + }; + + // Verify current password + match admin_client.verify_password(&email, &req.current_password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Current password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update password in Supabase + match admin_client + .update_password(&user.user_id.to_string(), &req.new_password) + .await + { + Ok(()) => { + tracing::info!("Password changed for user {}", user.user_id); + Json(ChangePasswordResponse { + success: true, + message: "Password changed successfully".to_string(), + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to update password: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update password")), + ) + .into_response() + } + } +} + +/// Change the authenticated user's email address. +/// +/// Requires password verification before allowing the change. +/// The new email will be updated directly (Supabase handles verification if configured). +#[utoipa::path( + put, + path = "/api/v1/users/me/email", + request_body = ChangeEmailRequest, + responses( + (status = 200, description = "Email changed successfully", body = ChangeEmailResponse), + (status = 400, description = "Invalid request (invalid email, wrong password)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn change_email_handler( + State(state): State<SharedState>, + headers: HeaderMap, + UserOnly(user): UserOnly, + Json(req): Json<ChangeEmailRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Validate new email format + if !validate_email(&req.new_email) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_EMAIL", "Invalid email format")), + ) + .into_response(); + } + + // Get user's current email (required for password verification) + let current_email = match &user.email { + Some(email) => email.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Extract JWT from Authorization header for user-level API calls + let jwt_token = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|s| s.to_string()); + + // Try user client first (uses JWT, no admin required), fall back to admin client + if let Some(token) = jwt_token { + if let Some(user_client) = SupabaseUserClient::new(token) { + // Verify password + match user_client.verify_password(¤t_email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update email using user's JWT + if let Err(e) = user_client.update_email(&req.new_email).await { + tracing::error!("Failed to update email in Supabase: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update email")), + ) + .into_response(); + } + + // Update email in our database + if let Err(e) = sqlx::query("UPDATE users SET email = $1, updated_at = NOW() WHERE id = $2") + .bind(&req.new_email) + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to update email in database: {}", e); + } + + tracing::info!( + "Email changed for user {} from {} to {}", + user.user_id, + current_email, + req.new_email + ); + + return Json(ChangeEmailResponse { + success: true, + message: "Email changed successfully".to_string(), + verification_sent: false, + }) + .into_response(); + } + } + + // Fall back to admin client if user client not available + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_NOT_CONFIGURED", + "Supabase not configured. Ensure SUPABASE_URL and SUPABASE_ANON_KEY are set.", + )), + ) + .into_response(); + } + }; + + // Verify password + match admin_client.verify_password(¤t_email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update email in Supabase + if let Err(e) = admin_client + .update_email(&user.user_id.to_string(), &req.new_email) + .await + { + tracing::error!("Failed to update email in Supabase: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update email")), + ) + .into_response(); + } + + // Update email in our database + if let Err(e) = sqlx::query("UPDATE users SET email = $1, updated_at = NOW() WHERE id = $2") + .bind(&req.new_email) + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to update email in database: {}", e); + } + + tracing::info!( + "Email changed for user {} from {} to {}", + user.user_id, + current_email, + req.new_email + ); + + Json(ChangeEmailResponse { + success: true, + message: "Email changed successfully".to_string(), + verification_sent: false, + }) + .into_response() +} + +/// Delete the authenticated user's account. +/// +/// This permanently deletes: +/// - The user's Supabase Auth account +/// - The user's record in our database +/// - All associated data (API keys, files, tasks, etc. via CASCADE) +/// +/// Requires password verification and confirmation text matching the user's email. +#[utoipa::path( + delete, + path = "/api/v1/users/me", + request_body = DeleteAccountRequest, + responses( + (status = 200, description = "Account deleted successfully", body = DeleteAccountResponse), + (status = 400, description = "Invalid request (wrong password, wrong confirmation)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn delete_account_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<DeleteAccountRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get Supabase admin client - required for full account deletion + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_ADMIN_NOT_CONFIGURED", + "Account deletion requires SUPABASE_SECRET_API_KEY to be configured", + )), + ) + .into_response(); + } + }; + + // Verify confirmation is "DELETE MY ACCOUNT" + const REQUIRED_CONFIRMATION: &str = "DELETE MY ACCOUNT"; + if req.confirmation != REQUIRED_CONFIRMATION { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_CONFIRMATION", + format!("Confirmation text must be exactly: {}", REQUIRED_CONFIRMATION), + )), + ) + .into_response(); + } + + // Get user's email (required for password verification) + let email = match &user.email { + Some(e) => e.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Verify password + match admin_client.verify_password(&email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Delete from our database first (this will cascade to related records) + // Get the owner_id before deleting + let owner_id = user.owner_id; + + // Delete API keys for this user (explicit deletion for audit purposes) + if let Err(e) = sqlx::query("UPDATE api_keys SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL") + .bind(user.user_id) + .execute(pool) + .await + { + tracing::warn!("Failed to revoke API keys during account deletion: {}", e); + } + + // Delete user record + if let Err(e) = sqlx::query("DELETE FROM users WHERE id = $1") + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to delete user from database: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to delete account")), + ) + .into_response(); + } + + // Delete files owned by this user + if let Err(e) = sqlx::query("DELETE FROM files WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user files: {}", e); + } + + // Delete tasks owned by this user + if let Err(e) = sqlx::query("DELETE FROM tasks WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user tasks: {}", e); + } + + // Delete mesh chat conversations owned by this user + if let Err(e) = sqlx::query("DELETE FROM mesh_chat_conversations WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete mesh chat conversations: {}", e); + } + + // Delete daemons owned by this user + if let Err(e) = sqlx::query("DELETE FROM daemons WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user daemons: {}", e); + } + + // Delete owner record + if let Err(e) = sqlx::query("DELETE FROM owners WHERE id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete owner record: {}", e); + } + + // Delete from Supabase Auth + if let Err(e) = admin_client.delete_user(&user.user_id.to_string()).await { + tracing::error!("Failed to delete user from Supabase Auth: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new( + "SUPABASE_DELETE_FAILED", + "Failed to delete user from authentication system", + )), + ) + .into_response(); + } + + tracing::info!("Account deleted for user {} ({})", user.user_id, email); + + Json(DeleteAccountResponse { + success: true, + message: "Account deleted successfully".to_string(), + }) + .into_response() +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_password_validation_success() { + // Minimum 6 characters + let result = validate_password_strength("abcdef"); + assert!(result.is_valid); + assert!(result.errors.is_empty()); + + let result = validate_password_strength("Password123"); + assert!(result.is_valid); + assert!(result.errors.is_empty()); + } + + #[test] + fn test_password_validation_too_short() { + let result = validate_password_strength("12345"); + assert!(!result.is_valid); + assert!(result.errors.iter().any(|e| e.contains("6 characters"))); + } + + #[test] + fn test_email_validation_valid() { + assert!(validate_email("user@example.com")); + assert!(validate_email("user.name@example.co.uk")); + assert!(validate_email("user+tag@example.org")); + } + + #[test] + fn test_email_validation_invalid() { + assert!(!validate_email("userexample.com")); + assert!(!validate_email("user@")); + assert!(!validate_email("@example.com")); + assert!(!validate_email("user@.com")); + assert!(!validate_email("user@example.")); + } +} diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs index ee5e9bd..a096a5c 100644 --- a/makima/src/server/mod.rs +++ b/makima/src/server/mod.rs @@ -1,5 +1,6 @@ //! Web server module for the makima audio API. +pub mod auth; pub mod handlers; pub mod messages; pub mod openapi; @@ -17,7 +18,7 @@ use tower_http::trace::TraceLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::server::handlers::{chat, file_ws, files, listen, versions}; +use crate::server::handlers::{api_keys, chat, file_ws, files, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_ws, users, versions}; use crate::server::openapi::ApiDoc; use crate::server::state::SharedState; @@ -56,6 +57,62 @@ pub fn make_router(state: SharedState) -> Router { .route("/files/{id}/versions", get(versions::list_versions)) .route("/files/{id}/versions/{version}", get(versions::get_version)) .route("/files/{id}/versions/restore", post(versions::restore_version)) + // Mesh/task orchestration endpoints + .route( + "/mesh/tasks", + get(mesh::list_tasks).post(mesh::create_task), + ) + .route( + "/mesh/tasks/{id}", + get(mesh::get_task) + .put(mesh::update_task) + .delete(mesh::delete_task), + ) + .route("/mesh/tasks/{id}/subtasks", get(mesh::list_subtasks)) + .route("/mesh/tasks/{id}/events", get(mesh::list_task_events)) + .route("/mesh/tasks/{id}/output", get(mesh::get_task_output)) + .route("/mesh/tasks/{id}/start", post(mesh::start_task)) + .route("/mesh/tasks/{id}/stop", post(mesh::stop_task)) + .route("/mesh/tasks/{id}/message", post(mesh::send_message)) + .route("/mesh/tasks/{id}/retry-completion", post(mesh::retry_completion_action)) + .route("/mesh/tasks/{id}/clone", post(mesh::clone_worktree)) + .route("/mesh/tasks/{id}/check-target", post(mesh::check_target_exists)) + .route("/mesh/chat", post(mesh_chat::mesh_toplevel_chat_handler)) + .route( + "/mesh/chat/history", + get(mesh_chat::get_chat_history).delete(mesh_chat::clear_chat_history), + ) + .route("/mesh/tasks/{id}/chat", post(mesh_chat::mesh_chat_handler)) + .route("/mesh/daemons", get(mesh::list_daemons)) + .route("/mesh/daemons/directories", get(mesh::get_daemon_directories)) + .route("/mesh/daemons/{id}", get(mesh::get_daemon)) + // Merge endpoints for orchestrators + .route("/mesh/tasks/{id}/branches", get(mesh_merge::list_branches)) + .route("/mesh/tasks/{id}/merge/start", post(mesh_merge::merge_start)) + .route("/mesh/tasks/{id}/merge/status", get(mesh_merge::merge_status)) + .route("/mesh/tasks/{id}/merge/resolve", post(mesh_merge::merge_resolve)) + .route("/mesh/tasks/{id}/merge/commit", post(mesh_merge::merge_commit)) + .route("/mesh/tasks/{id}/merge/abort", post(mesh_merge::merge_abort)) + .route("/mesh/tasks/{id}/merge/skip", post(mesh_merge::merge_skip)) + .route("/mesh/tasks/{id}/merge/check", get(mesh_merge::merge_check)) + // Mesh WebSocket endpoints + .route("/mesh/tasks/subscribe", get(mesh_ws::task_subscription_handler)) + .route("/mesh/daemons/connect", get(mesh_daemon::daemon_handler)) + // API key management endpoints + .route( + "/auth/api-keys", + post(api_keys::create_api_key_handler) + .get(api_keys::get_api_key_handler) + .delete(api_keys::revoke_api_key_handler), + ) + .route("/auth/api-keys/refresh", post(api_keys::refresh_api_key_handler)) + // User account management endpoints + .route( + "/users/me", + axum::routing::delete(users::delete_account_handler), + ) + .route("/users/me/password", axum::routing::put(users::change_password_handler)) + .route("/users/me/email", axum::routing::put(users::change_email_handler)) .with_state(state); let swagger = SwaggerUi::new("/swagger-ui") diff --git a/makima/src/server/openapi.rs b/makima/src/server/openapi.rs index b946ff3..425c466 100644 --- a/makima/src/server/openapi.rs +++ b/makima/src/server/openapi.rs @@ -3,9 +3,19 @@ use utoipa::OpenApi; use crate::db::models::{ - CreateFileRequest, File, FileListResponse, FileSummary, TranscriptEntry, UpdateFileRequest, + BranchInfo, BranchListResponse, CreateFileRequest, CreateTaskRequest, Daemon, + DaemonDirectoriesResponse, DaemonDirectory, DaemonListResponse, File, FileListResponse, + FileSummary, MergeCommitRequest, MergeCompleteCheckResponse, MergeResolveRequest, + MergeResultResponse, MergeSkipRequest, MergeStartRequest, MergeStatusResponse, + MeshChatConversation, MeshChatHistoryResponse, MeshChatMessageRecord, SendMessageRequest, + Task, TaskEventListResponse, TaskListResponse, TaskSummary, TaskWithSubtasks, TranscriptEntry, + UpdateFileRequest, UpdateTaskRequest, }; -use crate::server::handlers::{files, listen}; +use crate::server::auth::{ + ApiKey, ApiKeyInfoResponse, CreateApiKeyRequest, CreateApiKeyResponse, + RefreshApiKeyRequest, RefreshApiKeyResponse, RevokeApiKeyResponse, +}; +use crate::server::handlers::{api_keys, files, listen, mesh, mesh_chat, mesh_merge, users}; use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage, TranscriptMessage}; #[derive(OpenApi)] @@ -23,6 +33,44 @@ use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage files::create_file, files::update_file, files::delete_file, + // Mesh endpoints + mesh::list_tasks, + mesh::get_task, + mesh::create_task, + mesh::update_task, + mesh::delete_task, + mesh::list_subtasks, + mesh::list_task_events, + mesh::get_task_output, + mesh::start_task, + mesh::stop_task, + mesh::send_message, + mesh::retry_completion_action, + mesh::list_daemons, + mesh::get_daemon, + mesh::get_daemon_directories, + mesh::clone_worktree, + mesh::check_target_exists, + mesh_chat::get_chat_history, + mesh_chat::clear_chat_history, + // Merge endpoints + mesh_merge::list_branches, + mesh_merge::merge_start, + mesh_merge::merge_status, + mesh_merge::merge_resolve, + mesh_merge::merge_commit, + mesh_merge::merge_abort, + mesh_merge::merge_skip, + mesh_merge::merge_check, + // API key endpoints + api_keys::create_api_key_handler, + api_keys::get_api_key_handler, + api_keys::refresh_api_key_handler, + api_keys::revoke_api_key_handler, + // User account management endpoints + users::change_password_handler, + users::change_email_handler, + users::delete_account_handler, ), components( schemas( @@ -38,11 +86,55 @@ use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage CreateFileRequest, UpdateFileRequest, TranscriptEntry, + // Mesh/Task schemas + Task, + TaskSummary, + TaskListResponse, + TaskWithSubtasks, + CreateTaskRequest, + UpdateTaskRequest, + SendMessageRequest, + TaskEventListResponse, + Daemon, + DaemonListResponse, + DaemonDirectoriesResponse, + DaemonDirectory, + MeshChatConversation, + MeshChatMessageRecord, + MeshChatHistoryResponse, + // Merge schemas + BranchInfo, + BranchListResponse, + MergeStartRequest, + MergeStatusResponse, + MergeResolveRequest, + MergeCommitRequest, + MergeSkipRequest, + MergeResultResponse, + MergeCompleteCheckResponse, + // API key schemas + ApiKey, + ApiKeyInfoResponse, + CreateApiKeyRequest, + CreateApiKeyResponse, + RefreshApiKeyRequest, + RefreshApiKeyResponse, + RevokeApiKeyResponse, + // User account management schemas + users::ChangePasswordRequest, + users::ChangePasswordResponse, + users::ChangeEmailRequest, + users::ChangeEmailResponse, + users::DeleteAccountRequest, + users::DeleteAccountResponse, ) ), tags( (name = "Listen", description = "Speech-to-text streaming endpoints"), (name = "Files", description = "Transcript file management"), + (name = "Mesh", description = "Task orchestration for Claude Code instances"), + (name = "API Keys", description = "API key management for programmatic access"), + (name = "Users", description = "User account management"), ) )] pub struct ApiDoc; diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index 239ab77..e89197a 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -1,11 +1,13 @@ //! Application state holding shared ML models and database pool. use std::sync::Arc; +use dashmap::DashMap; use sqlx::PgPool; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, mpsc, Mutex}; use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; +use crate::server::auth::{AuthConfig, JwtVerifier}; /// Notification payload for file updates (broadcast to WebSocket subscribers). #[derive(Debug, Clone)] @@ -20,6 +22,262 @@ pub struct FileUpdateNotification { pub updated_by: String, } +// ============================================================================= +// Task/Mesh Notifications +// ============================================================================= + +/// Notification payload for task updates (broadcast to WebSocket subscribers). +#[derive(Debug, Clone)] +pub struct TaskUpdateNotification { + /// ID of the updated task + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + pub owner_id: Option<Uuid>, + /// New version number after update + pub version: i32, + /// Current task status + pub status: String, + /// List of fields that were updated + pub updated_fields: Vec<String>, + /// Source of the update: "user", "daemon", or "system" + pub updated_by: String, +} + +/// Notification for streaming task output from Claude Code containers. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskOutputNotification { + /// ID of the task producing output + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + #[serde(skip)] + pub owner_id: Option<Uuid>, + /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + pub message_type: String, + /// Main text content of the message + pub content: String, + /// Tool name if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<String>, + /// Tool input (JSON) if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_input: Option<serde_json::Value>, + /// Whether tool result was an error + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option<bool>, + /// Cost in USD if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_usd: Option<f64>, + /// Duration in milliseconds if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option<u64>, + /// Whether this is a partial line (more coming) or complete + pub is_partial: bool, +} + +/// Command sent from server to daemon. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonCommand { + /// Confirm successful authentication + Authenticated { + #[serde(rename = "daemonId")] + daemon_id: Uuid, + }, + /// Spawn a new task in a container + SpawnTask { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + plan: String, + #[serde(rename = "repoUrl")] + repo_url: Option<String>, + #[serde(rename = "baseBranch")] + base_branch: Option<String>, + /// Target branch to merge into (used for completion actions) + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + /// Parent task ID if this is a subtask + #[serde(rename = "parentTaskId")] + parent_task_id: Option<Uuid>, + /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask) + depth: i32, + /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks) + #[serde(rename = "isOrchestrator")] + is_orchestrator: bool, + /// Path to user's local repository (outside ~/.makima) for completion actions + #[serde(rename = "targetRepoPath")] + target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr" + #[serde(rename = "completionAction")] + completion_action: Option<String>, + /// Task ID to continue from (copy worktree from this task) + #[serde(rename = "continueFromTaskId")] + continue_from_task_id: Option<Uuid>, + /// Files to copy from parent task's worktree + #[serde(rename = "copyFiles")] + copy_files: Option<Vec<String>>, + }, + /// Pause a running task + PauseTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resume a paused task + ResumeTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Interrupt a task (gracefully or forced) + InterruptTask { + #[serde(rename = "taskId")] + task_id: Uuid, + graceful: bool, + }, + /// Send a message to a running task + SendMessage { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Inject context about sibling task progress + InjectSiblingContext { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "siblingTaskId")] + sibling_task_id: Uuid, + #[serde(rename = "siblingName")] + sibling_name: String, + #[serde(rename = "siblingStatus")] + sibling_status: String, + #[serde(rename = "progressSummary")] + progress_summary: Option<String>, + #[serde(rename = "changedFiles")] + changed_files: Vec<String>, + }, + + // ========================================================================= + // Merge Commands (for orchestrators to merge subtask branches) + // ========================================================================= + + /// List all subtask branches for a task + ListBranches { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Start merging a subtask branch + MergeStart { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "sourceBranch")] + source_branch: String, + }, + /// Get current merge status + MergeStatus { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resolve a merge conflict + MergeResolve { + #[serde(rename = "taskId")] + task_id: Uuid, + file: String, + /// "ours" or "theirs" + strategy: String, + }, + /// Commit the current merge + MergeCommit { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Abort the current merge + MergeAbort { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Skip merging a subtask branch (mark as intentionally not merged) + MergeSkip { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "subtaskId")] + subtask_id: Uuid, + reason: String, + }, + /// Check if all subtask branches have been merged or skipped (completion gate) + CheckMergeComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + // ========================================================================= + // Completion Action Commands + // ========================================================================= + + /// Retry a completion action for a completed task + RetryCompletionAction { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + /// The action to execute: "branch", "merge", or "pr" + action: String, + /// Path to the target repository + #[serde(rename = "targetRepoPath")] + target_repo_path: String, + /// Target branch to merge into (for merge/pr actions) + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + }, + + /// Clone worktree to a target directory + CloneWorktree { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to the target directory + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Check if a target directory exists + CheckTargetExists { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to check + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Error response + Error { code: String, message: String }, +} + +/// Active daemon connection info stored in state. +#[derive(Debug)] +pub struct DaemonConnectionInfo { + /// Database ID of the daemon + pub id: Uuid, + /// Owner ID for data isolation (from API key authentication) + pub owner_id: Uuid, + /// WebSocket connection identifier + pub connection_id: String, + /// Daemon hostname + pub hostname: Option<String>, + /// Machine identifier + pub machine_id: Option<String>, + /// Channel to send commands to this daemon + pub command_sender: mpsc::Sender<DaemonCommand>, + /// Current working directory of the daemon + pub working_directory: Option<String>, + /// Path to ~/.makima/home directory on daemon (for cloning completed work) + pub home_directory: Option<String>, + /// Path to worktrees directory (~/.makima/worktrees) on daemon + pub worktrees_directory: Option<String>, +} + /// Shared application state containing ML models and database pool. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. @@ -34,6 +292,16 @@ pub struct AppState { pub db_pool: Option<PgPool>, /// Broadcast channel for file update notifications pub file_updates: broadcast::Sender<FileUpdateNotification>, + /// Broadcast channel for task update notifications + pub task_updates: broadcast::Sender<TaskUpdateNotification>, + /// Broadcast channel for task output streaming + pub task_output: broadcast::Sender<TaskOutputNotification>, + /// Active daemon connections (keyed by connection_id) + pub daemon_connections: DashMap<String, DaemonConnectionInfo>, + /// Tool keys for orchestrator API access (key -> task_id) + pub tool_keys: DashMap<String, Uuid>, + /// JWT verifier for Supabase authentication (None if not configured) + pub jwt_verifier: Option<JwtVerifier>, } impl AppState { @@ -56,8 +324,38 @@ impl AppState { DiarizationConfig::callhome(), )?; - // Create broadcast channel with buffer for 256 messages + // Create broadcast channels with buffer for 256 messages let (file_updates, _) = broadcast::channel(256); + let (task_updates, _) = broadcast::channel(256); + let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming + + // Initialize JWT verifier from environment (optional) + // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256) + let jwt_verifier = match AuthConfig::from_env() { + Some(config) => match JwtVerifier::new(config) { + Ok(verifier) => { + tracing::info!("JWT authentication configured"); + Some(verifier) + } + Err(e) => { + tracing::error!("Failed to initialize JWT verifier: {}", e); + None + } + }, + None => { + // Log which env vars are missing + let has_url = std::env::var("SUPABASE_URL").is_ok(); + let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok(); + let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok(); + + if !has_url { + tracing::info!("JWT authentication not configured (SUPABASE_URL not set)"); + } else if !has_public_key && !has_secret { + tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)"); + } + None + } + }; Ok(Self { parakeet: Mutex::new(parakeet), @@ -65,6 +363,11 @@ impl AppState { sortformer: Mutex::new(sortformer), db_pool: None, file_updates, + task_updates, + task_output, + daemon_connections: DashMap::new(), + tool_keys: DashMap::new(), + jwt_verifier, }) } @@ -81,6 +384,166 @@ impl AppState { // Ignore send errors - they just mean no one is listening let _ = self.file_updates.send(notification); } + + /// Broadcast a task update notification to all subscribers. + /// + /// This is a no-op if there are no subscribers (ignores send errors). + pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) { + let _ = self.task_updates.send(notification); + } + + /// Broadcast task output to all subscribers. + /// + /// Used for streaming Claude Code container output to frontend clients. + pub fn broadcast_task_output(&self, notification: TaskOutputNotification) { + let _ = self.task_output.send(notification); + } + + /// Register a new daemon connection. + /// + /// Returns the connection_id for later reference. + pub fn register_daemon( + &self, + connection_id: String, + daemon_id: Uuid, + owner_id: Uuid, + hostname: Option<String>, + machine_id: Option<String>, + command_sender: mpsc::Sender<DaemonCommand>, + ) { + self.daemon_connections.insert( + connection_id.clone(), + DaemonConnectionInfo { + id: daemon_id, + owner_id, + connection_id, + hostname, + machine_id, + command_sender, + working_directory: None, + home_directory: None, + worktrees_directory: None, + }, + ); + } + + /// Update daemon directory information. + pub fn update_daemon_directories( + &self, + connection_id: &str, + working_directory: String, + home_directory: String, + worktrees_directory: String, + ) { + if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) { + entry.working_directory = Some(working_directory); + entry.home_directory = Some(home_directory); + entry.worktrees_directory = Some(worktrees_directory); + } + } + + /// Unregister a daemon connection. + pub fn unregister_daemon(&self, connection_id: &str) { + self.daemon_connections.remove(connection_id); + } + + /// Get a daemon connection by connection_id. + pub fn get_daemon(&self, connection_id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> { + self.daemon_connections.get(connection_id) + } + + /// Get a daemon by its database ID. + pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> { + self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + .map(|entry| { + // Return a reference to the found entry + self.daemon_connections.get(entry.key()).unwrap() + }) + } + + /// Send a command to a specific daemon by its database ID. + pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> { + if let Some(daemon) = self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + { + daemon.value().command_sender.send(command).await + .map_err(|e| format!("Failed to send command to daemon: {}", e)) + } else { + Err(format!("Daemon {} not connected", daemon_id)) + } + } + + /// Broadcast sibling progress to all running sibling tasks. + /// + /// This is used for sibling awareness - when a task makes progress, + /// its siblings are notified so they can adjust their work if needed. + pub async fn broadcast_sibling_progress( + &self, + source_task_id: Uuid, + source_task_name: &str, + source_task_status: &str, + progress_summary: Option<String>, + changed_files: Vec<String>, + running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id) + ) { + for (sibling_task_id, daemon_id) in running_sibling_daemon_ids { + let command = DaemonCommand::InjectSiblingContext { + task_id: sibling_task_id, + sibling_task_id: source_task_id, + sibling_name: source_task_name.to_string(), + sibling_status: source_task_status.to_string(), + progress_summary: progress_summary.clone(), + changed_files: changed_files.clone(), + }; + + // Fire and forget - don't block on sending to all daemons + if let Err(e) = self.send_daemon_command(daemon_id, command).await { + tracing::warn!( + "Failed to inject sibling context to task {}: {}", + sibling_task_id, + e + ); + } + } + } + + /// Get list of connected daemon IDs. + pub fn list_connected_daemon_ids(&self) -> Vec<Uuid> { + self.daemon_connections + .iter() + .map(|entry| entry.value().id) + .collect() + } + + // ========================================================================= + // Tool Key Management + // ========================================================================= + + /// Register a tool key for a task. + /// + /// This allows orchestrators to authenticate with the API using + /// the `X-Makima-Tool-Key` header. + pub fn register_tool_key(&self, key: String, task_id: Uuid) { + tracing::info!(task_id = %task_id, "Registering tool key"); + self.tool_keys.insert(key, task_id); + } + + /// Validate a tool key and return the associated task ID. + pub fn validate_tool_key(&self, key: &str) -> Option<Uuid> { + self.tool_keys.get(key).map(|entry| *entry.value()) + } + + /// Revoke a tool key for a task. + /// + /// This should be called when a task completes or is terminated. + pub fn revoke_tool_key(&self, task_id: Uuid) { + // Find and remove the key for this task + self.tool_keys.retain(|_, v| *v != task_id); + tracing::info!(task_id = %task_id, "Revoked tool key"); + } } /// Type alias for the shared application state. |
