//! WebSocket handler for file change subscriptions. //! //! Clients can subscribe to specific files and receive real-time notifications //! when those files are updated by any source (user edits, LLM modifications, etc.). use axum::{ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, response::Response, }; use futures::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use uuid::Uuid; use crate::server::state::SharedState; /// Client message for file subscription management. #[derive(Debug, Clone, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum FileClientMessage { /// Subscribe to updates for a specific file Subscribe { #[serde(rename = "fileId")] file_id: Uuid, }, /// Unsubscribe from updates for a specific file Unsubscribe { #[serde(rename = "fileId")] file_id: Uuid, }, } /// Server message for file subscription WebSocket. #[derive(Debug, Clone, Serialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum FileServerMessage { /// Subscription confirmed Subscribed { #[serde(rename = "fileId")] file_id: Uuid, }, /// Unsubscription confirmed Unsubscribed { #[serde(rename = "fileId")] file_id: Uuid, }, /// File was updated FileUpdated { #[serde(rename = "fileId")] file_id: Uuid, version: i32, #[serde(rename = "updatedFields")] updated_fields: Vec, #[serde(rename = "updatedBy")] updated_by: String, }, /// Error occurred Error { code: String, message: String }, } /// WebSocket upgrade handler for file subscriptions. #[utoipa::path( get, path = "/api/v1/files/subscribe", responses( (status = 101, description = "WebSocket connection established"), ), tag = "Files" )] pub async fn file_subscription_handler( ws: WebSocketUpgrade, State(state): State, ) -> Response { ws.on_upgrade(|socket| handle_file_subscription(socket, state)) } async fn handle_file_subscription(socket: WebSocket, state: SharedState) { let (mut sender, mut receiver) = socket.split(); // Set of file IDs this client is subscribed to let mut subscriptions: HashSet = HashSet::new(); // Subscribe to the broadcast channel let mut broadcast_rx = state.file_updates.subscribe(); loop { tokio::select! { // Handle incoming WebSocket messages from client msg = receiver.next() => { match msg { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(FileClientMessage::Subscribe { file_id }) => { subscriptions.insert(file_id); let response = FileServerMessage::Subscribed { file_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client subscribed to file {}", file_id); } Ok(FileClientMessage::Unsubscribe { file_id }) => { subscriptions.remove(&file_id); let response = FileServerMessage::Unsubscribed { file_id }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } tracing::debug!("Client unsubscribed from file {}", file_id); } Err(e) => { let response = FileServerMessage::Error { code: "PARSE_ERROR".into(), message: e.to_string(), }; let json = serde_json::to_string(&response).unwrap(); let _ = sender.send(Message::Text(json.into())).await; } } } Some(Ok(Message::Close(_))) | None => { tracing::debug!("Client disconnected from file subscription"); break; } Some(Err(e)) => { tracing::warn!("WebSocket error: {}", e); break; } _ => {} } } // Handle broadcast notifications notification = broadcast_rx.recv() => { match notification { Ok(notification) => { // Only forward if client is subscribed to this file if subscriptions.contains(¬ification.file_id) { let response = FileServerMessage::FileUpdated { file_id: notification.file_id, version: notification.version, updated_fields: notification.updated_fields, updated_by: notification.updated_by, }; let json = serde_json::to_string(&response).unwrap(); if sender.send(Message::Text(json.into())).await.is_err() { break; } } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { // Client is too slow, skip some messages tracing::warn!("File subscription client lagged, skipped {} messages", n); } Err(tokio::sync::broadcast::error::RecvError::Closed) => { // Channel closed, exit break; } } } } } }