//! WebSocket handler for file change subscriptions.
//!
//! Clients can subscribe to specific files and receive real-time notifications
//! when those files are updated by any source (user edits, LLM modifications, etc.).
use axum::{
extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use uuid::Uuid;
use crate::server::state::SharedState;
/// Client message for file subscription management.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum FileClientMessage {
/// Subscribe to updates for a specific file
Subscribe {
#[serde(rename = "fileId")]
file_id: Uuid,
},
/// Unsubscribe from updates for a specific file
Unsubscribe {
#[serde(rename = "fileId")]
file_id: Uuid,
},
}
/// Server message for file subscription WebSocket.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum FileServerMessage {
/// Subscription confirmed
Subscribed {
#[serde(rename = "fileId")]
file_id: Uuid,
},
/// Unsubscription confirmed
Unsubscribed {
#[serde(rename = "fileId")]
file_id: Uuid,
},
/// File was updated
FileUpdated {
#[serde(rename = "fileId")]
file_id: Uuid,
version: i32,
#[serde(rename = "updatedFields")]
updated_fields: Vec<String>,
#[serde(rename = "updatedBy")]
updated_by: String,
},
/// Error occurred
Error { code: String, message: String },
}
/// WebSocket upgrade handler for file subscriptions.
#[utoipa::path(
get,
path = "/api/v1/files/subscribe",
responses(
(status = 101, description = "WebSocket connection established"),
),
tag = "Files"
)]
pub async fn file_subscription_handler(
ws: WebSocketUpgrade,
State(state): State<SharedState>,
) -> Response {
ws.on_upgrade(|socket| handle_file_subscription(socket, state))
}
async fn handle_file_subscription(socket: WebSocket, state: SharedState) {
let (mut sender, mut receiver) = socket.split();
// Set of file IDs this client is subscribed to
let mut subscriptions: HashSet<Uuid> = HashSet::new();
// Subscribe to the broadcast channel
let mut broadcast_rx = state.file_updates.subscribe();
loop {
tokio::select! {
// Handle incoming WebSocket messages from client
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<FileClientMessage>(&text) {
Ok(FileClientMessage::Subscribe { file_id }) => {
subscriptions.insert(file_id);
let response = FileServerMessage::Subscribed { file_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client subscribed to file {}", file_id);
}
Ok(FileClientMessage::Unsubscribe { file_id }) => {
subscriptions.remove(&file_id);
let response = FileServerMessage::Unsubscribed { file_id };
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
tracing::debug!("Client unsubscribed from file {}", file_id);
}
Err(e) => {
let response = FileServerMessage::Error {
code: "PARSE_ERROR".into(),
message: e.to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
Some(Ok(Message::Close(_))) | None => {
tracing::debug!("Client disconnected from file subscription");
break;
}
Some(Err(e)) => {
tracing::warn!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
// Handle broadcast notifications
notification = broadcast_rx.recv() => {
match notification {
Ok(notification) => {
// Only forward if client is subscribed to this file
if subscriptions.contains(¬ification.file_id) {
let response = FileServerMessage::FileUpdated {
file_id: notification.file_id,
version: notification.version,
updated_fields: notification.updated_fields,
updated_by: notification.updated_by,
};
let json = serde_json::to_string(&response).unwrap();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
// Client is too slow, skip some messages
tracing::warn!("File subscription client lagged, skipped {} messages", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
// Channel closed, exit
break;
}
}
}
}
}
}