//! Application state holding shared ML models and database pool.
use std::sync::Arc;
use dashmap::DashMap;
use sqlx::PgPool;
use tokio::sync::{broadcast, mpsc, Mutex};
use uuid::Uuid;
use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
use crate::server::auth::{AuthConfig, JwtVerifier};
/// Notification payload for file updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
pub struct FileUpdateNotification {
/// ID of the updated file
pub file_id: Uuid,
/// New version number after update
pub version: i32,
/// List of fields that were updated
pub updated_fields: Vec<String>,
/// Source of the update: "user", "llm", or "system"
pub updated_by: String,
}
// =============================================================================
// Task/Mesh Notifications
// =============================================================================
/// Notification payload for task updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
pub struct TaskUpdateNotification {
/// ID of the updated task
pub task_id: Uuid,
/// Owner ID for data isolation (notifications are scoped to owner)
pub owner_id: Option<Uuid>,
/// New version number after update
pub version: i32,
/// Current task status
pub status: String,
/// List of fields that were updated
pub updated_fields: Vec<String>,
/// Source of the update: "user", "daemon", or "system"
pub updated_by: String,
}
/// Notification for streaming task output from Claude Code containers.
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskOutputNotification {
/// ID of the task producing output
pub task_id: Uuid,
/// Owner ID for data isolation (notifications are scoped to owner)
#[serde(skip)]
pub owner_id: Option<Uuid>,
/// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
pub message_type: String,
/// Main text content of the message
pub content: String,
/// Tool name if this is a tool_use message
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
/// Tool input (JSON) if this is a tool_use message
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_input: Option<serde_json::Value>,
/// Whether tool result was an error
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
/// Cost in USD if this is a result message
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_usd: Option<f64>,
/// Duration in milliseconds if this is a result message
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>,
/// Whether this is a partial line (more coming) or complete
pub is_partial: bool,
}
/// Command sent from server to daemon.
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum DaemonCommand {
/// Confirm successful authentication
Authenticated {
#[serde(rename = "daemonId")]
daemon_id: Uuid,
},
/// Spawn a new task in a container
SpawnTask {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Human-readable task name (used for commit messages)
#[serde(rename = "taskName")]
task_name: String,
plan: String,
#[serde(rename = "repoUrl")]
repo_url: Option<String>,
#[serde(rename = "baseBranch")]
base_branch: Option<String>,
/// Target branch to merge into (used for completion actions)
#[serde(rename = "targetBranch")]
target_branch: Option<String>,
/// Parent task ID if this is a subtask
#[serde(rename = "parentTaskId")]
parent_task_id: Option<Uuid>,
/// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask)
depth: i32,
/// Whether this task should run as an orchestrator (true if depth==0 and has subtasks)
#[serde(rename = "isOrchestrator")]
is_orchestrator: bool,
/// Path to user's local repository (outside ~/.makima) for completion actions
#[serde(rename = "targetRepoPath")]
target_repo_path: Option<String>,
/// Action on completion: "none", "branch", "merge", "pr"
#[serde(rename = "completionAction")]
completion_action: Option<String>,
/// Task ID to continue from (copy worktree from this task)
#[serde(rename = "continueFromTaskId")]
continue_from_task_id: Option<Uuid>,
/// Files to copy from parent task's worktree
#[serde(rename = "copyFiles")]
copy_files: Option<Vec<String>>,
},
/// Pause a running task
PauseTask {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Resume a paused task
ResumeTask {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Interrupt a task (gracefully or forced)
InterruptTask {
#[serde(rename = "taskId")]
task_id: Uuid,
graceful: bool,
},
/// Send a message to a running task
SendMessage {
#[serde(rename = "taskId")]
task_id: Uuid,
message: String,
},
/// Inject context about sibling task progress
InjectSiblingContext {
#[serde(rename = "taskId")]
task_id: Uuid,
#[serde(rename = "siblingTaskId")]
sibling_task_id: Uuid,
#[serde(rename = "siblingName")]
sibling_name: String,
#[serde(rename = "siblingStatus")]
sibling_status: String,
#[serde(rename = "progressSummary")]
progress_summary: Option<String>,
#[serde(rename = "changedFiles")]
changed_files: Vec<String>,
},
// =========================================================================
// Merge Commands (for orchestrators to merge subtask branches)
// =========================================================================
/// List all subtask branches for a task
ListBranches {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Start merging a subtask branch
MergeStart {
#[serde(rename = "taskId")]
task_id: Uuid,
#[serde(rename = "sourceBranch")]
source_branch: String,
},
/// Get current merge status
MergeStatus {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Resolve a merge conflict
MergeResolve {
#[serde(rename = "taskId")]
task_id: Uuid,
file: String,
/// "ours" or "theirs"
strategy: String,
},
/// Commit the current merge
MergeCommit {
#[serde(rename = "taskId")]
task_id: Uuid,
message: String,
},
/// Abort the current merge
MergeAbort {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Skip merging a subtask branch (mark as intentionally not merged)
MergeSkip {
#[serde(rename = "taskId")]
task_id: Uuid,
#[serde(rename = "subtaskId")]
subtask_id: Uuid,
reason: String,
},
/// Check if all subtask branches have been merged or skipped (completion gate)
CheckMergeComplete {
#[serde(rename = "taskId")]
task_id: Uuid,
},
// =========================================================================
// Completion Action Commands
// =========================================================================
/// Retry a completion action for a completed task
RetryCompletionAction {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Human-readable task name (used for commit messages)
#[serde(rename = "taskName")]
task_name: String,
/// The action to execute: "branch", "merge", or "pr"
action: String,
/// Path to the target repository
#[serde(rename = "targetRepoPath")]
target_repo_path: String,
/// Target branch to merge into (for merge/pr actions)
#[serde(rename = "targetBranch")]
target_branch: Option<String>,
},
/// Clone worktree to a target directory
CloneWorktree {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Path to the target directory
#[serde(rename = "targetDir")]
target_dir: String,
},
/// Check if a target directory exists
CheckTargetExists {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Path to check
#[serde(rename = "targetDir")]
target_dir: String,
},
/// Error response
Error { code: String, message: String },
}
/// Active daemon connection info stored in state.
#[derive(Debug)]
pub struct DaemonConnectionInfo {
/// Database ID of the daemon
pub id: Uuid,
/// Owner ID for data isolation (from API key authentication)
pub owner_id: Uuid,
/// WebSocket connection identifier
pub connection_id: String,
/// Daemon hostname
pub hostname: Option<String>,
/// Machine identifier
pub machine_id: Option<String>,
/// Channel to send commands to this daemon
pub command_sender: mpsc::Sender<DaemonCommand>,
/// Current working directory of the daemon
pub working_directory: Option<String>,
/// Path to ~/.makima/home directory on daemon (for cloning completed work)
pub home_directory: Option<String>,
/// Path to worktrees directory (~/.makima/worktrees) on daemon
pub worktrees_directory: Option<String>,
}
/// Shared application state containing ML models and database pool.
///
/// Models are wrapped in `Mutex` for thread-safe mutable access during inference.
pub struct AppState {
/// Speech-to-text model (Parakeet TDT)
pub parakeet: Mutex<ParakeetTDT>,
/// End-of-Utterance detection model for streaming
pub parakeet_eou: Mutex<ParakeetEOU>,
/// Speaker diarization model (Sortformer)
pub sortformer: Mutex<Sortformer>,
/// Optional database connection pool
pub db_pool: Option<PgPool>,
/// Broadcast channel for file update notifications
pub file_updates: broadcast::Sender<FileUpdateNotification>,
/// Broadcast channel for task update notifications
pub task_updates: broadcast::Sender<TaskUpdateNotification>,
/// Broadcast channel for task output streaming
pub task_output: broadcast::Sender<TaskOutputNotification>,
/// Active daemon connections (keyed by connection_id)
pub daemon_connections: DashMap<String, DaemonConnectionInfo>,
/// Tool keys for orchestrator API access (key -> task_id)
pub tool_keys: DashMap<String, Uuid>,
/// JWT verifier for Supabase authentication (None if not configured)
pub jwt_verifier: Option<JwtVerifier>,
}
impl AppState {
/// Load all ML models from the specified directories.
///
/// # Arguments
/// * `parakeet_model_dir` - Path to the Parakeet TDT model directory
/// * `parakeet_eou_dir` - Path to the Parakeet EOU model directory
/// * `sortformer_model_path` - Path to the Sortformer diarization model file
pub fn new(
parakeet_model_dir: &str,
parakeet_eou_dir: &str,
sortformer_model_path: &str,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let parakeet = ParakeetTDT::from_pretrained(parakeet_model_dir, None)?;
let parakeet_eou = ParakeetEOU::from_pretrained(parakeet_eou_dir, None)?;
let sortformer = Sortformer::with_config(
sortformer_model_path,
None,
DiarizationConfig::callhome(),
)?;
// Create broadcast channels with buffer for 256 messages
let (file_updates, _) = broadcast::channel(256);
let (task_updates, _) = broadcast::channel(256);
let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming
// Initialize JWT verifier from environment (optional)
// Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256)
let jwt_verifier = match AuthConfig::from_env() {
Some(config) => match JwtVerifier::new(config) {
Ok(verifier) => {
tracing::info!("JWT authentication configured");
Some(verifier)
}
Err(e) => {
tracing::error!("Failed to initialize JWT verifier: {}", e);
None
}
},
None => {
// Log which env vars are missing
let has_url = std::env::var("SUPABASE_URL").is_ok();
let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok();
let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok();
if !has_url {
tracing::info!("JWT authentication not configured (SUPABASE_URL not set)");
} else if !has_public_key && !has_secret {
tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)");
}
None
}
};
Ok(Self {
parakeet: Mutex::new(parakeet),
parakeet_eou: Mutex::new(parakeet_eou),
sortformer: Mutex::new(sortformer),
db_pool: None,
file_updates,
task_updates,
task_output,
daemon_connections: DashMap::new(),
tool_keys: DashMap::new(),
jwt_verifier,
})
}
/// Set the database pool.
pub fn with_db_pool(mut self, pool: PgPool) -> Self {
self.db_pool = Some(pool);
self
}
/// Broadcast a file update notification to all subscribers.
///
/// This is a no-op if there are no subscribers (ignores send errors).
pub fn broadcast_file_update(&self, notification: FileUpdateNotification) {
// Ignore send errors - they just mean no one is listening
let _ = self.file_updates.send(notification);
}
/// Broadcast a task update notification to all subscribers.
///
/// This is a no-op if there are no subscribers (ignores send errors).
pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) {
let _ = self.task_updates.send(notification);
}
/// Broadcast task output to all subscribers.
///
/// Used for streaming Claude Code container output to frontend clients.
pub fn broadcast_task_output(&self, notification: TaskOutputNotification) {
let _ = self.task_output.send(notification);
}
/// Register a new daemon connection.
///
/// Returns the connection_id for later reference.
pub fn register_daemon(
&self,
connection_id: String,
daemon_id: Uuid,
owner_id: Uuid,
hostname: Option<String>,
machine_id: Option<String>,
command_sender: mpsc::Sender<DaemonCommand>,
) {
self.daemon_connections.insert(
connection_id.clone(),
DaemonConnectionInfo {
id: daemon_id,
owner_id,
connection_id,
hostname,
machine_id,
command_sender,
working_directory: None,
home_directory: None,
worktrees_directory: None,
},
);
}
/// Update daemon directory information.
pub fn update_daemon_directories(
&self,
connection_id: &str,
working_directory: String,
home_directory: String,
worktrees_directory: String,
) {
if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) {
entry.working_directory = Some(working_directory);
entry.home_directory = Some(home_directory);
entry.worktrees_directory = Some(worktrees_directory);
}
}
/// Unregister a daemon connection.
pub fn unregister_daemon(&self, connection_id: &str) {
self.daemon_connections.remove(connection_id);
}
/// Get a daemon connection by connection_id.
pub fn get_daemon(&self, connection_id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
self.daemon_connections.get(connection_id)
}
/// Get a daemon by its database ID.
pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> {
self.daemon_connections
.iter()
.find(|entry| entry.value().id == daemon_id)
.map(|entry| {
// Return a reference to the found entry
self.daemon_connections.get(entry.key()).unwrap()
})
}
/// Send a command to a specific daemon by its database ID.
pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> {
if let Some(daemon) = self.daemon_connections
.iter()
.find(|entry| entry.value().id == daemon_id)
{
daemon.value().command_sender.send(command).await
.map_err(|e| format!("Failed to send command to daemon: {}", e))
} else {
Err(format!("Daemon {} not connected", daemon_id))
}
}
/// Broadcast sibling progress to all running sibling tasks.
///
/// This is used for sibling awareness - when a task makes progress,
/// its siblings are notified so they can adjust their work if needed.
pub async fn broadcast_sibling_progress(
&self,
source_task_id: Uuid,
source_task_name: &str,
source_task_status: &str,
progress_summary: Option<String>,
changed_files: Vec<String>,
running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id)
) {
for (sibling_task_id, daemon_id) in running_sibling_daemon_ids {
let command = DaemonCommand::InjectSiblingContext {
task_id: sibling_task_id,
sibling_task_id: source_task_id,
sibling_name: source_task_name.to_string(),
sibling_status: source_task_status.to_string(),
progress_summary: progress_summary.clone(),
changed_files: changed_files.clone(),
};
// Fire and forget - don't block on sending to all daemons
if let Err(e) = self.send_daemon_command(daemon_id, command).await {
tracing::warn!(
"Failed to inject sibling context to task {}: {}",
sibling_task_id,
e
);
}
}
}
/// Get list of connected daemon IDs.
pub fn list_connected_daemon_ids(&self) -> Vec<Uuid> {
self.daemon_connections
.iter()
.map(|entry| entry.value().id)
.collect()
}
// =========================================================================
// Tool Key Management
// =========================================================================
/// Register a tool key for a task.
///
/// This allows orchestrators to authenticate with the API using
/// the `X-Makima-Tool-Key` header.
pub fn register_tool_key(&self, key: String, task_id: Uuid) {
tracing::info!(task_id = %task_id, "Registering tool key");
self.tool_keys.insert(key, task_id);
}
/// Validate a tool key and return the associated task ID.
pub fn validate_tool_key(&self, key: &str) -> Option<Uuid> {
self.tool_keys.get(key).map(|entry| *entry.value())
}
/// Revoke a tool key for a task.
///
/// This should be called when a task completes or is terminated.
pub fn revoke_tool_key(&self, task_id: Uuid) {
// Find and remove the key for this task
self.tool_keys.retain(|_, v| *v != task_id);
tracing::info!(task_id = %task_id, "Revoked tool key");
}
}
/// Type alias for the shared application state.
pub type SharedState = Arc<AppState>;