summaryrefslogblamecommitdiff
path: root/makima/src/server/state.rs
blob: e89197ac4c7daefc8a788e823a26c0a936534283 (plain) (tree)
1
2
3
4
5
6
7
8
9
                                                                 

                   
                     
                 
                                          
               
 
                                                                             
                                                   
 












                                                                               































































































































































































































































                                                                                                   
                                                                    


                                                                                  
                                           
                                     

                                                      

                                              

                                         

                                                                









                                                                         





                                                          

                                                                         
                                                                                 

                                 
                               
                                    

                                                                               
                                                                                 




                                                 
 
                                                                 
                                                        





























                                                                                                                                                 
 

                                           
                                                   
                                               
                          
                         




                                               

          





                                                         







                                                                               































































































































































                                                                                                                            



                                                
//! Application state holding shared ML models and database pool.

use std::sync::Arc;
use dashmap::DashMap;
use sqlx::PgPool;
use tokio::sync::{broadcast, mpsc, Mutex};
use uuid::Uuid;

use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
use crate::server::auth::{AuthConfig, JwtVerifier};

/// Notification payload for file updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
pub struct FileUpdateNotification {
    /// ID of the updated file
    pub file_id: Uuid,
    /// New version number after update
    pub version: i32,
    /// List of fields that were updated
    pub updated_fields: Vec<String>,
    /// Source of the update: "user", "llm", or "system"
    pub updated_by: String,
}

// =============================================================================
// Task/Mesh Notifications
// =============================================================================

/// Notification payload for task updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
pub struct TaskUpdateNotification {
    /// ID of the updated task
    pub task_id: Uuid,
    /// Owner ID for data isolation (notifications are scoped to owner)
    pub owner_id: Option<Uuid>,
    /// New version number after update
    pub version: i32,
    /// Current task status
    pub status: String,
    /// List of fields that were updated
    pub updated_fields: Vec<String>,
    /// Source of the update: "user", "daemon", or "system"
    pub updated_by: String,
}

/// Notification for streaming task output from Claude Code containers.
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskOutputNotification {
    /// ID of the task producing output
    pub task_id: Uuid,
    /// Owner ID for data isolation (notifications are scoped to owner)
    #[serde(skip)]
    pub owner_id: Option<Uuid>,
    /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
    pub message_type: String,
    /// Main text content of the message
    pub content: String,
    /// Tool name if this is a tool_use message
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_name: Option<String>,
    /// Tool input (JSON) if this is a tool_use message
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_input: Option<serde_json::Value>,
    /// Whether tool result was an error
    #[serde(skip_serializing_if = "Option::is_none")]
    pub is_error: Option<bool>,
    /// Cost in USD if this is a result message
    #[serde(skip_serializing_if = "Option::is_none")]
    pub cost_usd: Option<f64>,
    /// Duration in milliseconds if this is a result message
    #[serde(skip_serializing_if = "Option::is_none")]
    pub duration_ms: Option<u64>,
    /// Whether this is a partial line (more coming) or complete
    pub is_partial: bool,
}

/// Command sent from server to daemon.
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum DaemonCommand {
    /// Confirm successful authentication
    Authenticated {
        #[serde(rename = "daemonId")]
        daemon_id: Uuid,
    },
    /// Spawn a new task in a container
    SpawnTask {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        /// Human-readable task name (used for commit messages)
        #[serde(rename = "taskName")]
        task_name: String,
        plan: String,
        #[serde(rename = "repoUrl")]
        repo_url: Option<String>,
        #[serde(rename = "baseBranch")]
        base_branch: Option<String>,
        /// Target branch to merge into (used for completion actions)
        #[serde(rename = "targetBranch")]
        target_branch: Option<String>,
        /// Parent task ID if this is a subtask
        #[serde(rename = "parentTaskId")]
        parent_task_id: Option<Uuid>,
        /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask)
        depth: i32,
        /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks)
        #[serde(rename = "isOrchestrator")]
        is_orchestrator: bool,
        /// Path to user's local repository (outside ~/.makima) for completion actions
        #[serde(rename = "targetRepoPath")]
        target_repo_path: Option<String>,
        /// Action on completion: "none", "branch", "merge", "pr"
        #[serde(rename = "completionAction")]
        completion_action: Option<String>,
        /// Task ID to continue from (copy worktree from this task)
        #[serde(rename = "continueFromTaskId")]
        continue_from_task_id: Option<Uuid>,
        /// Files to copy from parent task's worktree
        #[serde(rename = "copyFiles")]
        copy_files: Option<Vec<String>>,
    },
    /// Pause a running task
    PauseTask {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Resume a paused task
    ResumeTask {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Interrupt a task (gracefully or forced)
    InterruptTask {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        graceful: bool,
    },
    /// Send a message to a running task
    SendMessage {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        message: String,
    },
    /// Inject context about sibling task progress
    InjectSiblingContext {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        #[serde(rename = "siblingTaskId")]
        sibling_task_id: Uuid,
        #[serde(rename = "siblingName")]
        sibling_name: String,
        #[serde(rename = "siblingStatus")]
        sibling_status: String,
        #[serde(rename = "progressSummary")]
        progress_summary: Option<String>,
        #[serde(rename = "changedFiles")]
        changed_files: Vec<String>,
    },

    // =========================================================================
    // Merge Commands (for orchestrators to merge subtask branches)
    // =========================================================================

    /// List all subtask branches for a task
    ListBranches {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Start merging a subtask branch
    MergeStart {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        #[serde(rename = "sourceBranch")]
        source_branch: String,
    },
    /// Get current merge status
    MergeStatus {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Resolve a merge conflict
    MergeResolve {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        file: String,
        /// "ours" or "theirs"
        strategy: String,
    },
    /// Commit the current merge
    MergeCommit {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        message: String,
    },
    /// Abort the current merge
    MergeAbort {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Skip merging a subtask branch (mark as intentionally not merged)
    MergeSkip {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        #[serde(rename = "subtaskId")]
        subtask_id: Uuid,
        reason: String,
    },
    /// Check if all subtask branches have been merged or skipped (completion gate)
    CheckMergeComplete {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },

    // =========================================================================
    // Completion Action Commands
    // =========================================================================

    /// Retry a completion action for a completed task
    RetryCompletionAction {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        /// Human-readable task name (used for commit messages)
        #[serde(rename = "taskName")]
        task_name: String,
        /// The action to execute: "branch", "merge", or "pr"
        action: String,
        /// Path to the target repository
        #[serde(rename = "targetRepoPath")]
        target_repo_path: String,
        /// Target branch to merge into (for merge/pr actions)
        #[serde(rename = "targetBranch")]
        target_branch: Option<String>,
    },

    /// Clone worktree to a target directory
    CloneWorktree {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        /// Path to the target directory
        #[serde(rename = "targetDir")]
        target_dir: String,
    },

    /// Check if a target directory exists
    CheckTargetExists {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        /// Path to check
        #[serde(rename = "targetDir")]
        target_dir: String,
    },

    /// Error response
    Error { code: String, message: String },
}

/// Active daemon connection info stored in state.
#[derive(Debug)]
pub struct DaemonConnectionInfo {
    /// Database ID of the daemon
    pub id: Uuid,
    /// Owner ID for data isolation (from API key authentication)
    pub owner_id: Uuid,
    /// WebSocket connection identifier
    pub connection_id: String,
    /// Daemon hostname
    pub hostname: Option<String>,
    /// Machine identifier
    pub machine_id: Option<String>,
    /// Channel to send commands to this daemon
    pub command_sender: mpsc::Sender<DaemonCommand>,
    /// Current working directory of the daemon
    pub working_directory: Option<String>,
    /// Path to ~/.makima/home directory on daemon (for cloning completed work)
    pub home_directory: Option<String>,
    /// Path to worktrees directory (~/.makima/worktrees) on daemon
    pub worktrees_directory: Option<String>,
}

/// Shared application state containing ML models and database pool.
///
/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
pub struct AppState {
    /// Speech-to-text model (Parakeet TDT)
    pub parakeet: Mutex<ParakeetTDT>,
    /// End-of-Utterance detection model for streaming
    pub parakeet_eou: Mutex<ParakeetEOU>,
    /// Speaker diarization model (Sortformer)
    pub sortformer: Mutex<Sortformer>,
    /// Optional database connection pool
    pub db_pool: Option<PgPool>,
    /// Broadcast channel for file update notifications
    pub file_updates: broadcast::Sender<FileUpdateNotification>,
    /// Broadcast channel for task update notifications
    pub task_updates: broadcast::Sender<TaskUpdateNotification>,
    /// Broadcast channel for task output streaming
    pub task_output: broadcast::Sender<TaskOutputNotification>,
    /// Active daemon connections (keyed by connection_id)
    pub daemon_connections: DashMap<String, DaemonConnectionInfo>,
    /// Tool keys for orchestrator API access (key -> task_id)
    pub tool_keys: DashMap<String, Uuid>,
    /// JWT verifier for Supabase authentication (None if not configured)
    pub jwt_verifier: Option<JwtVerifier>,
}

impl AppState {
    /// Load all ML models from the specified directories.
    ///
    /// # Arguments
    /// * `parakeet_model_dir` - Path to the Parakeet TDT model directory
    /// * `parakeet_eou_dir` - Path to the Parakeet EOU model directory
    /// * `sortformer_model_path` - Path to the Sortformer diarization model file
    pub fn new(
        parakeet_model_dir: &str,
        parakeet_eou_dir: &str,
        sortformer_model_path: &str,
    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
        let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
        let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?;
        let sortformer = Sortformer::with_config(
            sortformer_model_path,
            None,
            DiarizationConfig::callhome(),
        )?;

        // Create broadcast channels with buffer for 256 messages
        let (file_updates, _) = broadcast::channel(256);
        let (task_updates, _) = broadcast::channel(256);
        let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming

        // Initialize JWT verifier from environment (optional)
        // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256)
        let jwt_verifier = match AuthConfig::from_env() {
            Some(config) => match JwtVerifier::new(config) {
                Ok(verifier) => {
                    tracing::info!("JWT authentication configured");
                    Some(verifier)
                }
                Err(e) => {
                    tracing::error!("Failed to initialize JWT verifier: {}", e);
                    None
                }
            },
            None => {
                // Log which env vars are missing
                let has_url = std::env::var("SUPABASE_URL").is_ok();
                let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok();
                let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok();

                if !has_url {
                    tracing::info!("JWT authentication not configured (SUPABASE_URL not set)");
                } else if !has_public_key && !has_secret {
                    tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)");
                }
                None
            }
        };

        Ok(Self {
            parakeet: Mutex::new(parakeet),
            parakeet_eou: Mutex::new(parakeet_eou),
            sortformer: Mutex::new(sortformer),
            db_pool: None,
            file_updates,
            task_updates,
            task_output,
            daemon_connections: DashMap::new(),
            tool_keys: DashMap::new(),
            jwt_verifier,
        })
    }

    /// Set the database pool.
    pub fn with_db_pool(mut self, pool: PgPool) -> Self {
        self.db_pool = Some(pool);
        self
    }

    /// Broadcast a file update notification to all subscribers.
    ///
    /// This is a no-op if there are no subscribers (ignores send errors).
    pub fn broadcast_file_update(&self, notification: FileUpdateNotification) {
        // Ignore send errors - they just mean no one is listening
        let _ = self.file_updates.send(notification);
    }

    /// Broadcast a task update notification to all subscribers.
    ///
    /// This is a no-op if there are no subscribers (ignores send errors).
    pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) {
        let _ = self.task_updates.send(notification);
    }

    /// Broadcast task output to all subscribers.
    ///
    /// Used for streaming Claude Code container output to frontend clients.
    pub fn broadcast_task_output(&self, notification: TaskOutputNotification) {
        let _ = self.task_output.send(notification);
    }

    /// Register a new daemon connection.
    ///
    /// Returns the connection_id for later reference.
    pub fn register_daemon(
        &self,
        connection_id: String,
        daemon_id: Uuid,
        owner_id: Uuid,
        hostname: Option<String>,
        machine_id: Option<String>,
        command_sender: mpsc::Sender<DaemonCommand>,
    ) {
        self.daemon_connections.insert(
            connection_id.clone(),
            DaemonConnectionInfo {
                id: daemon_id,
                owner_id,
                connection_id,
                hostname,
                machine_id,
                command_sender,
                working_directory: None,
                home_directory: None,
                worktrees_directory: None,
            },
        );
    }

    /// Update daemon directory information.
    pub fn update_daemon_directories(
        &self,
        connection_id: &str,
        working_directory: String,
        home_directory: String,
        worktrees_directory: String,
    ) {
        if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) {
            entry.working_directory = Some(working_directory);
            entry.home_directory = Some(home_directory);
            entry.worktrees_directory = Some(worktrees_directory);
        }
    }

    /// Unregister a daemon connection.
    pub fn unregister_daemon(&self, connection_id: &str) {
        self.daemon_connections.remove(connection_id);
    }

    /// Get a daemon connection by connection_id.
    pub fn get_daemon(&self, connection_id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
        self.daemon_connections.get(connection_id)
    }

    /// Get a daemon by its database ID.
    pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
        self.daemon_connections
            .iter()
            .find(|entry| entry.value().id == daemon_id)
            .map(|entry| {
                // Return a reference to the found entry
                self.daemon_connections.get(entry.key()).unwrap()
            })
    }

    /// Send a command to a specific daemon by its database ID.
    pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> {
        if let Some(daemon) = self.daemon_connections
            .iter()
            .find(|entry| entry.value().id == daemon_id)
        {
            daemon.value().command_sender.send(command).await
                .map_err(|e| format!("Failed to send command to daemon: {}", e))
        } else {
            Err(format!("Daemon {} not connected", daemon_id))
        }
    }

    /// Broadcast sibling progress to all running sibling tasks.
    ///
    /// This is used for sibling awareness - when a task makes progress,
    /// its siblings are notified so they can adjust their work if needed.
    pub async fn broadcast_sibling_progress(
        &self,
        source_task_id: Uuid,
        source_task_name: &str,
        source_task_status: &str,
        progress_summary: Option<String>,
        changed_files: Vec<String>,
        running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id)
    ) {
        for (sibling_task_id, daemon_id) in running_sibling_daemon_ids {
            let command = DaemonCommand::InjectSiblingContext {
                task_id: sibling_task_id,
                sibling_task_id: source_task_id,
                sibling_name: source_task_name.to_string(),
                sibling_status: source_task_status.to_string(),
                progress_summary: progress_summary.clone(),
                changed_files: changed_files.clone(),
            };

            // Fire and forget - don't block on sending to all daemons
            if let Err(e) = self.send_daemon_command(daemon_id, command).await {
                tracing::warn!(
                    "Failed to inject sibling context to task {}: {}",
                    sibling_task_id,
                    e
                );
            }
        }
    }

    /// Get list of connected daemon IDs.
    pub fn list_connected_daemon_ids(&self) -> Vec<Uuid> {
        self.daemon_connections
            .iter()
            .map(|entry| entry.value().id)
            .collect()
    }

    // =========================================================================
    // Tool Key Management
    // =========================================================================

    /// Register a tool key for a task.
    ///
    /// This allows orchestrators to authenticate with the API using
    /// the `X-Makima-Tool-Key` header.
    pub fn register_tool_key(&self, key: String, task_id: Uuid) {
        tracing::info!(task_id = %task_id, "Registering tool key");
        self.tool_keys.insert(key, task_id);
    }

    /// Validate a tool key and return the associated task ID.
    pub fn validate_tool_key(&self, key: &str) -> Option<Uuid> {
        self.tool_keys.get(key).map(|entry| *entry.value())
    }

    /// Revoke a tool key for a task.
    ///
    /// This should be called when a task completes or is terminated.
    pub fn revoke_tool_key(&self, task_id: Uuid) {
        // Find and remove the key for this task
        self.tool_keys.retain(|_, v| *v != task_id);
        tracing::info!(task_id = %task_id, "Revoked tool key");
    }
}

/// Type alias for the shared application state.
pub type SharedState = Arc<AppState>;