1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
|
//! WebSocket handler for task change subscriptions and output streaming.
//!
//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications
//! when tasks are updated. They can also subscribe to task output for live terminal streaming.
//!
//! ## Owner-scoped filtering
//!
//! Notifications are filtered by owner_id. If a notification has an owner_id set,
//! it will only be delivered to clients who are subscribed to tasks belonging to that owner.
//! The task's owner_id is looked up from the database when the client subscribes.
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::collections::HashMap;
use uuid::Uuid;
use crate::server::state::SharedState;
/// Client message for task subscription management.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskClientMessage {
/// Subscribe to updates for a specific task
Subscribe {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscribe from updates for a specific task
Unsubscribe {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Subscribe to all task updates
SubscribeAll,
/// Unsubscribe from all task updates
UnsubscribeAll,
/// Subscribe to live output streaming for a specific task
SubscribeOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscribe from output streaming for a specific task
UnsubscribeOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
},
}
/// Server message for task subscription WebSocket.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskServerMessage {
/// Subscription confirmed for specific task
Subscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Unsubscription confirmed for specific task
Unsubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Subscribed to all task updates
SubscribedAll,
/// Unsubscribed from all task updates
UnsubscribedAll,
/// Task was updated
TaskUpdated {
#[serde(rename = "taskId")]
task_id: Uuid,
version: i32,
status: String,
#[serde(rename = "updatedFields")]
updated_fields: Vec<String>,
#[serde(rename = "updatedBy")]
updated_by: String,
},
/// Live output from Claude Code container (parsed and structured)
TaskOutput {
#[serde(rename = "taskId")]
task_id: Uuid,
/// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
#[serde(rename = "messageType")]
message_type: String,
/// Main text content
content: String,
/// Tool name if tool_use message
#[serde(rename = "toolName", skip_serializing_if = "Option::is_none")]
tool_name: Option<String>,
/// Tool input JSON if tool_use message
#[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")]
tool_input: Option<serde_json::Value>,
/// Whether tool result was an error
#[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
/// Cost in USD if result message
#[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")]
cost_usd: Option<f64>,
/// Duration in ms if result message
#[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")]
duration_ms: Option<u64>,
#[serde(rename = "isPartial")]
is_partial: bool,
},
/// Output subscription confirmed
OutputSubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Output unsubscription confirmed
OutputUnsubscribed {
#[serde(rename = "taskId")]
task_id: Uuid,
},
/// Error occurred
Error { code: String, message: String },
}
/// WebSocket upgrade handler for task subscriptions.
#[utoipa::path(
get,
path = "/api/v1/mesh/tasks/subscribe",
responses(
(status = 101, description = "WebSocket connection established"),
),
tag = "Mesh"
)]
pub async fn task_subscription_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_task_subscription(socket, state))
}
/// Look up the owner_id for a task from the database.
async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> {
let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
.bind(task_id)
.fetch_optional(pool)
.await
.ok()??;
row.try_get("owner_id").ok()
}
async fn handle_task_subscription(socket: WebSocket, state: SharedState) {
let (mut sender, mut receiver) = socket.split();
// Map of task IDs to their owner_ids for this client's subscriptions
let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
// Whether client is subscribed to all task updates (not owner-scoped)
let mut subscribed_all = false;
// Map of task IDs to their owner_ids for output streaming subscriptions
let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
// Subscribe to broadcast channels
let mut task_update_rx = state.task_updates.subscribe();
let mut task_output_rx = state.task_output.subscribe();
loop {
tokio::select! {
// Handle incoming WebSocket messages from client
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<TaskClientMessage>(&text) {
Ok(TaskClientMessage::Subscribe { task_id }) => {
// Look up owner_id for this task
let owner_id = if let Some(ref pool) = state.db_pool {
get_task_owner_id(pool, task_id).await
} else {
None
};
task_subscriptions.insert(task_id, owner_id);
let response = TaskServerMessage::Subscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id);
}
Ok(TaskClientMessage::Unsubscribe { task_id }) => {
task_subscriptions.remove(&task_id);
let response = TaskServerMessage::Unsubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from task {}", task_id);
}
Ok(TaskClientMessage::SubscribeAll) => {
subscribed_all = true;
let response = TaskServerMessage::SubscribedAll;
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to all tasks");
}
Ok(TaskClientMessage::UnsubscribeAll) => {
subscribed_all = false;
let response = TaskServerMessage::UnsubscribedAll;
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from all tasks");
}
Ok(TaskClientMessage::SubscribeOutput { task_id }) => {
// Look up owner_id for this task
let owner_id = if let Some(ref pool) = state.db_pool {
get_task_owner_id(pool, task_id).await
} else {
None
};
output_subscriptions.insert(task_id, owner_id);
let response = TaskServerMessage::OutputSubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id);
}
Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => {
output_subscriptions.remove(&task_id);
let response = TaskServerMessage::OutputUnsubscribed { task_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from output for task {}", task_id);
}
Err(e) => {
let response = TaskServerMessage::Error {
code: "PARSE_ERROR".into(),
message: e.to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
Some(Ok(Message::Close(_))) | None => {
tracing::debug!("Client disconnected from task subscription");
break;
}
Some(Err(e)) => {
tracing::warn!("Task WebSocket error: {}", e);
break;
}
_ => {}
}
}
// Handle task update broadcasts
notification = task_update_rx.recv() => {
match notification {
Ok(notification) => {
// Check if client should receive this notification
let should_forward = if subscribed_all {
// SubscribeAll gets all notifications (typically for admin views)
true
} else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) {
// Client is subscribed to this specific task
// Verify owner_id matches (if set on both sides)
match (notification.owner_id, subscribed_owner) {
(Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
_ => true, // Allow if owner_id not set on either side
}
} else {
false
};
if should_forward {
let response = TaskServerMessage::TaskUpdated {
task_id: notification.task_id,
version: notification.version,
status: notification.status,
updated_fields: notification.updated_fields,
updated_by: notification.updated_by,
};
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Task subscription client lagged, skipped {} messages", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
// Handle task output broadcasts
output = task_output_rx.recv() => {
match output {
Ok(output) => {
// Check if client should receive this output
let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) {
// Client is subscribed to output for this task
// Verify owner_id matches (if set on both sides)
match (output.owner_id, subscribed_owner) {
(Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
_ => true, // Allow if owner_id not set on either side
}
} else {
false
};
if should_forward {
let response = TaskServerMessage::TaskOutput {
task_id: output.task_id,
message_type: output.message_type,
content: output.content,
tool_name: output.tool_name,
tool_input: output.tool_input,
is_error: output.is_error,
cost_usd: output.cost_usd,
duration_ms: output.duration_ms,
is_partial: output.is_partial,
};
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Task output subscription client lagged, skipped {} messages", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
}
|