summaryrefslogtreecommitdiff
path: root/makima/src/server/state.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-06 04:08:11 +0000
committersoryu <soryu@soryu.co>2026-01-11 03:01:13 +0000
commit8b17a175c3e7e27b789812eba4e3cd760beadb10 (patch)
tree7864dcaa2fa9db47fdfd4e8bfdb0b1dde832aa33 /makima/src/server/state.rs
parentf79c416c58557d2f946aa5332989afdfa8c021cd (diff)
downloadsoryu-8b17a175c3e7e27b789812eba4e3cd760beadb10.tar.gz
soryu-8b17a175c3e7e27b789812eba4e3cd760beadb10.zip
Initial Control system
Diffstat (limited to 'makima/src/server/state.rs')
-rw-r--r--makima/src/server/state.rs467
1 files changed, 465 insertions, 2 deletions
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
index 239ab77..e89197a 100644
--- a/makima/src/server/state.rs
+++ b/makima/src/server/state.rs
@@ -1,11 +1,13 @@
//! Application state holding shared ML models and database pool.
use std::sync::Arc;
+use dashmap::DashMap;
use sqlx::PgPool;
-use tokio::sync::{broadcast, Mutex};
+use tokio::sync::{broadcast, mpsc, Mutex};
use uuid::Uuid;
use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
+use crate::server::auth::{AuthConfig, JwtVerifier};
/// Notification payload for file updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
@@ -20,6 +22,262 @@ pub struct FileUpdateNotification {
pub updated_by: String,
}
+// =============================================================================
+// Task/Mesh Notifications
+// =============================================================================
+
+/// Notification payload for task updates (broadcast to WebSocket subscribers).
+#[derive(Debug, Clone)]
+pub struct TaskUpdateNotification {
+ /// ID of the updated task
+ pub task_id: Uuid,
+ /// Owner ID for data isolation (notifications are scoped to owner)
+ pub owner_id: Option<Uuid>,
+ /// New version number after update
+ pub version: i32,
+ /// Current task status
+ pub status: String,
+ /// List of fields that were updated
+ pub updated_fields: Vec<String>,
+ /// Source of the update: "user", "daemon", or "system"
+ pub updated_by: String,
+}
+
+/// Notification for streaming task output from Claude Code containers.
+#[derive(Debug, Clone, serde::Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct TaskOutputNotification {
+ /// ID of the task producing output
+ pub task_id: Uuid,
+ /// Owner ID for data isolation (notifications are scoped to owner)
+ #[serde(skip)]
+ pub owner_id: Option<Uuid>,
+ /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
+ pub message_type: String,
+ /// Main text content of the message
+ pub content: String,
+ /// Tool name if this is a tool_use message
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_name: Option<String>,
+ /// Tool input (JSON) if this is a tool_use message
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_input: Option<serde_json::Value>,
+ /// Whether tool result was an error
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub is_error: Option<bool>,
+ /// Cost in USD if this is a result message
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub cost_usd: Option<f64>,
+ /// Duration in milliseconds if this is a result message
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub duration_ms: Option<u64>,
+ /// Whether this is a partial line (more coming) or complete
+ pub is_partial: bool,
+}
+
+/// Command sent from server to daemon.
+#[derive(Debug, Clone, serde::Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum DaemonCommand {
+ /// Confirm successful authentication
+ Authenticated {
+ #[serde(rename = "daemonId")]
+ daemon_id: Uuid,
+ },
+ /// Spawn a new task in a container
+ SpawnTask {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Human-readable task name (used for commit messages)
+ #[serde(rename = "taskName")]
+ task_name: String,
+ plan: String,
+ #[serde(rename = "repoUrl")]
+ repo_url: Option<String>,
+ #[serde(rename = "baseBranch")]
+ base_branch: Option<String>,
+ /// Target branch to merge into (used for completion actions)
+ #[serde(rename = "targetBranch")]
+ target_branch: Option<String>,
+ /// Parent task ID if this is a subtask
+ #[serde(rename = "parentTaskId")]
+ parent_task_id: Option<Uuid>,
+ /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask)
+ depth: i32,
+ /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks)
+ #[serde(rename = "isOrchestrator")]
+ is_orchestrator: bool,
+ /// Path to user's local repository (outside ~/.makima) for completion actions
+ #[serde(rename = "targetRepoPath")]
+ target_repo_path: Option<String>,
+ /// Action on completion: "none", "branch", "merge", "pr"
+ #[serde(rename = "completionAction")]
+ completion_action: Option<String>,
+ /// Task ID to continue from (copy worktree from this task)
+ #[serde(rename = "continueFromTaskId")]
+ continue_from_task_id: Option<Uuid>,
+ /// Files to copy from parent task's worktree
+ #[serde(rename = "copyFiles")]
+ copy_files: Option<Vec<String>>,
+ },
+ /// Pause a running task
+ PauseTask {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Resume a paused task
+ ResumeTask {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Interrupt a task (gracefully or forced)
+ InterruptTask {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ graceful: bool,
+ },
+ /// Send a message to a running task
+ SendMessage {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ message: String,
+ },
+ /// Inject context about sibling task progress
+ InjectSiblingContext {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ #[serde(rename = "siblingTaskId")]
+ sibling_task_id: Uuid,
+ #[serde(rename = "siblingName")]
+ sibling_name: String,
+ #[serde(rename = "siblingStatus")]
+ sibling_status: String,
+ #[serde(rename = "progressSummary")]
+ progress_summary: Option<String>,
+ #[serde(rename = "changedFiles")]
+ changed_files: Vec<String>,
+ },
+
+ // =========================================================================
+ // Merge Commands (for orchestrators to merge subtask branches)
+ // =========================================================================
+
+ /// List all subtask branches for a task
+ ListBranches {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Start merging a subtask branch
+ MergeStart {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ #[serde(rename = "sourceBranch")]
+ source_branch: String,
+ },
+ /// Get current merge status
+ MergeStatus {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Resolve a merge conflict
+ MergeResolve {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ file: String,
+ /// "ours" or "theirs"
+ strategy: String,
+ },
+ /// Commit the current merge
+ MergeCommit {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ message: String,
+ },
+ /// Abort the current merge
+ MergeAbort {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Skip merging a subtask branch (mark as intentionally not merged)
+ MergeSkip {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ #[serde(rename = "subtaskId")]
+ subtask_id: Uuid,
+ reason: String,
+ },
+ /// Check if all subtask branches have been merged or skipped (completion gate)
+ CheckMergeComplete {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+
+ // =========================================================================
+ // Completion Action Commands
+ // =========================================================================
+
+ /// Retry a completion action for a completed task
+ RetryCompletionAction {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Human-readable task name (used for commit messages)
+ #[serde(rename = "taskName")]
+ task_name: String,
+ /// The action to execute: "branch", "merge", or "pr"
+ action: String,
+ /// Path to the target repository
+ #[serde(rename = "targetRepoPath")]
+ target_repo_path: String,
+ /// Target branch to merge into (for merge/pr actions)
+ #[serde(rename = "targetBranch")]
+ target_branch: Option<String>,
+ },
+
+ /// Clone worktree to a target directory
+ CloneWorktree {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Path to the target directory
+ #[serde(rename = "targetDir")]
+ target_dir: String,
+ },
+
+ /// Check if a target directory exists
+ CheckTargetExists {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Path to check
+ #[serde(rename = "targetDir")]
+ target_dir: String,
+ },
+
+ /// Error response
+ Error { code: String, message: String },
+}
+
+/// Active daemon connection info stored in state.
+#[derive(Debug)]
+pub struct DaemonConnectionInfo {
+ /// Database ID of the daemon
+ pub id: Uuid,
+ /// Owner ID for data isolation (from API key authentication)
+ pub owner_id: Uuid,
+ /// WebSocket connection identifier
+ pub connection_id: String,
+ /// Daemon hostname
+ pub hostname: Option<String>,
+ /// Machine identifier
+ pub machine_id: Option<String>,
+ /// Channel to send commands to this daemon
+ pub command_sender: mpsc::Sender<DaemonCommand>,
+ /// Current working directory of the daemon
+ pub working_directory: Option<String>,
+ /// Path to ~/.makima/home directory on daemon (for cloning completed work)
+ pub home_directory: Option<String>,
+ /// Path to worktrees directory (~/.makima/worktrees) on daemon
+ pub worktrees_directory: Option<String>,
+}
+
/// Shared application state containing ML models and database pool.
///
/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
@@ -34,6 +292,16 @@ pub struct AppState {
pub db_pool: Option<PgPool>,
/// Broadcast channel for file update notifications
pub file_updates: broadcast::Sender<FileUpdateNotification>,
+ /// Broadcast channel for task update notifications
+ pub task_updates: broadcast::Sender<TaskUpdateNotification>,
+ /// Broadcast channel for task output streaming
+ pub task_output: broadcast::Sender<TaskOutputNotification>,
+ /// Active daemon connections (keyed by connection_id)
+ pub daemon_connections: DashMap<String, DaemonConnectionInfo>,
+ /// Tool keys for orchestrator API access (key -> task_id)
+ pub tool_keys: DashMap<String, Uuid>,
+ /// JWT verifier for Supabase authentication (None if not configured)
+ pub jwt_verifier: Option<JwtVerifier>,
}
impl AppState {
@@ -56,8 +324,38 @@ impl AppState {
DiarizationConfig::callhome(),
)?;
- // Create broadcast channel with buffer for 256 messages
+ // Create broadcast channels with buffer for 256 messages
let (file_updates, _) = broadcast::channel(256);
+ let (task_updates, _) = broadcast::channel(256);
+ let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming
+
+ // Initialize JWT verifier from environment (optional)
+ // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256)
+ let jwt_verifier = match AuthConfig::from_env() {
+ Some(config) => match JwtVerifier::new(config) {
+ Ok(verifier) => {
+ tracing::info!("JWT authentication configured");
+ Some(verifier)
+ }
+ Err(e) => {
+ tracing::error!("Failed to initialize JWT verifier: {}", e);
+ None
+ }
+ },
+ None => {
+ // Log which env vars are missing
+ let has_url = std::env::var("SUPABASE_URL").is_ok();
+ let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok();
+ let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok();
+
+ if !has_url {
+ tracing::info!("JWT authentication not configured (SUPABASE_URL not set)");
+ } else if !has_public_key && !has_secret {
+ tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)");
+ }
+ None
+ }
+ };
Ok(Self {
parakeet: Mutex::new(parakeet),
@@ -65,6 +363,11 @@ impl AppState {
sortformer: Mutex::new(sortformer),
db_pool: None,
file_updates,
+ task_updates,
+ task_output,
+ daemon_connections: DashMap::new(),
+ tool_keys: DashMap::new(),
+ jwt_verifier,
})
}
@@ -81,6 +384,166 @@ impl AppState {
// Ignore send errors - they just mean no one is listening
let _ = self.file_updates.send(notification);
}
+
+ /// Broadcast a task update notification to all subscribers.
+ ///
+ /// This is a no-op if there are no subscribers (ignores send errors).
+ pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) {
+ let _ = self.task_updates.send(notification);
+ }
+
+ /// Broadcast task output to all subscribers.
+ ///
+ /// Used for streaming Claude Code container output to frontend clients.
+ pub fn broadcast_task_output(&self, notification: TaskOutputNotification) {
+ let _ = self.task_output.send(notification);
+ }
+
+ /// Register a new daemon connection.
+ ///
+ /// Returns the connection_id for later reference.
+ pub fn register_daemon(
+ &self,
+ connection_id: String,
+ daemon_id: Uuid,
+ owner_id: Uuid,
+ hostname: Option<String>,
+ machine_id: Option<String>,
+ command_sender: mpsc::Sender<DaemonCommand>,
+ ) {
+ self.daemon_connections.insert(
+ connection_id.clone(),
+ DaemonConnectionInfo {
+ id: daemon_id,
+ owner_id,
+ connection_id,
+ hostname,
+ machine_id,
+ command_sender,
+ working_directory: None,
+ home_directory: None,
+ worktrees_directory: None,
+ },
+ );
+ }
+
+ /// Update daemon directory information.
+ pub fn update_daemon_directories(
+ &self,
+ connection_id: &str,
+ working_directory: String,
+ home_directory: String,
+ worktrees_directory: String,
+ ) {
+ if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) {
+ entry.working_directory = Some(working_directory);
+ entry.home_directory = Some(home_directory);
+ entry.worktrees_directory = Some(worktrees_directory);
+ }
+ }
+
+ /// Unregister a daemon connection.
+ pub fn unregister_daemon(&self, connection_id: &str) {
+ self.daemon_connections.remove(connection_id);
+ }
+
+ /// Get a daemon connection by connection_id.
+ pub fn get_daemon(&self, connection_id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
+ self.daemon_connections.get(connection_id)
+ }
+
+ /// Get a daemon by its database ID.
+ pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
+ self.daemon_connections
+ .iter()
+ .find(|entry| entry.value().id == daemon_id)
+ .map(|entry| {
+ // Return a reference to the found entry
+ self.daemon_connections.get(entry.key()).unwrap()
+ })
+ }
+
+ /// Send a command to a specific daemon by its database ID.
+ pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> {
+ if let Some(daemon) = self.daemon_connections
+ .iter()
+ .find(|entry| entry.value().id == daemon_id)
+ {
+ daemon.value().command_sender.send(command).await
+ .map_err(|e| format!("Failed to send command to daemon: {}", e))
+ } else {
+ Err(format!("Daemon {} not connected", daemon_id))
+ }
+ }
+
+ /// Broadcast sibling progress to all running sibling tasks.
+ ///
+ /// This is used for sibling awareness - when a task makes progress,
+ /// its siblings are notified so they can adjust their work if needed.
+ pub async fn broadcast_sibling_progress(
+ &self,
+ source_task_id: Uuid,
+ source_task_name: &str,
+ source_task_status: &str,
+ progress_summary: Option<String>,
+ changed_files: Vec<String>,
+ running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id)
+ ) {
+ for (sibling_task_id, daemon_id) in running_sibling_daemon_ids {
+ let command = DaemonCommand::InjectSiblingContext {
+ task_id: sibling_task_id,
+ sibling_task_id: source_task_id,
+ sibling_name: source_task_name.to_string(),
+ sibling_status: source_task_status.to_string(),
+ progress_summary: progress_summary.clone(),
+ changed_files: changed_files.clone(),
+ };
+
+ // Fire and forget - don't block on sending to all daemons
+ if let Err(e) = self.send_daemon_command(daemon_id, command).await {
+ tracing::warn!(
+ "Failed to inject sibling context to task {}: {}",
+ sibling_task_id,
+ e
+ );
+ }
+ }
+ }
+
+ /// Get list of connected daemon IDs.
+ pub fn list_connected_daemon_ids(&self) -> Vec<Uuid> {
+ self.daemon_connections
+ .iter()
+ .map(|entry| entry.value().id)
+ .collect()
+ }
+
+ // =========================================================================
+ // Tool Key Management
+ // =========================================================================
+
+ /// Register a tool key for a task.
+ ///
+ /// This allows orchestrators to authenticate with the API using
+ /// the `X-Makima-Tool-Key` header.
+ pub fn register_tool_key(&self, key: String, task_id: Uuid) {
+ tracing::info!(task_id = %task_id, "Registering tool key");
+ self.tool_keys.insert(key, task_id);
+ }
+
+ /// Validate a tool key and return the associated task ID.
+ pub fn validate_tool_key(&self, key: &str) -> Option<Uuid> {
+ self.tool_keys.get(key).map(|entry| *entry.value())
+ }
+
+ /// Revoke a tool key for a task.
+ ///
+ /// This should be called when a task completes or is terminated.
+ pub fn revoke_tool_key(&self, task_id: Uuid) {
+ // Find and remove the key for this task
+ self.tool_keys.retain(|_, v| *v != task_id);
+ tracing::info!(task_id = %task_id, "Revoked tool key");
+ }
}
/// Type alias for the shared application state.