diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/db/models.rs | 7 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 132 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 46 | ||||
| -rw-r--r-- | makima/src/server/handlers/file_ws.rs | 163 | ||||
| -rw-r--r-- | makima/src/server/handlers/files.rs | 57 | ||||
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 1 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 1 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 3 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 30 |
9 files changed, 402 insertions, 38 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 135ae75..8204b86 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -68,6 +68,8 @@ pub struct File { /// Structured body content (headings, paragraphs, charts) #[sqlx(json)] pub body: Vec<BodyElement>, + /// Version number for optimistic locking + pub version: i32, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, } @@ -100,6 +102,8 @@ pub struct UpdateFileRequest { pub summary: Option<String>, /// Structured body content (optional) pub body: Option<Vec<BodyElement>>, + /// Version for optimistic locking (required for updates from frontend) + pub version: Option<i32>, } /// Response for file list endpoint. @@ -120,6 +124,8 @@ pub struct FileSummary { pub transcript_count: usize, /// Duration derived from last transcript end time pub duration: Option<f32>, + /// Version number for optimistic locking + pub version: i32, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, } @@ -137,6 +143,7 @@ impl From<File> for FileSummary { description: file.description, transcript_count: file.transcript.len(), duration: if duration > 0.0 { Some(duration) } else { None }, + version: file.version, created_at: file.created_at, updated_at: file.updated_at, } diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index f8b90b3..5b962ee 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -9,6 +9,43 @@ use super::models::{CreateFileRequest, File, UpdateFileRequest}; /// Default owner ID for anonymous users. pub const ANONYMOUS_OWNER_ID: Uuid = Uuid::from_u128(0x00000000_0000_0000_0000_000000000002); +/// Repository error types. +#[derive(Debug)] +pub enum RepositoryError { + /// Database error + Database(sqlx::Error), + /// Version conflict (optimistic locking failure) + VersionConflict { + /// The version the client expected + expected: i32, + /// The actual current version in the database + actual: i32, + }, +} + +impl From<sqlx::Error> for RepositoryError { + fn from(e: sqlx::Error) -> Self { + RepositoryError::Database(e) + } +} + +impl std::fmt::Display for RepositoryError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RepositoryError::Database(e) => write!(f, "Database error: {}", e), + RepositoryError::VersionConflict { expected, actual } => { + write!( + f, + "Version conflict: expected {}, actual {}", + expected, actual + ) + } + } + } +} + +impl std::error::Error for RepositoryError {} + /// Generate a default name based on current timestamp. fn generate_default_name() -> String { let now = Utc::now(); @@ -25,7 +62,7 @@ pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, r#" INSERT INTO files (owner_id, name, description, transcript, location, summary, body) VALUES ($1, $2, $3, $4, $5, NULL, $6) - RETURNING id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at "#, ) .bind(ANONYMOUS_OWNER_ID) @@ -42,7 +79,7 @@ pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Error> { sqlx::query_as::<_, File>( r#" - SELECT id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at + SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at FROM files WHERE id = $1 AND owner_id = $2 "#, @@ -57,7 +94,7 @@ pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Err pub async fn list_files(pool: &PgPool) -> Result<Vec<File>, sqlx::Error> { sqlx::query_as::<_, File>( r#" - SELECT id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at + SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at FROM files WHERE owner_id = $1 ORDER BY created_at DESC @@ -68,18 +105,33 @@ pub async fn list_files(pool: &PgPool) -> Result<Vec<File>, sqlx::Error> { .await } -/// Update a file by ID. +/// Update a file by ID with optimistic locking. +/// +/// If `req.version` is provided, the update will only succeed if the current +/// version matches. Returns `RepositoryError::VersionConflict` if there's a mismatch. +/// +/// If `req.version` is None (e.g., internal system updates), version checking is skipped. pub async fn update_file( pool: &PgPool, id: Uuid, req: UpdateFileRequest, -) -> Result<Option<File>, sqlx::Error> { +) -> Result<Option<File>, RepositoryError> { // Get the existing file first let existing = get_file(pool, id).await?; let Some(existing) = existing else { return Ok(None); }; + // Check version if provided (optimistic locking) + if let Some(expected_version) = req.version { + if existing.version != expected_version { + return Err(RepositoryError::VersionConflict { + expected: expected_version, + actual: existing.version, + }); + } + } + // Apply updates let name = req.name.unwrap_or(existing.name); let description = req.description.or(existing.description); @@ -89,23 +141,59 @@ pub async fn update_file( let body = req.body.unwrap_or(existing.body); let body_json = serde_json::to_value(&body).unwrap_or_default(); - sqlx::query_as::<_, File>( - r#" - UPDATE files - SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() - WHERE id = $1 AND owner_id = $2 - RETURNING id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at - "#, - ) - .bind(id) - .bind(ANONYMOUS_OWNER_ID) - .bind(&name) - .bind(&description) - .bind(&transcript_json) - .bind(&summary) - .bind(&body_json) - .fetch_optional(pool) - .await + // Update with version check in WHERE clause for race condition safety + let result = if req.version.is_some() { + sqlx::query_as::<_, File>( + r#" + UPDATE files + SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 AND version = $8 + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + "#, + ) + .bind(id) + .bind(ANONYMOUS_OWNER_ID) + .bind(&name) + .bind(&description) + .bind(&transcript_json) + .bind(&summary) + .bind(&body_json) + .bind(req.version.unwrap()) + .fetch_optional(pool) + .await? + } else { + // No version check for internal updates + sqlx::query_as::<_, File>( + r#" + UPDATE files + SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() + WHERE id = $1 AND owner_id = $2 + RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at + "#, + ) + .bind(id) + .bind(ANONYMOUS_OWNER_ID) + .bind(&name) + .bind(&description) + .bind(&transcript_json) + .bind(&summary) + .bind(&body_json) + .fetch_optional(pool) + .await? + }; + + // If versioned update returned None, there was a race condition + if result.is_none() && req.version.is_some() { + // Re-fetch to get the actual version + if let Some(current) = get_file(pool, id).await? { + return Err(RepositoryError::VersionConflict { + expected: req.version.unwrap(), + actual: current.version, + }); + } + } + + Ok(result) } /// Delete a file by ID. diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index 92c4ec8..3bdbc74 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -17,7 +17,7 @@ use crate::llm::{ groq::{GroqClient, GroqError, Message, ToolCallResponse}, LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS, }; -use crate::server::state::SharedState; +use crate::server::state::{FileUpdateNotification, SharedState}; /// Maximum number of tool-calling rounds to prevent infinite loops const MAX_TOOL_ROUNDS: usize = 10; @@ -385,17 +385,43 @@ pub async fn chat_handler( transcript: None, summary: current_summary.clone(), body: Some(current_body.clone()), + version: None, // Internal update, skip version check }; - if let Err(e) = repository::update_file(pool, id, update_req).await { - tracing::error!("Failed to save file changes: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("Failed to save changes: {}", e) - })), - ) - .into_response(); + match repository::update_file(pool, id, update_req).await { + Ok(Some(updated_file)) => { + // Broadcast update notification for LLM changes + let mut updated_fields = vec!["body".to_string()]; + if current_summary.is_some() { + updated_fields.push("summary".to_string()); + } + state.broadcast_file_update(FileUpdateNotification { + file_id: id, + version: updated_file.version, + updated_fields, + updated_by: "llm".to_string(), + }); + } + Ok(None) => { + // File was deleted during processing + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ + "error": "File not found" + })), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to save file changes: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Failed to save changes: {}", e) + })), + ) + .into_response(); + } } } diff --git a/makima/src/server/handlers/file_ws.rs b/makima/src/server/handlers/file_ws.rs new file mode 100644 index 0000000..5a44309 --- /dev/null +++ b/makima/src/server/handlers/file_ws.rs @@ -0,0 +1,163 @@ +//! WebSocket handler for file change subscriptions. +//! +//! Clients can subscribe to specific files and receive real-time notifications +//! when those files are updated by any source (user edits, LLM modifications, etc.). + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use uuid::Uuid; + +use crate::server::state::SharedState; + +/// Client message for file subscription management. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum FileClientMessage { + /// Subscribe to updates for a specific file + Subscribe { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// Unsubscribe from updates for a specific file + Unsubscribe { + #[serde(rename = "fileId")] + file_id: Uuid, + }, +} + +/// Server message for file subscription WebSocket. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum FileServerMessage { + /// Subscription confirmed + Subscribed { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// Unsubscription confirmed + Unsubscribed { + #[serde(rename = "fileId")] + file_id: Uuid, + }, + /// File was updated + FileUpdated { + #[serde(rename = "fileId")] + file_id: Uuid, + version: i32, + #[serde(rename = "updatedFields")] + updated_fields: Vec<String>, + #[serde(rename = "updatedBy")] + updated_by: String, + }, + /// Error occurred + Error { code: String, message: String }, +} + +/// WebSocket upgrade handler for file subscriptions. +#[utoipa::path( + get, + path = "/api/v1/files/subscribe", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "Files" +)] +pub async fn file_subscription_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_file_subscription(socket, state)) +} + +async fn handle_file_subscription(socket: WebSocket, state: SharedState) { + let (mut sender, mut receiver) = socket.split(); + + // Set of file IDs this client is subscribed to + let mut subscriptions: HashSet<Uuid> = HashSet::new(); + + // Subscribe to the broadcast channel + let mut broadcast_rx = state.file_updates.subscribe(); + + loop { + tokio::select! { + // Handle incoming WebSocket messages from client + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<FileClientMessage>(&text) { + Ok(FileClientMessage::Subscribe { file_id }) => { + subscriptions.insert(file_id); + let response = FileServerMessage::Subscribed { file_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to file {}", file_id); + } + Ok(FileClientMessage::Unsubscribe { file_id }) => { + subscriptions.remove(&file_id); + let response = FileServerMessage::Unsubscribed { file_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from file {}", file_id); + } + Err(e) => { + let response = FileServerMessage::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Client disconnected from file subscription"); + break; + } + Some(Err(e)) => { + tracing::warn!("WebSocket error: {}", e); + break; + } + _ => {} + } + } + + // Handle broadcast notifications + notification = broadcast_rx.recv() => { + match notification { + Ok(notification) => { + // Only forward if client is subscribed to this file + if subscriptions.contains(¬ification.file_id) { + let response = FileServerMessage::FileUpdated { + file_id: notification.file_id, + version: notification.version, + updated_fields: notification.updated_fields, + updated_by: notification.updated_by, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + // Client is too slow, skip some messages + tracing::warn!("File subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + // Channel closed, exit + break; + } + } + } + } + } +} diff --git a/makima/src/server/handlers/files.rs b/makima/src/server/handlers/files.rs index 746d66b..c65eed5 100644 --- a/makima/src/server/handlers/files.rs +++ b/makima/src/server/handlers/files.rs @@ -9,9 +9,9 @@ use axum::{ use uuid::Uuid; use crate::db::models::{CreateFileRequest, FileListResponse, FileSummary, UpdateFileRequest}; -use crate::db::repository; +use crate::db::repository::{self, RepositoryError}; use crate::server::messages::ApiError; -use crate::server::state::SharedState; +use crate::server::state::{FileUpdateNotification, SharedState}; /// List all files for the current owner. #[utoipa::path( @@ -148,6 +148,7 @@ pub async fn create_file( responses( (status = 200, description = "File updated", body = crate::db::models::File), (status = 404, description = "File not found", body = ApiError), + (status = 409, description = "Version conflict", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), @@ -166,14 +167,62 @@ pub async fn update_file( .into_response(); }; + // Collect which fields are being updated for broadcast + let mut updated_fields = Vec::new(); + if req.name.is_some() { + updated_fields.push("name".to_string()); + } + if req.description.is_some() { + updated_fields.push("description".to_string()); + } + if req.transcript.is_some() { + updated_fields.push("transcript".to_string()); + } + if req.summary.is_some() { + updated_fields.push("summary".to_string()); + } + if req.body.is_some() { + updated_fields.push("body".to_string()); + } + match repository::update_file(pool, id, req).await { - Ok(Some(file)) => Json(file).into_response(), + Ok(Some(file)) => { + // Broadcast update notification + state.broadcast_file_update(FileUpdateNotification { + file_id: id, + version: file.version, + updated_fields, + updated_by: "user".to_string(), + }); + Json(file).into_response() + } Ok(None) => ( StatusCode::NOT_FOUND, Json(ApiError::new("NOT_FOUND", "File not found")), ) .into_response(), - Err(e) => { + Err(RepositoryError::VersionConflict { expected, actual }) => { + tracing::info!( + "Version conflict on file {}: expected {}, actual {}", + id, + expected, + actual + ); + ( + StatusCode::CONFLICT, + Json(serde_json::json!({ + "code": "VERSION_CONFLICT", + "message": format!( + "File was modified by another user. Expected version {}, actual version {}", + expected, actual + ), + "expectedVersion": expected, + "actualVersion": actual, + })), + ) + .into_response() + } + Err(RepositoryError::Database(e)) => { tracing::error!("Failed to update file {}: {}", id, e); ( StatusCode::INTERNAL_SERVER_ERROR, diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs index 5fc5cea..a26c208 100644 --- a/makima/src/server/handlers/listen.rs +++ b/makima/src/server/handlers/listen.rs @@ -467,6 +467,7 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { transcript: Some(final_entries.clone()), summary: None, body: None, + version: None, // Internal update, skip version check }).await { Ok(_) => { tracing::info!( diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index b13668a..c08f1bd 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,5 +1,6 @@ //! HTTP and WebSocket request handlers. pub mod chat; +pub mod file_ws; pub mod files; pub mod listen; diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs index a8f98a6..f132cf4 100644 --- a/makima/src/server/mod.rs +++ b/makima/src/server/mod.rs @@ -17,7 +17,7 @@ use tower_http::trace::TraceLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::server::handlers::{chat, files, listen}; +use crate::server::handlers::{chat, file_ws, files, listen}; use crate::server::openapi::ApiDoc; use crate::server::state::SharedState; @@ -43,6 +43,7 @@ pub fn make_router(state: SharedState) -> Router { // API v1 routes let api_v1 = Router::new() .route("/listen", get(listen::websocket_handler)) + .route("/files/subscribe", get(file_ws::file_subscription_handler)) .route("/files", get(files::list_files).post(files::create_file)) .route( "/files/{id}", diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index 8cdc26c..239ab77 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -2,10 +2,24 @@ use std::sync::Arc; use sqlx::PgPool; -use tokio::sync::Mutex; +use tokio::sync::{broadcast, Mutex}; +use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; +/// Notification payload for file updates (broadcast to WebSocket subscribers). +#[derive(Debug, Clone)] +pub struct FileUpdateNotification { + /// ID of the updated file + pub file_id: Uuid, + /// New version number after update + pub version: i32, + /// List of fields that were updated + pub updated_fields: Vec<String>, + /// Source of the update: "user", "llm", or "system" + pub updated_by: String, +} + /// Shared application state containing ML models and database pool. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. @@ -18,6 +32,8 @@ pub struct AppState { pub sortformer: Mutex<Sortformer>, /// Optional database connection pool pub db_pool: Option<PgPool>, + /// Broadcast channel for file update notifications + pub file_updates: broadcast::Sender<FileUpdateNotification>, } impl AppState { @@ -40,11 +56,15 @@ impl AppState { DiarizationConfig::callhome(), )?; + // Create broadcast channel with buffer for 256 messages + let (file_updates, _) = broadcast::channel(256); + Ok(Self { parakeet: Mutex::new(parakeet), parakeet_eou: Mutex::new(parakeet_eou), sortformer: Mutex::new(sortformer), db_pool: None, + file_updates, }) } @@ -53,6 +73,14 @@ impl AppState { self.db_pool = Some(pool); self } + + /// Broadcast a file update notification to all subscribers. + /// + /// This is a no-op if there are no subscribers (ignores send errors). + pub fn broadcast_file_update(&self, notification: FileUpdateNotification) { + // Ignore send errors - they just mean no one is listening + let _ = self.file_updates.send(notification); + } } /// Type alias for the shared application state. |
