From 87044a747b47bd83249d61a45842c7f7b2eae56d Mon Sep 17 00:00:00 2001 From: soryu Date: Sun, 11 Jan 2026 05:52:14 +0000 Subject: Contract system --- makima/src/server/state.rs | 364 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 341 insertions(+), 23 deletions(-) (limited to 'makima/src/server/state.rs') diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index e89197a..1c28544 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use dashmap::DashMap; use sqlx::PgPool; -use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::sync::{broadcast, mpsc, Mutex, OnceCell}; use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; @@ -75,6 +75,34 @@ pub struct TaskOutputNotification { pub is_partial: bool, } +/// Notification for task completion events (for supervisor tasks to monitor). +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskCompletionNotification { + /// ID of the completed task + pub task_id: Uuid, + /// Owner ID for data isolation + #[serde(skip)] + pub owner_id: Option, + /// Contract ID if task belongs to a contract + #[serde(skip_serializing_if = "Option::is_none")] + pub contract_id: Option, + /// Parent task ID (to notify parent/supervisor) + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_task_id: Option, + /// Final status: "done", "failed", etc. + pub status: String, + /// Summary of task output/results + #[serde(skip_serializing_if = "Option::is_none")] + pub output_summary: Option, + /// Path to the task's worktree (for reading files) + #[serde(skip_serializing_if = "Option::is_none")] + pub worktree_path: Option, + /// Error message if task failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, +} + /// Command sent from server to daemon. #[derive(Debug, Clone, serde::Serialize)] #[serde(tag = "type", rename_all = "camelCase")] @@ -119,6 +147,12 @@ pub enum DaemonCommand { /// Files to copy from parent task's worktree #[serde(rename = "copyFiles")] copy_files: Option>, + /// Contract ID if this task is associated with a contract + #[serde(rename = "contractId")] + contract_id: Option, + /// Whether this task is a supervisor (long-running contract orchestrator) + #[serde(rename = "isSupervisor")] + is_supervisor: bool, }, /// Pause a running task PauseTask { @@ -251,6 +285,69 @@ pub enum DaemonCommand { target_dir: String, }, + // ========================================================================= + // Contract File Commands + // ========================================================================= + + /// Read a file from a repository linked to a contract + ReadRepoFile { + /// Request ID for correlating response + #[serde(rename = "requestId")] + request_id: Uuid, + /// Contract ID (used for logging/context) + #[serde(rename = "contractId")] + contract_id: Uuid, + /// Path to the file within the repository + #[serde(rename = "filePath")] + file_path: String, + /// Full repository path on daemon's filesystem + #[serde(rename = "repoPath")] + repo_path: String, + }, + + // ========================================================================= + // Supervisor Git Commands + // ========================================================================= + + /// Create a new branch in a task's worktree + CreateBranch { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "branchName")] + branch_name: String, + /// Optional reference to create branch from (task_id or SHA) + #[serde(rename = "fromRef")] + from_ref: Option, + }, + + /// Merge a task's changes to a target branch + MergeTaskToTarget { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Target branch to merge into (default: task's base branch) + #[serde(rename = "targetBranch")] + target_branch: Option, + /// Whether to squash commits + squash: bool, + }, + + /// Create a pull request for a task's changes + CreatePR { + #[serde(rename = "taskId")] + task_id: Uuid, + title: String, + body: Option, + /// Base branch for the PR (default: main) + #[serde(rename = "baseBranch")] + base_branch: String, + }, + + /// Get the diff for a task's changes + GetTaskDiff { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Error response Error { code: String, message: String }, } @@ -278,16 +375,29 @@ pub struct DaemonConnectionInfo { pub worktrees_directory: Option, } -/// Shared application state containing ML models and database pool. -/// -/// Models are wrapped in `Mutex` for thread-safe mutable access during inference. -pub struct AppState { - /// Speech-to-text model (Parakeet TDT) +/// Configuration paths for ML models (used for lazy loading). +#[derive(Clone)] +pub struct ModelConfig { + pub parakeet_model_dir: String, + pub parakeet_eou_dir: String, + pub sortformer_model_path: String, +} + +/// Lazily-loaded ML models. +pub struct MlModels { pub parakeet: Mutex, - /// End-of-Utterance detection model for streaming pub parakeet_eou: Mutex, - /// Speaker diarization model (Sortformer) pub sortformer: Mutex, +} + +/// Shared application state containing ML models and database pool. +/// +/// Models are lazily loaded on first use to speed up server startup. +pub struct AppState { + /// ML model configuration (paths for lazy loading) + pub model_config: Option, + /// Lazily-loaded ML models (initialized on first Listen connection) + pub ml_models: OnceCell, /// Optional database connection pool pub db_pool: Option, /// Broadcast channel for file update notifications @@ -296,6 +406,8 @@ pub struct AppState { pub task_updates: broadcast::Sender, /// Broadcast channel for task output streaming pub task_output: broadcast::Sender, + /// Broadcast channel for task completion notifications (for supervisors) + pub task_completions: broadcast::Sender, /// Active daemon connections (keyed by connection_id) pub daemon_connections: DashMap, /// Tool keys for orchestrator API access (key -> task_id) @@ -305,7 +417,9 @@ pub struct AppState { } impl AppState { - /// Load all ML models from the specified directories. + /// Create AppState with ML model configuration for lazy loading. + /// + /// Models are NOT loaded at startup - they will be loaded on first Listen connection. /// /// # Arguments /// * `parakeet_model_dir` - Path to the Parakeet TDT model directory @@ -315,19 +429,12 @@ impl AppState { parakeet_model_dir: &str, parakeet_eou_dir: &str, sortformer_model_path: &str, - ) -> Result> { - let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?; - let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?; - let sortformer = Sortformer::with_config( - sortformer_model_path, - None, - DiarizationConfig::callhome(), - )?; - + ) -> Self { // Create broadcast channels with buffer for 256 messages let (file_updates, _) = broadcast::channel(256); let (task_updates, _) = broadcast::channel(256); let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming + let (task_completions, _) = broadcast::channel(256); // For supervisor task monitoring // Initialize JWT verifier from environment (optional) // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256) @@ -357,18 +464,61 @@ impl AppState { } }; - Ok(Self { - parakeet: Mutex::new(parakeet), - parakeet_eou: Mutex::new(parakeet_eou), - sortformer: Mutex::new(sortformer), + Self { + model_config: Some(ModelConfig { + parakeet_model_dir: parakeet_model_dir.to_string(), + parakeet_eou_dir: parakeet_eou_dir.to_string(), + sortformer_model_path: sortformer_model_path.to_string(), + }), + ml_models: OnceCell::new(), db_pool: None, file_updates, task_updates, task_output, + task_completions, daemon_connections: DashMap::new(), tool_keys: DashMap::new(), jwt_verifier, - }) + } + } + + /// Get or initialize ML models (lazy loading). + /// + /// Models are loaded on first call and cached for subsequent calls. + /// Returns None if model config is not set. + pub async fn get_ml_models(&self) -> Result<&MlModels, Box> { + let config = self.model_config.as_ref() + .ok_or_else(|| "ML model configuration not set")?; + + self.ml_models.get_or_try_init(|| async { + tracing::info!( + parakeet = %config.parakeet_model_dir, + eou = %config.parakeet_eou_dir, + sortformer = %config.sortformer_model_path, + "Lazy-loading ML models on first Listen connection..." + ); + + let parakeet = ParakeetTDT::from_pretrained(&config.parakeet_model_dir, None)?; + let parakeet_eou = ParakeetEOU::from_pretrained(&config.parakeet_eou_dir, None)?; + let sortformer = Sortformer::with_config( + &config.sortformer_model_path, + None, + DiarizationConfig::callhome(), + )?; + + tracing::info!("ML models loaded successfully"); + + Ok(MlModels { + parakeet: Mutex::new(parakeet), + parakeet_eou: Mutex::new(parakeet_eou), + sortformer: Mutex::new(sortformer), + }) + }).await + } + + /// Check if ML models are loaded. + pub fn are_models_loaded(&self) -> bool { + self.ml_models.initialized() } /// Set the database pool. @@ -399,6 +549,13 @@ impl AppState { let _ = self.task_output.send(notification); } + /// Broadcast a task completion notification to all subscribers. + /// + /// Used to notify supervisor tasks when their child tasks complete. + pub fn broadcast_task_completion(&self, notification: TaskCompletionNotification) { + let _ = self.task_completions.send(notification); + } + /// Register a new daemon connection. /// /// Returns the connection_id for later reference. @@ -544,6 +701,167 @@ impl AppState { self.tool_keys.retain(|_, v| *v != task_id); tracing::info!(task_id = %task_id, "Revoked tool key"); } + + // ========================================================================= + // Supervisor Notifications + // ========================================================================= + + /// Notify a contract's supervisor task about an event. + /// + /// This sends a message to the supervisor's stdin so it can react to changes + /// in tasks or contract state. + pub async fn notify_supervisor( + &self, + supervisor_task_id: Uuid, + supervisor_daemon_id: Option, + message: &str, + ) -> Result<(), String> { + // Only send if we have a daemon ID + let daemon_id = match supervisor_daemon_id { + Some(id) => id, + None => { + tracing::debug!( + supervisor_task_id = %supervisor_task_id, + "Supervisor has no daemon assigned, skipping notification" + ); + return Ok(()); + } + }; + + let command = DaemonCommand::SendMessage { + task_id: supervisor_task_id, + message: message.to_string(), + }; + + self.send_daemon_command(daemon_id, command).await + } + + /// Format and send a task completion notification to a supervisor. + pub async fn notify_supervisor_of_task_completion( + &self, + supervisor_task_id: Uuid, + supervisor_daemon_id: Option, + completed_task_id: Uuid, + completed_task_name: &str, + status: &str, + progress_summary: Option<&str>, + error_message: Option<&str>, + ) { + let mut message = format!( + "TASK_COMPLETED task_id={} name=\"{}\" status={}", + completed_task_id, completed_task_name, status + ); + + if let Some(summary) = progress_summary { + // Escape newlines in summary + let escaped = summary.replace('\n', "\\n"); + message.push_str(&format!(" summary=\"{}\"", escaped)); + } + + if let Some(err) = error_message { + let escaped = err.replace('\n', "\\n"); + message.push_str(&format!(" error=\"{}\"", escaped)); + } + + if let Err(e) = self.notify_supervisor( + supervisor_task_id, + supervisor_daemon_id, + &message, + ).await { + tracing::warn!( + supervisor_task_id = %supervisor_task_id, + completed_task_id = %completed_task_id, + "Failed to notify supervisor of task completion: {}", + e + ); + } + } + + /// Format and send a task status change notification to a supervisor. + pub async fn notify_supervisor_of_task_update( + &self, + supervisor_task_id: Uuid, + supervisor_daemon_id: Option, + updated_task_id: Uuid, + updated_task_name: &str, + new_status: &str, + updated_fields: &[String], + ) { + let message = format!( + "TASK_UPDATED task_id={} name=\"{}\" status={} fields={}", + updated_task_id, + updated_task_name, + new_status, + updated_fields.join(",") + ); + + if let Err(e) = self.notify_supervisor( + supervisor_task_id, + supervisor_daemon_id, + &message, + ).await { + tracing::warn!( + supervisor_task_id = %supervisor_task_id, + updated_task_id = %updated_task_id, + "Failed to notify supervisor of task update: {}", + e + ); + } + } + + /// Format and send a contract phase change notification to a supervisor. + pub async fn notify_supervisor_of_phase_change( + &self, + supervisor_task_id: Uuid, + supervisor_daemon_id: Option, + contract_id: Uuid, + new_phase: &str, + ) { + let message = format!( + "PHASE_CHANGED contract_id={} phase={}", + contract_id, new_phase + ); + + if let Err(e) = self.notify_supervisor( + supervisor_task_id, + supervisor_daemon_id, + &message, + ).await { + tracing::warn!( + supervisor_task_id = %supervisor_task_id, + contract_id = %contract_id, + "Failed to notify supervisor of phase change: {}", + e + ); + } + } + + /// Format and send a new task created notification to a supervisor. + pub async fn notify_supervisor_of_task_created( + &self, + supervisor_task_id: Uuid, + supervisor_daemon_id: Option, + new_task_id: Uuid, + new_task_name: &str, + ) { + let message = format!( + "TASK_CREATED task_id={} name=\"{}\"", + new_task_id, new_task_name + ); + + if let Err(e) = self.notify_supervisor( + supervisor_task_id, + supervisor_daemon_id, + &message, + ).await { + tracing::warn!( + supervisor_task_id = %supervisor_task_id, + new_task_id = %new_task_id, + "Failed to notify supervisor of task creation: {}", + e + ); + } + } } /// Type alias for the shared application state. -- cgit v1.2.3