summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/file_ws.rs
blob: 5a4430958dc1dcb39ba42e2f2995635fc3acd80d (plain) (tree)


































































































































































                                                                                                  
//! WebSocket handler for file change subscriptions.
//!
//! Clients can subscribe to specific files and receive real-time notifications
//! when those files are updated by any source (user edits, LLM modifications, etc.).

use axum::{
    extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
    response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use uuid::Uuid;

use crate::server::state::SharedState;

/// Client message for file subscription management.
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum FileClientMessage {
    /// Subscribe to updates for a specific file
    Subscribe {
        #[serde(rename = "fileId")]
        file_id: Uuid,
    },
    /// Unsubscribe from updates for a specific file
    Unsubscribe {
        #[serde(rename = "fileId")]
        file_id: Uuid,
    },
}

/// Server message for file subscription WebSocket.
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum FileServerMessage {
    /// Subscription confirmed
    Subscribed {
        #[serde(rename = "fileId")]
        file_id: Uuid,
    },
    /// Unsubscription confirmed
    Unsubscribed {
        #[serde(rename = "fileId")]
        file_id: Uuid,
    },
    /// File was updated
    FileUpdated {
        #[serde(rename = "fileId")]
        file_id: Uuid,
        version: i32,
        #[serde(rename = "updatedFields")]
        updated_fields: Vec<String>,
        #[serde(rename = "updatedBy")]
        updated_by: String,
    },
    /// Error occurred
    Error { code: String, message: String },
}

/// WebSocket upgrade handler for file subscriptions.
#[utoipa::path(
    get,
    path = "/api/v1/files/subscribe",
    responses(
        (status = 101, description = "WebSocket connection established"),
    ),
    tag = "Files"
)]
pub async fn file_subscription_handler(
    ws: WebSocketUpgrade,
    State(state): State<SharedState>,
) -> Response {
    ws.on_upgrade(|socket| handle_file_subscription(socket, state))
}

async fn handle_file_subscription(socket: WebSocket, state: SharedState) {
    let (mut sender, mut receiver) = socket.split();

    // Set of file IDs this client is subscribed to
    let mut subscriptions: HashSet<Uuid> = HashSet::new();

    // Subscribe to the broadcast channel
    let mut broadcast_rx = state.file_updates.subscribe();

    loop {
        tokio::select! {
            // Handle incoming WebSocket messages from client
            msg = receiver.next() => {
                match msg {
                    Some(Ok(Message::Text(text))) => {
                        match serde_json::from_str::<FileClientMessage>(&text) {
                            Ok(FileClientMessage::Subscribe { file_id }) => {
                                subscriptions.insert(file_id);
                                let response = FileServerMessage::Subscribed { file_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client subscribed to file {}", file_id);
                            }
                            Ok(FileClientMessage::Unsubscribe { file_id }) => {
                                subscriptions.remove(&file_id);
                                let response = FileServerMessage::Unsubscribed { file_id };
                                let json = serde_json::to_string(&response).unwrap();
                                if sender.send(Message::Text(json.into())).await.is_err() {
                                    break;
                                }
                                tracing::debug!("Client unsubscribed from file {}", file_id);
                            }
                            Err(e) => {
                                let response = FileServerMessage::Error {
                                    code: "PARSE_ERROR".into(),
                                    message: e.to_string(),
                                };
                                let json = serde_json::to_string(&response).unwrap();
                                let _ = sender.send(Message::Text(json.into())).await;
                            }
                        }
                    }
                    Some(Ok(Message::Close(_))) | None => {
                        tracing::debug!("Client disconnected from file subscription");
                        break;
                    }
                    Some(Err(e)) => {
                        tracing::warn!("WebSocket error: {}", e);
                        break;
                    }
                    _ => {}
                }
            }

            // Handle broadcast notifications
            notification = broadcast_rx.recv() => {
                match notification {
                    Ok(notification) => {
                        // Only forward if client is subscribed to this file
                        if subscriptions.contains(&notification.file_id) {
                            let response = FileServerMessage::FileUpdated {
                                file_id: notification.file_id,
                                version: notification.version,
                                updated_fields: notification.updated_fields,
                                updated_by: notification.updated_by,
                            };
                            let json = serde_json::to_string(&response).unwrap();
                            if sender.send(Message::Text(json.into())).await.is_err() {
                                break;
                            }
                        }
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
                        // Client is too slow, skip some messages
                        tracing::warn!("File subscription client lagged, skipped {} messages", n);
                    }
                    Err(tokio::sync::broadcast::error::RecvError::Closed) => {
                        // Channel closed, exit
                        break;
                    }
                }
            }
        }
    }
}