From 8b17a175c3e7e27b789812eba4e3cd760beadb10 Mon Sep 17 00:00:00 2001 From: soryu Date: Tue, 6 Jan 2026 04:08:11 +0000 Subject: Initial Control system --- makima/src/server/handlers/mesh_ws.rs | 346 ++++++++++++++++++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 makima/src/server/handlers/mesh_ws.rs (limited to 'makima/src/server/handlers/mesh_ws.rs') diff --git a/makima/src/server/handlers/mesh_ws.rs b/makima/src/server/handlers/mesh_ws.rs new file mode 100644 index 0000000..d15fba7 --- /dev/null +++ b/makima/src/server/handlers/mesh_ws.rs @@ -0,0 +1,346 @@ +//! WebSocket handler for task change subscriptions and output streaming. +//! +//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications +//! when tasks are updated. They can also subscribe to task output for live terminal streaming. +//! +//! ## Owner-scoped filtering +//! +//! Notifications are filtered by owner_id. If a notification has an owner_id set, +//! it will only be delivered to clients who are subscribed to tasks belonging to that owner. +//! The task's owner_id is looked up from the database when the client subscribes. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use sqlx::Row; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::server::state::SharedState; + +/// Client message for task subscription management. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskClientMessage { + /// Subscribe to updates for a specific task + Subscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from updates for a specific task + Unsubscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribe to all task updates + SubscribeAll, + /// Unsubscribe from all task updates + UnsubscribeAll, + /// Subscribe to live output streaming for a specific task + SubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from output streaming for a specific task + UnsubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, +} + +/// Server message for task subscription WebSocket. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskServerMessage { + /// Subscription confirmed for specific task + Subscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscription confirmed for specific task + Unsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribed to all task updates + SubscribedAll, + /// Unsubscribed from all task updates + UnsubscribedAll, + /// Task was updated + TaskUpdated { + #[serde(rename = "taskId")] + task_id: Uuid, + version: i32, + status: String, + #[serde(rename = "updatedFields")] + updated_fields: Vec, + #[serde(rename = "updatedBy")] + updated_by: String, + }, + /// Live output from Claude Code container (parsed and structured) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + #[serde(rename = "messageType")] + message_type: String, + /// Main text content + content: String, + /// Tool name if tool_use message + #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")] + tool_name: Option, + /// Tool input JSON if tool_use message + #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")] + tool_input: Option, + /// Whether tool result was an error + #[serde(rename = "isError", skip_serializing_if = "Option::is_none")] + is_error: Option, + /// Cost in USD if result message + #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")] + cost_usd: Option, + /// Duration in ms if result message + #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")] + duration_ms: Option, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Output subscription confirmed + OutputSubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Output unsubscription confirmed + OutputUnsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Error occurred + Error { code: String, message: String }, +} + +/// WebSocket upgrade handler for task subscriptions. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/subscribe", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "Mesh" +)] +pub async fn task_subscription_handler( + ws: WebSocketUpgrade, + State(state): State, +) -> Response { + ws.on_upgrade(|socket| handle_task_subscription(socket, state)) +} + +/// Look up the owner_id for a task from the database. +async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option { + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .ok()??; + + row.try_get("owner_id").ok() +} + +async fn handle_task_subscription(socket: WebSocket, state: SharedState) { + let (mut sender, mut receiver) = socket.split(); + + // Map of task IDs to their owner_ids for this client's subscriptions + let mut task_subscriptions: HashMap> = HashMap::new(); + // Whether client is subscribed to all task updates (not owner-scoped) + let mut subscribed_all = false; + // Map of task IDs to their owner_ids for output streaming subscriptions + let mut output_subscriptions: HashMap> = HashMap::new(); + + // Subscribe to broadcast channels + let mut task_update_rx = state.task_updates.subscribe(); + let mut task_output_rx = state.task_output.subscribe(); + + loop { + tokio::select! { + // Handle incoming WebSocket messages from client + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::(&text) { + Ok(TaskClientMessage::Subscribe { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + task_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::Subscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::Unsubscribe { task_id }) => { + task_subscriptions.remove(&task_id); + let response = TaskServerMessage::Unsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from task {}", task_id); + } + Ok(TaskClientMessage::SubscribeAll) => { + subscribed_all = true; + let response = TaskServerMessage::SubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to all tasks"); + } + Ok(TaskClientMessage::UnsubscribeAll) => { + subscribed_all = false; + let response = TaskServerMessage::UnsubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from all tasks"); + } + Ok(TaskClientMessage::SubscribeOutput { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + output_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::OutputSubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => { + output_subscriptions.remove(&task_id); + let response = TaskServerMessage::OutputUnsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from output for task {}", task_id); + } + Err(e) => { + let response = TaskServerMessage::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Client disconnected from task subscription"); + break; + } + Some(Err(e)) => { + tracing::warn!("Task WebSocket error: {}", e); + break; + } + _ => {} + } + } + + // Handle task update broadcasts + notification = task_update_rx.recv() => { + match notification { + Ok(notification) => { + // Check if client should receive this notification + let should_forward = if subscribed_all { + // SubscribeAll gets all notifications (typically for admin views) + true + } else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) { + // Client is subscribed to this specific task + // Verify owner_id matches (if set on both sides) + match (notification.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskUpdated { + task_id: notification.task_id, + version: notification.version, + status: notification.status, + updated_fields: notification.updated_fields, + updated_by: notification.updated_by, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + + // Handle task output broadcasts + output = task_output_rx.recv() => { + match output { + Ok(output) => { + // Check if client should receive this output + let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) { + // Client is subscribed to output for this task + // Verify owner_id matches (if set on both sides) + match (output.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskOutput { + task_id: output.task_id, + message_type: output.message_type, + content: output.content, + tool_name: output.tool_name, + tool_input: output.tool_input, + is_error: output.is_error, + cost_usd: output.cost_usd, + duration_ms: output.duration_ms, + is_partial: output.is_partial, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task output subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + } + } +} -- cgit v1.2.3