summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/mesh_ws.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-06 04:08:11 +0000
committersoryu <soryu@soryu.co>2026-01-11 03:01:13 +0000
commit8b17a175c3e7e27b789812eba4e3cd760beadb10 (patch)
tree7864dcaa2fa9db47fdfd4e8bfdb0b1dde832aa33 /makima/src/server/handlers/mesh_ws.rs
parentf79c416c58557d2f946aa5332989afdfa8c021cd (diff)
downloadsoryu-8b17a175c3e7e27b789812eba4e3cd760beadb10.tar.gz
soryu-8b17a175c3e7e27b789812eba4e3cd760beadb10.zip
Initial Control system
Diffstat (limited to 'makima/src/server/handlers/mesh_ws.rs')
-rw-r--r--makima/src/server/handlers/mesh_ws.rs346
1 files changed, 346 insertions, 0 deletions
diff --git a/makima/src/server/handlers/mesh_ws.rs b/makima/src/server/handlers/mesh_ws.rs
new file mode 100644
index 0000000..d15fba7
--- /dev/null
+++ b/makima/src/server/handlers/mesh_ws.rs
@@ -0,0 +1,346 @@
+//! WebSocket handler for task change subscriptions and output streaming.
+//!
+//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications
+//! when tasks are updated. They can also subscribe to task output for live terminal streaming.
+//!
+//! ## Owner-scoped filtering
+//!
+//! Notifications are filtered by owner_id. If a notification has an owner_id set,
+//! it will only be delivered to clients who are subscribed to tasks belonging to that owner.
+//! The task's owner_id is looked up from the database when the client subscribes.
+
+use axum::{
+ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
+ response::Response,
+};
+use futures::{SinkExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use sqlx::Row;
+use std::collections::HashMap;
+use uuid::Uuid;
+
+use crate::server::state::SharedState;
+
+/// Client message for task subscription management.
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TaskClientMessage {
+ /// Subscribe to updates for a specific task
+ Subscribe {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Unsubscribe from updates for a specific task
+ Unsubscribe {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Subscribe to all task updates
+ SubscribeAll,
+ /// Unsubscribe from all task updates
+ UnsubscribeAll,
+ /// Subscribe to live output streaming for a specific task
+ SubscribeOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Unsubscribe from output streaming for a specific task
+ UnsubscribeOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+}
+
+/// Server message for task subscription WebSocket.
+#[derive(Debug, Clone, Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TaskServerMessage {
+ /// Subscription confirmed for specific task
+ Subscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Unsubscription confirmed for specific task
+ Unsubscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Subscribed to all task updates
+ SubscribedAll,
+ /// Unsubscribed from all task updates
+ UnsubscribedAll,
+ /// Task was updated
+ TaskUpdated {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ version: i32,
+ status: String,
+ #[serde(rename = "updatedFields")]
+ updated_fields: Vec<String>,
+ #[serde(rename = "updatedBy")]
+ updated_by: String,
+ },
+ /// Live output from Claude Code container (parsed and structured)
+ TaskOutput {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
+ #[serde(rename = "messageType")]
+ message_type: String,
+ /// Main text content
+ content: String,
+ /// Tool name if tool_use message
+ #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")]
+ tool_name: Option<String>,
+ /// Tool input JSON if tool_use message
+ #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")]
+ tool_input: Option<serde_json::Value>,
+ /// Whether tool result was an error
+ #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
+ is_error: Option<bool>,
+ /// Cost in USD if result message
+ #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")]
+ cost_usd: Option<f64>,
+ /// Duration in ms if result message
+ #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")]
+ duration_ms: Option<u64>,
+ #[serde(rename = "isPartial")]
+ is_partial: bool,
+ },
+ /// Output subscription confirmed
+ OutputSubscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Output unsubscription confirmed
+ OutputUnsubscribed {
+ #[serde(rename = "taskId")]
+ task_id: Uuid,
+ },
+ /// Error occurred
+ Error { code: String, message: String },
+}
+
+/// WebSocket upgrade handler for task subscriptions.
+#[utoipa::path(
+ get,
+ path = "/api/v1/mesh/tasks/subscribe",
+ responses(
+ (status = 101, description = "WebSocket connection established"),
+ ),
+ tag = "Mesh"
+)]
+pub async fn task_subscription_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_task_subscription(socket, state))
+}
+
+/// Look up the owner_id for a task from the database.
+async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> {
+ let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
+ .bind(task_id)
+ .fetch_optional(pool)
+ .await
+ .ok()??;
+
+ row.try_get("owner_id").ok()
+}
+
+async fn handle_task_subscription(socket: WebSocket, state: SharedState) {
+ let (mut sender, mut receiver) = socket.split();
+
+ // Map of task IDs to their owner_ids for this client's subscriptions
+ let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
+ // Whether client is subscribed to all task updates (not owner-scoped)
+ let mut subscribed_all = false;
+ // Map of task IDs to their owner_ids for output streaming subscriptions
+ let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
+
+ // Subscribe to broadcast channels
+ let mut task_update_rx = state.task_updates.subscribe();
+ let mut task_output_rx = state.task_output.subscribe();
+
+ loop {
+ tokio::select! {
+ // Handle incoming WebSocket messages from client
+ msg = receiver.next() => {
+ match msg {
+ Some(Ok(Message::Text(text))) => {
+ match serde_json::from_str::<TaskClientMessage>(&text) {
+ Ok(TaskClientMessage::Subscribe { task_id }) => {
+ // Look up owner_id for this task
+ let owner_id = if let Some(ref pool) = state.db_pool {
+ get_task_owner_id(pool, task_id).await
+ } else {
+ None
+ };
+ task_subscriptions.insert(task_id, owner_id);
+ let response = TaskServerMessage::Subscribed { task_id };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id);
+ }
+ Ok(TaskClientMessage::Unsubscribe { task_id }) => {
+ task_subscriptions.remove(&task_id);
+ let response = TaskServerMessage::Unsubscribed { task_id };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client unsubscribed from task {}", task_id);
+ }
+ Ok(TaskClientMessage::SubscribeAll) => {
+ subscribed_all = true;
+ let response = TaskServerMessage::SubscribedAll;
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client subscribed to all tasks");
+ }
+ Ok(TaskClientMessage::UnsubscribeAll) => {
+ subscribed_all = false;
+ let response = TaskServerMessage::UnsubscribedAll;
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client unsubscribed from all tasks");
+ }
+ Ok(TaskClientMessage::SubscribeOutput { task_id }) => {
+ // Look up owner_id for this task
+ let owner_id = if let Some(ref pool) = state.db_pool {
+ get_task_owner_id(pool, task_id).await
+ } else {
+ None
+ };
+ output_subscriptions.insert(task_id, owner_id);
+ let response = TaskServerMessage::OutputSubscribed { task_id };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id);
+ }
+ Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => {
+ output_subscriptions.remove(&task_id);
+ let response = TaskServerMessage::OutputUnsubscribed { task_id };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ tracing::debug!("Client unsubscribed from output for task {}", task_id);
+ }
+ Err(e) => {
+ let response = TaskServerMessage::Error {
+ code: "PARSE_ERROR".into(),
+ message: e.to_string(),
+ };
+ let json = serde_json::to_string(&response).unwrap();
+ let _ = sender.send(Message::Text(json.into())).await;
+ }
+ }
+ }
+ Some(Ok(Message::Close(_))) | None => {
+ tracing::debug!("Client disconnected from task subscription");
+ break;
+ }
+ Some(Err(e)) => {
+ tracing::warn!("Task WebSocket error: {}", e);
+ break;
+ }
+ _ => {}
+ }
+ }
+
+ // Handle task update broadcasts
+ notification = task_update_rx.recv() => {
+ match notification {
+ Ok(notification) => {
+ // Check if client should receive this notification
+ let should_forward = if subscribed_all {
+ // SubscribeAll gets all notifications (typically for admin views)
+ true
+ } else if let Some(subscribed_owner) = task_subscriptions.get(&notification.task_id) {
+ // Client is subscribed to this specific task
+ // Verify owner_id matches (if set on both sides)
+ match (notification.owner_id, subscribed_owner) {
+ (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
+ _ => true, // Allow if owner_id not set on either side
+ }
+ } else {
+ false
+ };
+
+ if should_forward {
+ let response = TaskServerMessage::TaskUpdated {
+ task_id: notification.task_id,
+ version: notification.version,
+ status: notification.status,
+ updated_fields: notification.updated_fields,
+ updated_by: notification.updated_by,
+ };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ }
+ }
+ Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
+ tracing::warn!("Task subscription client lagged, skipped {} messages", n);
+ }
+ Err(tokio::sync::broadcast::error::RecvError::Closed) => {
+ break;
+ }
+ }
+ }
+
+ // Handle task output broadcasts
+ output = task_output_rx.recv() => {
+ match output {
+ Ok(output) => {
+ // Check if client should receive this output
+ let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) {
+ // Client is subscribed to output for this task
+ // Verify owner_id matches (if set on both sides)
+ match (output.owner_id, subscribed_owner) {
+ (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
+ _ => true, // Allow if owner_id not set on either side
+ }
+ } else {
+ false
+ };
+
+ if should_forward {
+ let response = TaskServerMessage::TaskOutput {
+ task_id: output.task_id,
+ message_type: output.message_type,
+ content: output.content,
+ tool_name: output.tool_name,
+ tool_input: output.tool_input,
+ is_error: output.is_error,
+ cost_usd: output.cost_usd,
+ duration_ms: output.duration_ms,
+ is_partial: output.is_partial,
+ };
+ let json = serde_json::to_string(&response).unwrap();
+ if sender.send(Message::Text(json.into())).await.is_err() {
+ break;
+ }
+ }
+ }
+ Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
+ tracing::warn!("Task output subscription client lagged, skipped {} messages", n);
+ }
+ Err(tokio::sync::broadcast::error::RecvError::Closed) => {
+ break;
+ }
+ }
+ }
+ }
+ }
+}