diff options
Diffstat (limited to 'makima/src/server/handlers/file_ws.rs')
| -rw-r--r-- | makima/src/server/handlers/file_ws.rs | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/makima/src/server/handlers/file_ws.rs b/makima/src/server/handlers/file_ws.rs new file mode 100644 index 0000000..5a44309 --- /dev/null +++ b/makima/src/server/handlers/file_ws.rs @@ -0,0 +1,163 @@ +//! WebSocket handler for file change subscriptions. +//! +//! Clients can subscribe to specific files and receive real-time notifications +//! when those files are updated by any source (user edits, LLM modifications, etc.). + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use uuid::Uuid; + +use crate::server::state::SharedState; + +/// Client message for file subscription management. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum FileClientMessage { + /// Subscribe to updates for a specific file + Subscribe { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// Unsubscribe from updates for a specific file + Unsubscribe { + #[serde(rename = "fileId")] + file_id: Uuid, + }, +} + +/// Server message for file subscription WebSocket. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum FileServerMessage { + /// Subscription confirmed + Subscribed { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// Unsubscription confirmed + Unsubscribed { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// File was updated + FileUpdated { + #[serde(rename = "fileId")] + file_id: Uuid, + version: i32, + #[serde(rename = "updatedFields")] + updated_fields: Vec<String>, + #[serde(rename = "updatedBy")] + updated_by: String, + }, + /// Error occurred + Error { code: String, message: String }, +} + +/// WebSocket upgrade handler for file subscriptions. +#[utoipa::path( + get, + path = "/api/v1/files/subscribe", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "Files" +)] +pub async fn file_subscription_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_file_subscription(socket, state)) +} + +async fn handle_file_subscription(socket: WebSocket, state: SharedState) { + let (mut sender, mut receiver) = socket.split(); + + // Set of file IDs this client is subscribed to + let mut subscriptions: HashSet<Uuid> = HashSet::new(); + + // Subscribe to the broadcast channel + let mut broadcast_rx = state.file_updates.subscribe(); + + loop { + tokio::select! { + // Handle incoming WebSocket messages from client + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<FileClientMessage>(&text) { + Ok(FileClientMessage::Subscribe { file_id }) => { + subscriptions.insert(file_id); + let response = FileServerMessage::Subscribed { file_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to file {}", file_id); + } + Ok(FileClientMessage::Unsubscribe { file_id }) => { + subscriptions.remove(&file_id); + let response = FileServerMessage::Unsubscribed { file_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from file {}", file_id); + } + Err(e) => { + let response = FileServerMessage::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Client disconnected from file subscription"); + break; + } + Some(Err(e)) => { + tracing::warn!("WebSocket error: {}", e); + break; + } + _ => {} + } + } + + // Handle broadcast notifications + notification = broadcast_rx.recv() => { + match notification { + Ok(notification) => { + // Only forward if client is subscribed to this file + if subscriptions.contains(¬ification.file_id) { + let response = FileServerMessage::FileUpdated { + file_id: notification.file_id, + version: notification.version, + updated_fields: notification.updated_fields, + updated_by: notification.updated_by, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + // Client is too slow, skip some messages + tracing::warn!("File subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + // Channel closed, exit + break; + } + } + } + } + } +} |
