summaryrefslogtreecommitdiff
path: root/makima/src/server/state.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-11 05:52:14 +0000
committersoryu <soryu@soryu.co>2026-01-15 00:21:16 +0000
commit87044a747b47bd83249d61a45842c7f7b2eae56d (patch)
treeef2000ce79ffcc2723ef841acef5aa1deb1d5378 /makima/src/server/state.rs
parent077820c4167c168072d217a1b01df840463a12a8 (diff)
downloadsoryu-87044a747b47bd83249d61a45842c7f7b2eae56d.tar.gz
soryu-87044a747b47bd83249d61a45842c7f7b2eae56d.zip
Contract system
Diffstat (limited to 'makima/src/server/state.rs')
-rw-r--r--makima/src/server/state.rs364
1 files changed, 341 insertions, 23 deletions
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
index e89197a..1c28544 100644
--- a/makima/src/server/state.rs
+++ b/makima/src/server/state.rs
@@ -3,7 +3,7 @@
use std::sync::Arc;
use dashmap::DashMap;
use sqlx::PgPool;
-use tokio::sync::{broadcast, mpsc, Mutex};
+use tokio::sync::{broadcast, mpsc, Mutex, OnceCell};
use uuid::Uuid;
use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
@@ -75,6 +75,34 @@ pub struct TaskOutputNotification {
pub is_partial: bool,
}
+/// Notification for task completion events (for supervisor tasks to monitor).
+#[derive(Debug, Clone, serde::Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct TaskCompletionNotification {
+ /// ID of the completed task
+ pub task_id: Uuid,
+ /// Owner ID for data isolation
+ #[serde(skip)]
+ pub owner_id: Option<Uuid>,
+ /// Contract ID if task belongs to a contract
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub contract_id: Option<Uuid>,
+ /// Parent task ID (to notify parent/supervisor)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub parent_task_id: Option<Uuid>,
+ /// Final status: "done", "failed", etc.
+ pub status: String,
+ /// Summary of task output/results
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub output_summary: Option<String>,
+ /// Path to the task's worktree (for reading files)
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub worktree_path: Option<String>,
+ /// Error message if task failed
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub error_message: Option<String>,
+}
+
/// Command sent from server to daemon.
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
@@ -119,6 +147,12 @@ pub enum DaemonCommand {
/// Files to copy from parent task's worktree
#[serde(rename = "copyFiles")]
copy_files: Option<Vec<String>>,
+ /// Contract ID if this task is associated with a contract
+ #[serde(rename = "contractId")]
+ contract_id: Option<Uuid>,
+ /// Whether this task is a supervisor (long-running contract orchestrator)
+ #[serde(rename = "isSupervisor")]
+ is_supervisor: bool,
},
/// Pause a running task
PauseTask {
@@ -251,6 +285,69 @@ pub enum DaemonCommand {
target_dir: String,
},
+ // =========================================================================
+ // Contract File Commands
+ // =========================================================================
+
+ /// Read a file from a repository linked to a contract
+ ReadRepoFile {
+ /// Request ID for correlating response
+ #[serde(rename = "requestId")]
+ request_id: Uuid,
+ /// Contract ID (used for logging/context)
+ #[serde(rename = "contractId")]
+ contract_id: Uuid,
+ /// Path to the file within the repository
+ #[serde(rename = "filePath")]
+ file_path: String,
+ /// Full repository path on daemon's filesystem
+ #[serde(rename = "repoPath")]
+ repo_path: String,
+ },
+
+ // =========================================================================
+ // Supervisor Git Commands
+ // =========================================================================
+
+ /// Create a new branch in a task's worktree
+ CreateBranch {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ #[serde(rename = "branchName")]
+ branch_name: String,
+ /// Optional reference to create branch from (task_id or SHA)
+ #[serde(rename = "fromRef")]
+ from_ref: Option<String>,
+ },
+
+ /// Merge a task's changes to a target branch
+ MergeTaskToTarget {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Target branch to merge into (default: task's base branch)
+ #[serde(rename = "targetBranch")]
+ target_branch: Option<String>,
+ /// Whether to squash commits
+ squash: bool,
+ },
+
+ /// Create a pull request for a task's changes
+ CreatePR {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ title: String,
+ body: Option<String>,
+ /// Base branch for the PR (default: main)
+ #[serde(rename = "baseBranch")]
+ base_branch: String,
+ },
+
+ /// Get the diff for a task's changes
+ GetTaskDiff {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+
/// Error response
Error { code: String, message: String },
}
@@ -278,16 +375,29 @@ pub struct DaemonConnectionInfo {
pub worktrees_directory: Option<String>,
}
-/// Shared application state containing ML models and database pool.
-///
-/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
-pub struct AppState {
- /// Speech-to-text model (Parakeet TDT)
+/// Configuration paths for ML models (used for lazy loading).
+#[derive(Clone)]
+pub struct ModelConfig {
+ pub parakeet_model_dir: String,
+ pub parakeet_eou_dir: String,
+ pub sortformer_model_path: String,
+}
+
+/// Lazily-loaded ML models.
+pub struct MlModels {
pub parakeet: Mutex<ParakeetTDT>,
- /// End-of-Utterance detection model for streaming
pub parakeet_eou: Mutex<ParakeetEOU>,
- /// Speaker diarization model (Sortformer)
pub sortformer: Mutex<Sortformer>,
+}
+
+/// Shared application state containing ML models and database pool.
+///
+/// Models are lazily loaded on first use to speed up server startup.
+pub struct AppState {
+ /// ML model configuration (paths for lazy loading)
+ pub model_config: Option<ModelConfig>,
+ /// Lazily-loaded ML models (initialized on first Listen connection)
+ pub ml_models: OnceCell<MlModels>,
/// Optional database connection pool
pub db_pool: Option<PgPool>,
/// Broadcast channel for file update notifications
@@ -296,6 +406,8 @@ pub struct AppState {
pub task_updates: broadcast::Sender<TaskUpdateNotification>,
/// Broadcast channel for task output streaming
pub task_output: broadcast::Sender<TaskOutputNotification>,
+ /// Broadcast channel for task completion notifications (for supervisors)
+ pub task_completions: broadcast::Sender<TaskCompletionNotification>,
/// Active daemon connections (keyed by connection_id)
pub daemon_connections: DashMap<String, DaemonConnectionInfo>,
/// Tool keys for orchestrator API access (key -> task_id)
@@ -305,7 +417,9 @@ pub struct AppState {
}
impl AppState {
- /// Load all ML models from the specified directories.
+ /// Create AppState with ML model configuration for lazy loading.
+ ///
+ /// Models are NOT loaded at startup - they will be loaded on first Listen connection.
///
/// # Arguments
/// * `parakeet_model_dir` - Path to the Parakeet TDT model directory
@@ -315,19 +429,12 @@ impl AppState {
parakeet_model_dir: &str,
parakeet_eou_dir: &str,
sortformer_model_path: &str,
- ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
- let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
- let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?;
- let sortformer = Sortformer::with_config(
- sortformer_model_path,
- None,
- DiarizationConfig::callhome(),
- )?;
-
+ ) -> Self {
// Create broadcast channels with buffer for 256 messages
let (file_updates, _) = broadcast::channel(256);
let (task_updates, _) = broadcast::channel(256);
let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming
+ let (task_completions, _) = broadcast::channel(256); // For supervisor task monitoring
// Initialize JWT verifier from environment (optional)
// Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256)
@@ -357,18 +464,61 @@ impl AppState {
}
};
- Ok(Self {
- parakeet: Mutex::new(parakeet),
- parakeet_eou: Mutex::new(parakeet_eou),
- sortformer: Mutex::new(sortformer),
+ Self {
+ model_config: Some(ModelConfig {
+ parakeet_model_dir: parakeet_model_dir.to_string(),
+ parakeet_eou_dir: parakeet_eou_dir.to_string(),
+ sortformer_model_path: sortformer_model_path.to_string(),
+ }),
+ ml_models: OnceCell::new(),
db_pool: None,
file_updates,
task_updates,
task_output,
+ task_completions,
daemon_connections: DashMap::new(),
tool_keys: DashMap::new(),
jwt_verifier,
- })
+ }
+ }
+
+ /// Get or initialize ML models (lazy loading).
+ ///
+ /// Models are loaded on first call and cached for subsequent calls.
+ /// Returns None if model config is not set.
+ pub async fn get_ml_models(&self) -> Result<&MlModels, Box<dyn std::error::Error + Send + Sync>> {
+ let config = self.model_config.as_ref()
+ .ok_or_else(|| "ML model configuration not set")?;
+
+ self.ml_models.get_or_try_init(|| async {
+ tracing::info!(
+ parakeet = %config.parakeet_model_dir,
+ eou = %config.parakeet_eou_dir,
+ sortformer = %config.sortformer_model_path,
+ "Lazy-loading ML models on first Listen connection..."
+ );
+
+ let parakeet = ParakeetTDT::from_pretrained(&config.parakeet_model_dir, None)?;
+ let parakeet_eou = ParakeetEOU::from_pretrained(&config.parakeet_eou_dir, None)?;
+ let sortformer = Sortformer::with_config(
+ &config.sortformer_model_path,
+ None,
+ DiarizationConfig::callhome(),
+ )?;
+
+ tracing::info!("ML models loaded successfully");
+
+ Ok(MlModels {
+ parakeet: Mutex::new(parakeet),
+ parakeet_eou: Mutex::new(parakeet_eou),
+ sortformer: Mutex::new(sortformer),
+ })
+ }).await
+ }
+
+ /// Check if ML models are loaded.
+ pub fn are_models_loaded(&self) -> bool {
+ self.ml_models.initialized()
}
/// Set the database pool.
@@ -399,6 +549,13 @@ impl AppState {
let _ = self.task_output.send(notification);
}
+ /// Broadcast a task completion notification to all subscribers.
+ ///
+ /// Used to notify supervisor tasks when their child tasks complete.
+ pub fn broadcast_task_completion(&self, notification: TaskCompletionNotification) {
+ let _ = self.task_completions.send(notification);
+ }
+
/// Register a new daemon connection.
///
/// Returns the connection_id for later reference.
@@ -544,6 +701,167 @@ impl AppState {
self.tool_keys.retain(|_, v| *v != task_id);
tracing::info!(task_id = %task_id, "Revoked tool key");
}
+
+ // =========================================================================
+ // Supervisor Notifications
+ // =========================================================================
+
+ /// Notify a contract's supervisor task about an event.
+ ///
+ /// This sends a message to the supervisor's stdin so it can react to changes
+ /// in tasks or contract state.
+ pub async fn notify_supervisor(
+ &self,
+ supervisor_task_id: Uuid,
+ supervisor_daemon_id: Option<Uuid>,
+ message: &str,
+ ) -> Result<(), String> {
+ // Only send if we have a daemon ID
+ let daemon_id = match supervisor_daemon_id {
+ Some(id) => id,
+ None => {
+ tracing::debug!(
+ supervisor_task_id = %supervisor_task_id,
+ "Supervisor has no daemon assigned, skipping notification"
+ );
+ return Ok(());
+ }
+ };
+
+ let command = DaemonCommand::SendMessage {
+ task_id: supervisor_task_id,
+ message: message.to_string(),
+ };
+
+ self.send_daemon_command(daemon_id, command).await
+ }
+
+ /// Format and send a task completion notification to a supervisor.
+ pub async fn notify_supervisor_of_task_completion(
+ &self,
+ supervisor_task_id: Uuid,
+ supervisor_daemon_id: Option<Uuid>,
+ completed_task_id: Uuid,
+ completed_task_name: &str,
+ status: &str,
+ progress_summary: Option<&str>,
+ error_message: Option<&str>,
+ ) {
+ let mut message = format!(
+ "TASK_COMPLETED task_id={} name=\"{}\" status={}",
+ completed_task_id, completed_task_name, status
+ );
+
+ if let Some(summary) = progress_summary {
+ // Escape newlines in summary
+ let escaped = summary.replace('\n', "\\n");
+ message.push_str(&format!(" summary=\"{}\"", escaped));
+ }
+
+ if let Some(err) = error_message {
+ let escaped = err.replace('\n', "\\n");
+ message.push_str(&format!(" error=\"{}\"", escaped));
+ }
+
+ if let Err(e) = self.notify_supervisor(
+ supervisor_task_id,
+ supervisor_daemon_id,
+ &message,
+ ).await {
+ tracing::warn!(
+ supervisor_task_id = %supervisor_task_id,
+ completed_task_id = %completed_task_id,
+ "Failed to notify supervisor of task completion: {}",
+ e
+ );
+ }
+ }
+
+ /// Format and send a task status change notification to a supervisor.
+ pub async fn notify_supervisor_of_task_update(
+ &self,
+ supervisor_task_id: Uuid,
+ supervisor_daemon_id: Option<Uuid>,
+ updated_task_id: Uuid,
+ updated_task_name: &str,
+ new_status: &str,
+ updated_fields: &[String],
+ ) {
+ let message = format!(
+ "TASK_UPDATED task_id={} name=\"{}\" status={} fields={}",
+ updated_task_id,
+ updated_task_name,
+ new_status,
+ updated_fields.join(",")
+ );
+
+ if let Err(e) = self.notify_supervisor(
+ supervisor_task_id,
+ supervisor_daemon_id,
+ &message,
+ ).await {
+ tracing::warn!(
+ supervisor_task_id = %supervisor_task_id,
+ updated_task_id = %updated_task_id,
+ "Failed to notify supervisor of task update: {}",
+ e
+ );
+ }
+ }
+
+ /// Format and send a contract phase change notification to a supervisor.
+ pub async fn notify_supervisor_of_phase_change(
+ &self,
+ supervisor_task_id: Uuid,
+ supervisor_daemon_id: Option<Uuid>,
+ contract_id: Uuid,
+ new_phase: &str,
+ ) {
+ let message = format!(
+ "PHASE_CHANGED contract_id={} phase={}",
+ contract_id, new_phase
+ );
+
+ if let Err(e) = self.notify_supervisor(
+ supervisor_task_id,
+ supervisor_daemon_id,
+ &message,
+ ).await {
+ tracing::warn!(
+ supervisor_task_id = %supervisor_task_id,
+ contract_id = %contract_id,
+ "Failed to notify supervisor of phase change: {}",
+ e
+ );
+ }
+ }
+
+ /// Format and send a new task created notification to a supervisor.
+ pub async fn notify_supervisor_of_task_created(
+ &self,
+ supervisor_task_id: Uuid,
+ supervisor_daemon_id: Option<Uuid>,
+ new_task_id: Uuid,
+ new_task_name: &str,
+ ) {
+ let message = format!(
+ "TASK_CREATED task_id={} name=\"{}\"",
+ new_task_id, new_task_name
+ );
+
+ if let Err(e) = self.notify_supervisor(
+ supervisor_task_id,
+ supervisor_daemon_id,
+ &message,
+ ).await {
+ tracing::warn!(
+ supervisor_task_id = %supervisor_task_id,
+ new_task_id = %new_task_id,
+ "Failed to notify supervisor of task creation: {}",
+ e
+ );
+ }
+ }
}
/// Type alias for the shared application state.