//! WebSocket handler for task change subscriptions and output streaming.
//!
//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications
//! when tasks are updated. They can also subscribe to task output for live terminal streaming.
//!
//! ## Owner-scoped filtering
//!
//! Notifications are filtered by owner_id. If a notification has an owner_id set,
//! it will only be delivered to clients who are subscribed to tasks belonging to that owner.
//! The task's owner_id is looked up from the database when the client subscribes.
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::collections::HashMap;
use uuid::Uuid;
use crate::server::state::SharedState;
/// Client message for task subscription management.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskClientMessage {
/// Subscribe to updates for a specific task
Subscribe {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscribe from updates for a specific task
Unsubscribe {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Subscribe to all task updates
SubscribeAll,
/// Unsubscribe from all task updates
UnsubscribeAll,
/// Subscribe to live output streaming for a specific task
SubscribeOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscribe from output streaming for a specific task
UnsubscribeOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
},
}
/// Server message for task subscription WebSocket.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskServerMessage {
/// Subscription confirmed for specific task
Subscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscription confirmed for specific task
Unsubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Subscribed to all task updates
SubscribedAll,
/// Unsubscribed from all task updates
UnsubscribedAll,
/// Task was updated
TaskUpdated {
#[serde(rename = "taskId")]
task_id: Uuid,
version: i32,
status: String,
#[serde(rename = "updatedFields")]
updated_fields: Vec<String>,
#[serde(rename = "updatedBy")]
updated_by: String,
},
/// Live output from Claude Code container (parsed and structured)
TaskOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
#[serde(rename = "messageType")]
message_type: String,
/// Main text content
content: String,
/// Tool name if tool_use message
#[serde(rename = "toolName", skip_serializing_if = "Option::is_none")]
tool_name: Option<String>,
/// Tool input JSON if tool_use message
#[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")]
tool_input: Option<serde_json::Value>,
/// Whether tool result was an error
#[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
/// Cost in USD if result message
#[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")]
cost_usd: Option<f64>,
/// Duration in ms if result message
#[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")]
duration_ms: Option<u64>,
#[serde(rename = "isPartial")]
is_partial: bool,
},
/// Output subscription confirmed
OutputSubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Output unsubscription confirmed
OutputUnsubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Error occurred
Error { code: String, message: String },
}
/// WebSocket upgrade handler for task subscriptions.
#[utoipa::path(
get,
path = "/api/v1/mesh/tasks/subscribe",
responses(
(status = 101, description = "WebSocket connection established"),
),
tag = "Mesh"
)]
pub async fn task_subscription_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_task_subscription(socket, state))
}
/// Look up the owner_id for a task from the database.
async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> {
let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
.bind(task_id)
.fetch_optional(pool)
.await
.ok()??;
row.try_get("owner_id").ok()
}
async fn handle_task_subscription(socket: WebSocket, state: SharedState) {
let (mut sender, mut receiver) = socket.split();
// Map of task IDs to their owner_ids for this client's subscriptions
let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
// Whether client is subscribed to all task updates (not owner-scoped)
let mut subscribed_all = false;
// Map of task IDs to their owner_ids for output streaming subscriptions
let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
// Subscribe to broadcast channels
let mut task_update_rx = state.task_updates.subscribe();
let mut task_output_rx = state.task_output.subscribe();
loop {
tokio::select! {
// Handle incoming WebSocket messages from client
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<TaskClientMessage>(&text) {
Ok(TaskClientMessage::Subscribe { task_id }) => {
// Look up owner_id for this task
let owner_id = if let Some(ref pool) = state.db_pool {
get_task_owner_id(pool, task_id).await
} else {
None
};
task_subscriptions.insert(task_id, owner_id);
let response = TaskServerMessage::Subscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id);
}
Ok(TaskClientMessage::Unsubscribe { task_id }) => {
task_subscriptions.remove(&task_id);
let response = TaskServerMessage::Unsubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from task {}", task_id);
}
Ok(TaskClientMessage::SubscribeAll) => {
subscribed_all = true;
let response = TaskServerMessage::SubscribedAll;
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to all tasks");
}
Ok(TaskClientMessage::UnsubscribeAll) => {
subscribed_all = false;
let response = TaskServerMessage::UnsubscribedAll;
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from all tasks");
}
Ok(TaskClientMessage::SubscribeOutput { task_id }) => {
// Look up owner_id for this task
let owner_id = if let Some(ref pool) = state.db_pool {
get_task_owner_id(pool, task_id).await
} else {
None
};
output_subscriptions.insert(task_id, owner_id);
let response = TaskServerMessage::OutputSubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id);
}
Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => {
output_subscriptions.remove(&task_id);
let response = TaskServerMessage::OutputUnsubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from output for task {}", task_id);
}
Err(e) => {
let response = TaskServerMessage::Error {
code: "PARSE_ERROR".into(),
message: e.to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
Some(Ok(Message::Close(_))) | None => {
tracing::debug!("Client disconnected from task subscription");
break;
}
Some(Err(e)) => {
tracing::warn!("Task WebSocket error: {}", e);
break;
}
_ => {}
}
}
// Handle task update broadcasts
notification = task_update_rx.recv() => {
match notification {
Ok(notification) => {
// Check if client should receive this notification
let should_forward = if subscribed_all {
// SubscribeAll gets all notifications (typically for admin views)
true
} else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) {
// Client is subscribed to this specific task
// Verify owner_id matches (if set on both sides)
match (notification.owner_id, subscribed_owner) {
(Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
_ => true, // Allow if owner_id not set on either side
}
} else {
false
};
if should_forward {
let response = TaskServerMessage::TaskUpdated {
task_id: notification.task_id,
version: notification.version,
status: notification.status,
updated_fields: notification.updated_fields,
updated_by: notification.updated_by,
};
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Task subscription client lagged, skipped {} messages", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
// Handle task output broadcasts
output = task_output_rx.recv() => {
match output {
Ok(output) => {
// Check if client should receive this output
let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) {
// Client is subscribed to output for this task
// Verify owner_id matches (if set on both sides)
match (output.owner_id, subscribed_owner) {
(Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
_ => true, // Allow if owner_id not set on either side
}
} else {
false
};
if should_forward {
let response = TaskServerMessage::TaskOutput {
task_id: output.task_id,
message_type: output.message_type,
content: output.content,
tool_name: output.tool_name,
tool_input: output.tool_input,
is_error: output.is_error,
cost_usd: output.cost_usd,
duration_ms: output.duration_ms,
is_partial: output.is_partial,
};
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Task output subscription client lagged, skipped {} messages", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
}