//! WebSocket handler for task change subscriptions and output streaming. //! //! Clients can subscribe to specific tasks or all tasks to receive real-time notifications //! when tasks are updated. They can also subscribe to task output for live terminal streaming. //! //! ## Owner-scoped filtering //! //! Notifications are filtered by owner_id. If a notification has an owner_id set, //! it will only be delivered to clients who are subscribed to tasks belonging to that owner. //! The task's owner_id is looked up from the database when the client subscribes. use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, }; use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use sqlx::Row; use std::collections::HashMap; use uuid::Uuid; use crate::server::state::SharedState; /// Client message for task subscription management. #[derive(Debug, Clone, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum TaskClientMessage { /// Subscribe to updates for a specific task Subscribe { #[serde(rename = "taskId")] task_id: Uuid, }, /// Unsubscribe from updates for a specific task Unsubscribe { #[serde(rename = "taskId")] task_id: Uuid, }, /// Subscribe to all task updates SubscribeAll, /// Unsubscribe from all task updates UnsubscribeAll, /// Subscribe to live output streaming for a specific task SubscribeOutput { #[serde(rename = "taskId")] task_id: Uuid, }, /// Unsubscribe from output streaming for a specific task UnsubscribeOutput { #[serde(rename = "taskId")] task_id: Uuid, }, } /// Server message for task subscription WebSocket. #[derive(Debug, Clone, Serialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum TaskServerMessage { /// Subscription confirmed for specific task Subscribed { #[serde(rename = "taskId")] task_id: Uuid, }, /// Unsubscription confirmed for specific task Unsubscribed { #[serde(rename = "taskId")] task_id: Uuid, }, /// Subscribed to all task updates SubscribedAll, /// Unsubscribed from all task updates UnsubscribedAll, /// Task was updated TaskUpdated { #[serde(rename = "taskId")] task_id: Uuid, version: i32, status: String, #[serde(rename = "updatedFields")] updated_fields: Vec, #[serde(rename = "updatedBy")] updated_by: String, }, /// Live output from Claude Code container (parsed and structured) TaskOutput { #[serde(rename = "taskId")] task_id: Uuid, /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" #[serde(rename = "messageType")] message_type: String, /// Main text content content: String, /// Tool name if tool_use message #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")] tool_name: Option, /// Tool input JSON if tool_use message #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")] tool_input: Option, /// Whether tool result was an error #[serde(rename = "isError", skip_serializing_if = "Option::is_none")] is_error: Option, /// Cost in USD if result message #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")] cost_usd: Option, /// Duration in ms if result message #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")] duration_ms: Option, #[serde(rename = "isPartial")] is_partial: bool, }, /// Output subscription confirmed OutputSubscribed { #[serde(rename = "taskId")] task_id: Uuid, }, /// Output unsubscription confirmed OutputUnsubscribed { #[serde(rename = "taskId")] task_id: Uuid, }, /// Error occurred Error { code: String, message: String }, } /// WebSocket upgrade handler for task subscriptions. #[utoipa::path( get, path = "/api/v1/mesh/tasks/subscribe", responses( (status = 101, description = "WebSocket connection established"), ), tag = "Mesh" )] pub async fn task_subscription_handler( ws: WebSocketUpgrade, State(state): State, ) -> Response { ws.on_upgrade(|socket| handle_task_subscription(socket, state)) } /// Look up the owner_id for a task from the database. async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option { let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") .bind(task_id) .fetch_optional(pool) .await .ok()??; row.try_get("owner_id").ok() } async fn handle_task_subscription(socket: WebSocket, state: SharedState) { let (mut sender, mut receiver) = socket.split(); // Map of task IDs to their owner_ids for this client's subscriptions let mut task_subscriptions: HashMap> = HashMap::new(); // Whether client is subscribed to all task updates (not owner-scoped) let mut subscribed_all = false; // Map of task IDs to their owner_ids for output streaming subscriptions let mut output_subscriptions: HashMap> = HashMap::new(); // Subscribe to broadcast channels let mut task_update_rx = state.task_updates.subscribe(); let mut task_output_rx = state.task_output.subscribe(); loop { tokio::select! { // Handle incoming WebSocket messages from client msg = receiver.next() => { match msg { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(TaskClientMessage::Subscribe { task_id }) => { // Look up owner_id for this task let owner_id = if let Some(ref pool) = state.db_pool { get_task_owner_id(pool, task_id).await } else { None }; task_subscriptions.insert(task_id, owner_id); let response = TaskServerMessage::Subscribed { task_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id); } Ok(TaskClientMessage::Unsubscribe { task_id }) => { task_subscriptions.remove(&task_id); let response = TaskServerMessage::Unsubscribed { task_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client unsubscribed from task {}", task_id); } Ok(TaskClientMessage::SubscribeAll) => { subscribed_all = true; let response = TaskServerMessage::SubscribedAll; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client subscribed to all tasks"); } Ok(TaskClientMessage::UnsubscribeAll) => { subscribed_all = false; let response = TaskServerMessage::UnsubscribedAll; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client unsubscribed from all tasks"); } Ok(TaskClientMessage::SubscribeOutput { task_id }) => { // Look up owner_id for this task let owner_id = if let Some(ref pool) = state.db_pool { get_task_owner_id(pool, task_id).await } else { None }; output_subscriptions.insert(task_id, owner_id); let response = TaskServerMessage::OutputSubscribed { task_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id); } Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => { output_subscriptions.remove(&task_id); let response = TaskServerMessage::OutputUnsubscribed { task_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client unsubscribed from output for task {}", task_id); } Err(e) => { let response = TaskServerMessage::Error { code: "PARSE_ERROR".into(), message: e.to_string(), }; let json = serde_json::to_string(&response).unwrap(); let _ = sender.send(Message::Text(json.into())).await; } } } Some(Ok(Message::Close(_))) | None => { tracing::debug!("Client disconnected from task subscription"); break; } Some(Err(e)) => { tracing::warn!("Task WebSocket error: {}", e); break; } _ => {} } } // Handle task update broadcasts notification = task_update_rx.recv() => { match notification { Ok(notification) => { // Check if client should receive this notification let should_forward = if subscribed_all { // SubscribeAll gets all notifications (typically for admin views) true } else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) { // Client is subscribed to this specific task // Verify owner_id matches (if set on both sides) match (notification.owner_id, subscribed_owner) { (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, _ => true, // Allow if owner_id not set on either side } } else { false }; if should_forward { let response = TaskServerMessage::TaskUpdated { task_id: notification.task_id, version: notification.version, status: notification.status, updated_fields: notification.updated_fields, updated_by: notification.updated_by, }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { tracing::warn!("Task subscription client lagged, skipped {} messages", n); } Err(tokio::sync::broadcast::error::RecvError::Closed) => { break; } } } // Handle task output broadcasts output = task_output_rx.recv() => { match output { Ok(output) => { // Check if client should receive this output let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) { // Client is subscribed to output for this task // Verify owner_id matches (if set on both sides) match (output.owner_id, subscribed_owner) { (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, _ => true, // Allow if owner_id not set on either side } } else { false }; if should_forward { let response = TaskServerMessage::TaskOutput { task_id: output.task_id, message_type: output.message_type, content: output.content, tool_name: output.tool_name, tool_input: output.tool_input, is_error: output.is_error, cost_usd: output.cost_usd, duration_ms: output.duration_ms, is_partial: output.is_partial, }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { tracing::warn!("Task output subscription client lagged, skipped {} messages", n); } Err(tokio::sync::broadcast::error::RecvError::Closed) => { break; } } } } } }