//! Application state holding shared ML models and database pool. use std::sync::Arc; use dashmap::DashMap; use sqlx::PgPool; use tokio::sync::{broadcast, mpsc, Mutex}; use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; use crate::server::auth::{AuthConfig, JwtVerifier}; /// Notification payload for file updates (broadcast to WebSocket subscribers). #[derive(Debug, Clone)] pub struct FileUpdateNotification { /// ID of the updated file pub file_id: Uuid, /// New version number after update pub version: i32, /// List of fields that were updated pub updated_fields: Vec, /// Source of the update: "user", "llm", or "system" pub updated_by: String, } // ============================================================================= // Task/Mesh Notifications // ============================================================================= /// Notification payload for task updates (broadcast to WebSocket subscribers). #[derive(Debug, Clone)] pub struct TaskUpdateNotification { /// ID of the updated task pub task_id: Uuid, /// Owner ID for data isolation (notifications are scoped to owner) pub owner_id: Option, /// New version number after update pub version: i32, /// Current task status pub status: String, /// List of fields that were updated pub updated_fields: Vec, /// Source of the update: "user", "daemon", or "system" pub updated_by: String, } /// Notification for streaming task output from Claude Code containers. #[derive(Debug, Clone, serde::Serialize)] #[serde(rename_all = "camelCase")] pub struct TaskOutputNotification { /// ID of the task producing output pub task_id: Uuid, /// Owner ID for data isolation (notifications are scoped to owner) #[serde(skip)] pub owner_id: Option, /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" pub message_type: String, /// Main text content of the message pub content: String, /// Tool name if this is a tool_use message #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option, /// Tool input (JSON) if this is a tool_use message #[serde(skip_serializing_if = "Option::is_none")] pub tool_input: Option, /// Whether tool result was an error #[serde(skip_serializing_if = "Option::is_none")] pub is_error: Option, /// Cost in USD if this is a result message #[serde(skip_serializing_if = "Option::is_none")] pub cost_usd: Option, /// Duration in milliseconds if this is a result message #[serde(skip_serializing_if = "Option::is_none")] pub duration_ms: Option, /// Whether this is a partial line (more coming) or complete pub is_partial: bool, } /// Command sent from server to daemon. #[derive(Debug, Clone, serde::Serialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum DaemonCommand { /// Confirm successful authentication Authenticated { #[serde(rename = "daemonId")] daemon_id: Uuid, }, /// Spawn a new task in a container SpawnTask { #[serde(rename = "taskId")] task_id: Uuid, /// Human-readable task name (used for commit messages) #[serde(rename = "taskName")] task_name: String, plan: String, #[serde(rename = "repoUrl")] repo_url: Option, #[serde(rename = "baseBranch")] base_branch: Option, /// Target branch to merge into (used for completion actions) #[serde(rename = "targetBranch")] target_branch: Option, /// Parent task ID if this is a subtask #[serde(rename = "parentTaskId")] parent_task_id: Option, /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask) depth: i32, /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks) #[serde(rename = "isOrchestrator")] is_orchestrator: bool, /// Path to user's local repository (outside ~/.makima) for completion actions #[serde(rename = "targetRepoPath")] target_repo_path: Option, /// Action on completion: "none", "branch", "merge", "pr" #[serde(rename = "completionAction")] completion_action: Option, /// Task ID to continue from (copy worktree from this task) #[serde(rename = "continueFromTaskId")] continue_from_task_id: Option, /// Files to copy from parent task's worktree #[serde(rename = "copyFiles")] copy_files: Option>, }, /// Pause a running task PauseTask { #[serde(rename = "taskId")] task_id: Uuid, }, /// Resume a paused task ResumeTask { #[serde(rename = "taskId")] task_id: Uuid, }, /// Interrupt a task (gracefully or forced) InterruptTask { #[serde(rename = "taskId")] task_id: Uuid, graceful: bool, }, /// Send a message to a running task SendMessage { #[serde(rename = "taskId")] task_id: Uuid, message: String, }, /// Inject context about sibling task progress InjectSiblingContext { #[serde(rename = "taskId")] task_id: Uuid, #[serde(rename = "siblingTaskId")] sibling_task_id: Uuid, #[serde(rename = "siblingName")] sibling_name: String, #[serde(rename = "siblingStatus")] sibling_status: String, #[serde(rename = "progressSummary")] progress_summary: Option, #[serde(rename = "changedFiles")] changed_files: Vec, }, // ========================================================================= // Merge Commands (for orchestrators to merge subtask branches) // ========================================================================= /// List all subtask branches for a task ListBranches { #[serde(rename = "taskId")] task_id: Uuid, }, /// Start merging a subtask branch MergeStart { #[serde(rename = "taskId")] task_id: Uuid, #[serde(rename = "sourceBranch")] source_branch: String, }, /// Get current merge status MergeStatus { #[serde(rename = "taskId")] task_id: Uuid, }, /// Resolve a merge conflict MergeResolve { #[serde(rename = "taskId")] task_id: Uuid, file: String, /// "ours" or "theirs" strategy: String, }, /// Commit the current merge MergeCommit { #[serde(rename = "taskId")] task_id: Uuid, message: String, }, /// Abort the current merge MergeAbort { #[serde(rename = "taskId")] task_id: Uuid, }, /// Skip merging a subtask branch (mark as intentionally not merged) MergeSkip { #[serde(rename = "taskId")] task_id: Uuid, #[serde(rename = "subtaskId")] subtask_id: Uuid, reason: String, }, /// Check if all subtask branches have been merged or skipped (completion gate) CheckMergeComplete { #[serde(rename = "taskId")] task_id: Uuid, }, // ========================================================================= // Completion Action Commands // ========================================================================= /// Retry a completion action for a completed task RetryCompletionAction { #[serde(rename = "taskId")] task_id: Uuid, /// Human-readable task name (used for commit messages) #[serde(rename = "taskName")] task_name: String, /// The action to execute: "branch", "merge", or "pr" action: String, /// Path to the target repository #[serde(rename = "targetRepoPath")] target_repo_path: String, /// Target branch to merge into (for merge/pr actions) #[serde(rename = "targetBranch")] target_branch: Option, }, /// Clone worktree to a target directory CloneWorktree { #[serde(rename = "taskId")] task_id: Uuid, /// Path to the target directory #[serde(rename = "targetDir")] target_dir: String, }, /// Check if a target directory exists CheckTargetExists { #[serde(rename = "taskId")] task_id: Uuid, /// Path to check #[serde(rename = "targetDir")] target_dir: String, }, /// Error response Error { code: String, message: String }, } /// Active daemon connection info stored in state. #[derive(Debug)] pub struct DaemonConnectionInfo { /// Database ID of the daemon pub id: Uuid, /// Owner ID for data isolation (from API key authentication) pub owner_id: Uuid, /// WebSocket connection identifier pub connection_id: String, /// Daemon hostname pub hostname: Option, /// Machine identifier pub machine_id: Option, /// Channel to send commands to this daemon pub command_sender: mpsc::Sender, /// Current working directory of the daemon pub working_directory: Option, /// Path to ~/.makima/home directory on daemon (for cloning completed work) pub home_directory: Option, /// Path to worktrees directory (~/.makima/worktrees) on daemon pub worktrees_directory: Option, } /// Shared application state containing ML models and database pool. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. pub struct AppState { /// Speech-to-text model (Parakeet TDT) pub parakeet: Mutex, /// End-of-Utterance detection model for streaming pub parakeet_eou: Mutex, /// Speaker diarization model (Sortformer) pub sortformer: Mutex, /// Optional database connection pool pub db_pool: Option, /// Broadcast channel for file update notifications pub file_updates: broadcast::Sender, /// Broadcast channel for task update notifications pub task_updates: broadcast::Sender, /// Broadcast channel for task output streaming pub task_output: broadcast::Sender, /// Active daemon connections (keyed by connection_id) pub daemon_connections: DashMap, /// Tool keys for orchestrator API access (key -> task_id) pub tool_keys: DashMap, /// JWT verifier for Supabase authentication (None if not configured) pub jwt_verifier: Option, } impl AppState { /// Load all ML models from the specified directories. /// /// # Arguments /// * `parakeet_model_dir` - Path to the Parakeet TDT model directory /// * `parakeet_eou_dir` - Path to the Parakeet EOU model directory /// * `sortformer_model_path` - Path to the Sortformer diarization model file pub fn new( parakeet_model_dir: &str, parakeet_eou_dir: &str, sortformer_model_path: &str, ) -> Result> { let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?; let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?; let sortformer = Sortformer::with_config( sortformer_model_path, None, DiarizationConfig::callhome(), )?; // Create broadcast channels with buffer for 256 messages let (file_updates, _) = broadcast::channel(256); let (task_updates, _) = broadcast::channel(256); let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming // Initialize JWT verifier from environment (optional) // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256) let jwt_verifier = match AuthConfig::from_env() { Some(config) => match JwtVerifier::new(config) { Ok(verifier) => { tracing::info!("JWT authentication configured"); Some(verifier) } Err(e) => { tracing::error!("Failed to initialize JWT verifier: {}", e); None } }, None => { // Log which env vars are missing let has_url = std::env::var("SUPABASE_URL").is_ok(); let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok(); let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok(); if !has_url { tracing::info!("JWT authentication not configured (SUPABASE_URL not set)"); } else if !has_public_key && !has_secret { tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)"); } None } }; Ok(Self { parakeet: Mutex::new(parakeet), parakeet_eou: Mutex::new(parakeet_eou), sortformer: Mutex::new(sortformer), db_pool: None, file_updates, task_updates, task_output, daemon_connections: DashMap::new(), tool_keys: DashMap::new(), jwt_verifier, }) } /// Set the database pool. pub fn with_db_pool(mut self, pool: PgPool) -> Self { self.db_pool = Some(pool); self } /// Broadcast a file update notification to all subscribers. /// /// This is a no-op if there are no subscribers (ignores send errors). pub fn broadcast_file_update(&self, notification: FileUpdateNotification) { // Ignore send errors - they just mean no one is listening let _ = self.file_updates.send(notification); } /// Broadcast a task update notification to all subscribers. /// /// This is a no-op if there are no subscribers (ignores send errors). pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) { let _ = self.task_updates.send(notification); } /// Broadcast task output to all subscribers. /// /// Used for streaming Claude Code container output to frontend clients. pub fn broadcast_task_output(&self, notification: TaskOutputNotification) { let _ = self.task_output.send(notification); } /// Register a new daemon connection. /// /// Returns the connection_id for later reference. pub fn register_daemon( &self, connection_id: String, daemon_id: Uuid, owner_id: Uuid, hostname: Option, machine_id: Option, command_sender: mpsc::Sender, ) { self.daemon_connections.insert( connection_id.clone(), DaemonConnectionInfo { id: daemon_id, owner_id, connection_id, hostname, machine_id, command_sender, working_directory: None, home_directory: None, worktrees_directory: None, }, ); } /// Update daemon directory information. pub fn update_daemon_directories( &self, connection_id: &str, working_directory: String, home_directory: String, worktrees_directory: String, ) { if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) { entry.working_directory = Some(working_directory); entry.home_directory = Some(home_directory); entry.worktrees_directory = Some(worktrees_directory); } } /// Unregister a daemon connection. pub fn unregister_daemon(&self, connection_id: &str) { self.daemon_connections.remove(connection_id); } /// Get a daemon connection by connection_id. pub fn get_daemon(&self, connection_id: &str) -> Option> { self.daemon_connections.get(connection_id) } /// Get a daemon by its database ID. pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option> { self.daemon_connections .iter() .find(|entry| entry.value().id == daemon_id) .map(|entry| { // Return a reference to the found entry self.daemon_connections.get(entry.key()).unwrap() }) } /// Send a command to a specific daemon by its database ID. pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> { if let Some(daemon) = self.daemon_connections .iter() .find(|entry| entry.value().id == daemon_id) { daemon.value().command_sender.send(command).await .map_err(|e| format!("Failed to send command to daemon: {}", e)) } else { Err(format!("Daemon {} not connected", daemon_id)) } } /// Broadcast sibling progress to all running sibling tasks. /// /// This is used for sibling awareness - when a task makes progress, /// its siblings are notified so they can adjust their work if needed. pub async fn broadcast_sibling_progress( &self, source_task_id: Uuid, source_task_name: &str, source_task_status: &str, progress_summary: Option, changed_files: Vec, running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id) ) { for (sibling_task_id, daemon_id) in running_sibling_daemon_ids { let command = DaemonCommand::InjectSiblingContext { task_id: sibling_task_id, sibling_task_id: source_task_id, sibling_name: source_task_name.to_string(), sibling_status: source_task_status.to_string(), progress_summary: progress_summary.clone(), changed_files: changed_files.clone(), }; // Fire and forget - don't block on sending to all daemons if let Err(e) = self.send_daemon_command(daemon_id, command).await { tracing::warn!( "Failed to inject sibling context to task {}: {}", sibling_task_id, e ); } } } /// Get list of connected daemon IDs. pub fn list_connected_daemon_ids(&self) -> Vec { self.daemon_connections .iter() .map(|entry| entry.value().id) .collect() } // ========================================================================= // Tool Key Management // ========================================================================= /// Register a tool key for a task. /// /// This allows orchestrators to authenticate with the API using /// the `X-Makima-Tool-Key` header. pub fn register_tool_key(&self, key: String, task_id: Uuid) { tracing::info!(task_id = %task_id, "Registering tool key"); self.tool_keys.insert(key, task_id); } /// Validate a tool key and return the associated task ID. pub fn validate_tool_key(&self, key: &str) -> Option { self.tool_keys.get(key).map(|entry| *entry.value()) } /// Revoke a tool key for a task. /// /// This should be called when a task completes or is terminated. pub fn revoke_tool_key(&self, task_id: Uuid) { // Find and remove the key for this task self.tool_keys.retain(|_, v| *v != task_id); tracing::info!(task_id = %task_id, "Revoked tool key"); } } /// Type alias for the shared application state. pub type SharedState = Arc;