From 8b17a175c3e7e27b789812eba4e3cd760beadb10 Mon Sep 17 00:00:00 2001 From: soryu Date: Tue, 6 Jan 2026 04:08:11 +0000 Subject: Initial Control system --- makima/src/server/state.rs | 467 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 465 insertions(+), 2 deletions(-) (limited to 'makima/src/server/state.rs') diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index 239ab77..e89197a 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -1,11 +1,13 @@ //! Application state holding shared ML models and database pool. use std::sync::Arc; +use dashmap::DashMap; use sqlx::PgPool; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, mpsc, Mutex}; use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; +use crate::server::auth::{AuthConfig, JwtVerifier}; /// Notification payload for file updates (broadcast to WebSocket subscribers). #[derive(Debug, Clone)] @@ -20,6 +22,262 @@ pub struct FileUpdateNotification { pub updated_by: String, } +// ============================================================================= +// Task/Mesh Notifications +// ============================================================================= + +/// Notification payload for task updates (broadcast to WebSocket subscribers). +#[derive(Debug, Clone)] +pub struct TaskUpdateNotification { + /// ID of the updated task + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + pub owner_id: Option, + /// New version number after update + pub version: i32, + /// Current task status + pub status: String, + /// List of fields that were updated + pub updated_fields: Vec, + /// Source of the update: "user", "daemon", or "system" + pub updated_by: String, +} + +/// Notification for streaming task output from Claude Code containers. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskOutputNotification { + /// ID of the task producing output + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + #[serde(skip)] + pub owner_id: Option, + /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + pub message_type: String, + /// Main text content of the message + pub content: String, + /// Tool name if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option, + /// Tool input (JSON) if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_input: Option, + /// Whether tool result was an error + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, + /// Cost in USD if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_usd: Option, + /// Duration in milliseconds if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Whether this is a partial line (more coming) or complete + pub is_partial: bool, +} + +/// Command sent from server to daemon. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonCommand { + /// Confirm successful authentication + Authenticated { + #[serde(rename = "daemonId")] + daemon_id: Uuid, + }, + /// Spawn a new task in a container + SpawnTask { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + plan: String, + #[serde(rename = "repoUrl")] + repo_url: Option, + #[serde(rename = "baseBranch")] + base_branch: Option, + /// Target branch to merge into (used for completion actions) + #[serde(rename = "targetBranch")] + target_branch: Option, + /// Parent task ID if this is a subtask + #[serde(rename = "parentTaskId")] + parent_task_id: Option, + /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask) + depth: i32, + /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks) + #[serde(rename = "isOrchestrator")] + is_orchestrator: bool, + /// Path to user's local repository (outside ~/.makima) for completion actions + #[serde(rename = "targetRepoPath")] + target_repo_path: Option, + /// Action on completion: "none", "branch", "merge", "pr" + #[serde(rename = "completionAction")] + completion_action: Option, + /// Task ID to continue from (copy worktree from this task) + #[serde(rename = "continueFromTaskId")] + continue_from_task_id: Option, + /// Files to copy from parent task's worktree + #[serde(rename = "copyFiles")] + copy_files: Option>, + }, + /// Pause a running task + PauseTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resume a paused task + ResumeTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Interrupt a task (gracefully or forced) + InterruptTask { + #[serde(rename = "taskId")] + task_id: Uuid, + graceful: bool, + }, + /// Send a message to a running task + SendMessage { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Inject context about sibling task progress + InjectSiblingContext { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "siblingTaskId")] + sibling_task_id: Uuid, + #[serde(rename = "siblingName")] + sibling_name: String, + #[serde(rename = "siblingStatus")] + sibling_status: String, + #[serde(rename = "progressSummary")] + progress_summary: Option, + #[serde(rename = "changedFiles")] + changed_files: Vec, + }, + + // ========================================================================= + // Merge Commands (for orchestrators to merge subtask branches) + // ========================================================================= + + /// List all subtask branches for a task + ListBranches { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Start merging a subtask branch + MergeStart { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "sourceBranch")] + source_branch: String, + }, + /// Get current merge status + MergeStatus { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resolve a merge conflict + MergeResolve { + #[serde(rename = "taskId")] + task_id: Uuid, + file: String, + /// "ours" or "theirs" + strategy: String, + }, + /// Commit the current merge + MergeCommit { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Abort the current merge + MergeAbort { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Skip merging a subtask branch (mark as intentionally not merged) + MergeSkip { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "subtaskId")] + subtask_id: Uuid, + reason: String, + }, + /// Check if all subtask branches have been merged or skipped (completion gate) + CheckMergeComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + // ========================================================================= + // Completion Action Commands + // ========================================================================= + + /// Retry a completion action for a completed task + RetryCompletionAction { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + /// The action to execute: "branch", "merge", or "pr" + action: String, + /// Path to the target repository + #[serde(rename = "targetRepoPath")] + target_repo_path: String, + /// Target branch to merge into (for merge/pr actions) + #[serde(rename = "targetBranch")] + target_branch: Option, + }, + + /// Clone worktree to a target directory + CloneWorktree { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to the target directory + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Check if a target directory exists + CheckTargetExists { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to check + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Error response + Error { code: String, message: String }, +} + +/// Active daemon connection info stored in state. +#[derive(Debug)] +pub struct DaemonConnectionInfo { + /// Database ID of the daemon + pub id: Uuid, + /// Owner ID for data isolation (from API key authentication) + pub owner_id: Uuid, + /// WebSocket connection identifier + pub connection_id: String, + /// Daemon hostname + pub hostname: Option, + /// Machine identifier + pub machine_id: Option, + /// Channel to send commands to this daemon + pub command_sender: mpsc::Sender, + /// Current working directory of the daemon + pub working_directory: Option, + /// Path to ~/.makima/home directory on daemon (for cloning completed work) + pub home_directory: Option, + /// Path to worktrees directory (~/.makima/worktrees) on daemon + pub worktrees_directory: Option, +} + /// Shared application state containing ML models and database pool. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. @@ -34,6 +292,16 @@ pub struct AppState { pub db_pool: Option, /// Broadcast channel for file update notifications pub file_updates: broadcast::Sender, + /// Broadcast channel for task update notifications + pub task_updates: broadcast::Sender, + /// Broadcast channel for task output streaming + pub task_output: broadcast::Sender, + /// Active daemon connections (keyed by connection_id) + pub daemon_connections: DashMap, + /// Tool keys for orchestrator API access (key -> task_id) + pub tool_keys: DashMap, + /// JWT verifier for Supabase authentication (None if not configured) + pub jwt_verifier: Option, } impl AppState { @@ -56,8 +324,38 @@ impl AppState { DiarizationConfig::callhome(), )?; - // Create broadcast channel with buffer for 256 messages + // Create broadcast channels with buffer for 256 messages let (file_updates, _) = broadcast::channel(256); + let (task_updates, _) = broadcast::channel(256); + let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming + + // Initialize JWT verifier from environment (optional) + // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256) + let jwt_verifier = match AuthConfig::from_env() { + Some(config) => match JwtVerifier::new(config) { + Ok(verifier) => { + tracing::info!("JWT authentication configured"); + Some(verifier) + } + Err(e) => { + tracing::error!("Failed to initialize JWT verifier: {}", e); + None + } + }, + None => { + // Log which env vars are missing + let has_url = std::env::var("SUPABASE_URL").is_ok(); + let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok(); + let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok(); + + if !has_url { + tracing::info!("JWT authentication not configured (SUPABASE_URL not set)"); + } else if !has_public_key && !has_secret { + tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)"); + } + None + } + }; Ok(Self { parakeet: Mutex::new(parakeet), @@ -65,6 +363,11 @@ impl AppState { sortformer: Mutex::new(sortformer), db_pool: None, file_updates, + task_updates, + task_output, + daemon_connections: DashMap::new(), + tool_keys: DashMap::new(), + jwt_verifier, }) } @@ -81,6 +384,166 @@ impl AppState { // Ignore send errors - they just mean no one is listening let _ = self.file_updates.send(notification); } + + /// Broadcast a task update notification to all subscribers. + /// + /// This is a no-op if there are no subscribers (ignores send errors). + pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) { + let _ = self.task_updates.send(notification); + } + + /// Broadcast task output to all subscribers. + /// + /// Used for streaming Claude Code container output to frontend clients. + pub fn broadcast_task_output(&self, notification: TaskOutputNotification) { + let _ = self.task_output.send(notification); + } + + /// Register a new daemon connection. + /// + /// Returns the connection_id for later reference. + pub fn register_daemon( + &self, + connection_id: String, + daemon_id: Uuid, + owner_id: Uuid, + hostname: Option, + machine_id: Option, + command_sender: mpsc::Sender, + ) { + self.daemon_connections.insert( + connection_id.clone(), + DaemonConnectionInfo { + id: daemon_id, + owner_id, + connection_id, + hostname, + machine_id, + command_sender, + working_directory: None, + home_directory: None, + worktrees_directory: None, + }, + ); + } + + /// Update daemon directory information. + pub fn update_daemon_directories( + &self, + connection_id: &str, + working_directory: String, + home_directory: String, + worktrees_directory: String, + ) { + if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) { + entry.working_directory = Some(working_directory); + entry.home_directory = Some(home_directory); + entry.worktrees_directory = Some(worktrees_directory); + } + } + + /// Unregister a daemon connection. + pub fn unregister_daemon(&self, connection_id: &str) { + self.daemon_connections.remove(connection_id); + } + + /// Get a daemon connection by connection_id. + pub fn get_daemon(&self, connection_id: &str) -> Option> { + self.daemon_connections.get(connection_id) + } + + /// Get a daemon by its database ID. + pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option> { + self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + .map(|entry| { + // Return a reference to the found entry + self.daemon_connections.get(entry.key()).unwrap() + }) + } + + /// Send a command to a specific daemon by its database ID. + pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> { + if let Some(daemon) = self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + { + daemon.value().command_sender.send(command).await + .map_err(|e| format!("Failed to send command to daemon: {}", e)) + } else { + Err(format!("Daemon {} not connected", daemon_id)) + } + } + + /// Broadcast sibling progress to all running sibling tasks. + /// + /// This is used for sibling awareness - when a task makes progress, + /// its siblings are notified so they can adjust their work if needed. + pub async fn broadcast_sibling_progress( + &self, + source_task_id: Uuid, + source_task_name: &str, + source_task_status: &str, + progress_summary: Option, + changed_files: Vec, + running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id) + ) { + for (sibling_task_id, daemon_id) in running_sibling_daemon_ids { + let command = DaemonCommand::InjectSiblingContext { + task_id: sibling_task_id, + sibling_task_id: source_task_id, + sibling_name: source_task_name.to_string(), + sibling_status: source_task_status.to_string(), + progress_summary: progress_summary.clone(), + changed_files: changed_files.clone(), + }; + + // Fire and forget - don't block on sending to all daemons + if let Err(e) = self.send_daemon_command(daemon_id, command).await { + tracing::warn!( + "Failed to inject sibling context to task {}: {}", + sibling_task_id, + e + ); + } + } + } + + /// Get list of connected daemon IDs. + pub fn list_connected_daemon_ids(&self) -> Vec { + self.daemon_connections + .iter() + .map(|entry| entry.value().id) + .collect() + } + + // ========================================================================= + // Tool Key Management + // ========================================================================= + + /// Register a tool key for a task. + /// + /// This allows orchestrators to authenticate with the API using + /// the `X-Makima-Tool-Key` header. + pub fn register_tool_key(&self, key: String, task_id: Uuid) { + tracing::info!(task_id = %task_id, "Registering tool key"); + self.tool_keys.insert(key, task_id); + } + + /// Validate a tool key and return the associated task ID. + pub fn validate_tool_key(&self, key: &str) -> Option { + self.tool_keys.get(key).map(|entry| *entry.value()) + } + + /// Revoke a tool key for a task. + /// + /// This should be called when a task completes or is terminated. + pub fn revoke_tool_key(&self, task_id: Uuid) { + // Find and remove the key for this task + self.tool_keys.retain(|_, v| *v != task_id); + tracing::info!(task_id = %task_id, "Revoked tool key"); + } } /// Type alias for the shared application state. -- cgit v1.2.3