summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/mesh_ws.rs
blob: d15fba7f2da816297a686a87347374baa349d892 (plain) (tree)

























































































































































































































































































































































                                                                                                                            
//! WebSocket handler for task change subscriptions and output streaming.
//!
//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications
//! when tasks are updated. They can also subscribe to task output for live terminal streaming.
//!
//! ## Owner-scoped filtering
//!
//! Notifications are filtered by owner_id. If a notification has an owner_id set,
//! it will only be delivered to clients who are subscribed to tasks belonging to that owner.
//! The task's owner_id is looked up from the database when the client subscribes.

use axum::{
    extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
    response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::collections::HashMap;
use uuid::Uuid;

use crate::server::state::SharedState;

/// Client message for task subscription management.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskClientMessage {
    /// Subscribe to updates for a specific task
    Subscribe {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Unsubscribe from updates for a specific task
    Unsubscribe {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Subscribe to all task updates
    SubscribeAll,
    /// Unsubscribe from all task updates
    UnsubscribeAll,
    /// Subscribe to live output streaming for a specific task
    SubscribeOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Unsubscribe from output streaming for a specific task
    UnsubscribeOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
}

/// Server message for task subscription WebSocket.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum TaskServerMessage {
    /// Subscription confirmed for specific task
    Subscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Unsubscription confirmed for specific task
    Unsubscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Subscribed to all task updates
    SubscribedAll,
    /// Unsubscribed from all task updates
    UnsubscribedAll,
    /// Task was updated
    TaskUpdated {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        version: i32,
        status: String,
        #[serde(rename = "updatedFields")]
        updated_fields: Vec<String>,
        #[serde(rename = "updatedBy")]
        updated_by: String,
    },
    /// Live output from Claude Code container (parsed and structured)
    TaskOutput {
        #[serde(rename = "taskId")]
        task_id: Uuid,
        /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw"
        #[serde(rename = "messageType")]
        message_type: String,
        /// Main text content
        content: String,
        /// Tool name if tool_use message
        #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")]
        tool_name: Option<String>,
        /// Tool input JSON if tool_use message
        #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")]
        tool_input: Option<serde_json::Value>,
        /// Whether tool result was an error
        #[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
        is_error: Option<bool>,
        /// Cost in USD if result message
        #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")]
        cost_usd: Option<f64>,
        /// Duration in ms if result message
        #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")]
        duration_ms: Option<u64>,
        #[serde(rename = "isPartial")]
        is_partial: bool,
    },
    /// Output subscription confirmed
    OutputSubscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Output unsubscription confirmed
    OutputUnsubscribed {
        #[serde(rename = "taskId")]
        task_id: Uuid,
    },
    /// Error occurred
    Error { code: String, message: String },
}

/// WebSocket upgrade handler for task subscriptions.
#[utoipa::path(
    get,
    path = "/api/v1/mesh/tasks/subscribe",
    responses(
        (status = 101, description = "WebSocket connection established"),
    ),
    tag = "Mesh"
)]
pub async fn task_subscription_handler(
    ws: WebSocketUpgrade,
    State(state): State<SharedState>,
) -> Response {
    ws.on_upgrade(|socket| handle_task_subscription(socket, state))
}

/// Look up the owner_id for a task from the database.
async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> {
    let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
        .bind(task_id)
        .fetch_optional(pool)
        .await
        .ok()??;

    row.try_get("owner_id").ok()
}

async fn handle_task_subscription(socket: WebSocket, state: SharedState) {
    let (mut sender, mut receiver) = socket.split();

    // Map of task IDs to their owner_ids for this client's subscriptions
    let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();
    // Whether client is subscribed to all task updates (not owner-scoped)
    let mut subscribed_all = false;
    // Map of task IDs to their owner_ids for output streaming subscriptions
    let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new();

    // Subscribe to broadcast channels
    let mut task_update_rx = state.task_updates.subscribe();
    let mut task_output_rx = state.task_output.subscribe();

    loop {
        tokio::select! {
            // Handle incoming WebSocket messages from client
            msg = receiver.next() => {
                match msg {
                    Some(Ok(Message::Text(text))) => {
                        match serde_json::from_str::<TaskClientMessage>(&text) {
                            Ok(TaskClientMessage::Subscribe { task_id }) => {
                                // Look up owner_id for this task
                                let owner_id = if let Some(ref pool) = state.db_pool {
                                    get_task_owner_id(pool, task_id).await
                                } else {
                                    None
                                };
                                task_subscriptions.insert(task_id, owner_id);
                                let response = TaskServerMessage::Subscribed { task_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id);
                            }
                            Ok(TaskClientMessage::Unsubscribe { task_id }) => {
                                task_subscriptions.remove(&task_id);
                                let response = TaskServerMessage::Unsubscribed { task_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client unsubscribed from task {}", task_id);
                            }
                            Ok(TaskClientMessage::SubscribeAll) => {
                                subscribed_all = true;
                                let response = TaskServerMessage::SubscribedAll;
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client subscribed to all tasks");
                            }
                            Ok(TaskClientMessage::UnsubscribeAll) => {
                                subscribed_all = false;
                                let response = TaskServerMessage::UnsubscribedAll;
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client unsubscribed from all tasks");
                            }
                            Ok(TaskClientMessage::SubscribeOutput { task_id }) => {
                                // Look up owner_id for this task
                                let owner_id = if let Some(ref pool) = state.db_pool {
                                    get_task_owner_id(pool, task_id).await
                                } else {
                                    None
                                };
                                output_subscriptions.insert(task_id, owner_id);
                                let response = TaskServerMessage::OutputSubscribed { task_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id);
                            }
                            Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => {
                                output_subscriptions.remove(&task_id);
                                let response = TaskServerMessage::OutputUnsubscribed { task_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client unsubscribed from output for task {}", task_id);
                            }
                            Err(e) => {
                                let response = TaskServerMessage::Error {
                                    code: "PARSE_ERROR".into(),
                                    message: e.to_string(),
                                };
                                let json = serde_json::to_string(&response).unwrap();
                                let _ = sender.send(Message::Text(json.into())).await;
                            }
                        }
                    }
                    Some(Ok(Message::Close(_))) | None => {
                        tracing::debug!("Client disconnected from task subscription");
                        break;
                    }
                    Some(Err(e)) => {
                        tracing::warn!("Task WebSocket error: {}", e);
                        break;
                    }
                    _ => {}
                }
            }

            // Handle task update broadcasts
            notification = task_update_rx.recv() => {
                match notification {
                    Ok(notification) => {
                        // Check if client should receive this notification
                        let should_forward = if subscribed_all {
                            // SubscribeAll gets all notifications (typically for admin views)
                            true
                        } else if let Some(subscribed_owner) = task_subscriptions.get(&notification.task_id) {
                            // Client is subscribed to this specific task
                            // Verify owner_id matches (if set on both sides)
                            match (notification.owner_id, subscribed_owner) {
                                (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
                                _ => true, // Allow if owner_id not set on either side
                            }
                        } else {
                            false
                        };

                        if should_forward {
                            let response = TaskServerMessage::TaskUpdated {
                                task_id: notification.task_id,
                                version: notification.version,
                                status: notification.status,
                                updated_fields: notification.updated_fields,
                                updated_by: notification.updated_by,
                            };
                            let json = serde_json::to_string(&response).unwrap();
                            if sender.send(Message::Text(json.into())).await.is_err() {
                                break;
                            }
                        }
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
                        tracing::warn!("Task subscription client lagged, skipped {} messages", n);
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Closed) => {
                        break;
                    }
                }
            }

            // Handle task output broadcasts
            output = task_output_rx.recv() => {
                match output {
                    Ok(output) => {
                        // Check if client should receive this output
                        let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) {
                            // Client is subscribed to output for this task
                            // Verify owner_id matches (if set on both sides)
                            match (output.owner_id, subscribed_owner) {
                                (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner,
                                _ => true, // Allow if owner_id not set on either side
                            }
                        } else {
                            false
                        };

                        if should_forward {
                            let response = TaskServerMessage::TaskOutput {
                                task_id: output.task_id,
                                message_type: output.message_type,
                                content: output.content,
                                tool_name: output.tool_name,
                                tool_input: output.tool_input,
                                is_error: output.is_error,
                                cost_usd: output.cost_usd,
                                duration_ms: output.duration_ms,
                                is_partial: output.is_partial,
                            };
                            let json = serde_json::to_string(&response).unwrap();
                            if sender.send(Message::Text(json.into())).await.is_err() {
                                break;
                            }
                        }
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
                        tracing::warn!("Task output subscription client lagged, skipped {} messages", n);
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Closed) => {
                        break;
                    }
                }
            }
        }
    }
}