summaryrefslogblamecommitdiff
path: root/makima/src/db/repository.rs
blob: 4137ba60d91c5bf589e8806b2c9f358c9b21b874 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11





                                                    
                                                                             



                                                                                             




































                                                                               









                                                                                              
                                                                                             


                              

                                                                                            
                                                                                                                       






                             
                     







                                                                                     
                                                                                                                    













                                                                          
                                                                                                                    









                                





                                                                                          



                           
                                            





                                             









                                                         




                                                                                


                                                                    
 




















































                                                                                                                           


























                                                                                         













































































































































                                                                                                             
//! Repository pattern for file database operations.

use chrono::Utc;
use sqlx::PgPool;
use uuid::Uuid;

use super::models::{CreateFileRequest, File, FileVersion, UpdateFileRequest};

/// Default owner ID for anonymous users.
pub const ANONYMOUS_OWNER_ID: Uuid = Uuid::from_u128(0x00000000_0000_0000_0000_000000000002);

/// Repository error types.
#[derive(Debug)]
pub enum RepositoryError {
    /// Database error
    Database(sqlx::Error),
    /// Version conflict (optimistic locking failure)
    VersionConflict {
        /// The version the client expected
        expected: i32,
        /// The actual current version in the database
        actual: i32,
    },
}

impl From<sqlx::Error> for RepositoryError {
    fn from(e: sqlx::Error) -> Self {
        RepositoryError::Database(e)
    }
}

impl std::fmt::Display for RepositoryError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            RepositoryError::Database(e) => write!(f, "Database error: {}", e),
            RepositoryError::VersionConflict { expected, actual } => {
                write!(
                    f,
                    "Version conflict: expected {}, actual {}",
                    expected, actual
                )
            }
        }
    }
}

impl std::error::Error for RepositoryError {}

/// Generate a default name based on current timestamp.
fn generate_default_name() -> String {
    let now = Utc::now();
    now.format("Recording - %b %d %Y %H:%M:%S").to_string()
}

/// Create a new file record.
pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, sqlx::Error> {
    let name = req.name.unwrap_or_else(generate_default_name);
    let transcript_json = serde_json::to_value(&req.transcript).unwrap_or_default();
    let body_json = serde_json::to_value::<Vec<super::models::BodyElement>>(vec![]).unwrap();

    sqlx::query_as::<_, File>(
        r#"
        INSERT INTO files (owner_id, name, description, transcript, location, summary, body)
        VALUES ($1, $2, $3, $4, $5, NULL, $6)
        RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at
        "#,
    )
    .bind(ANONYMOUS_OWNER_ID)
    .bind(&name)
    .bind(&req.description)
    .bind(&transcript_json)
    .bind(&req.location)
    .bind(&body_json)
    .fetch_one(pool)
    .await
}

/// Get a file by ID.
pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Error> {
    sqlx::query_as::<_, File>(
        r#"
        SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at
        FROM files
        WHERE id = $1 AND owner_id = $2
        "#,
    )
    .bind(id)
    .bind(ANONYMOUS_OWNER_ID)
    .fetch_optional(pool)
    .await
}

/// List all files for the owner, ordered by created_at DESC.
pub async fn list_files(pool: &PgPool) -> Result<Vec<File>, sqlx::Error> {
    sqlx::query_as::<_, File>(
        r#"
        SELECT id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at
        FROM files
        WHERE owner_id = $1
        ORDER BY created_at DESC
        "#,
    )
    .bind(ANONYMOUS_OWNER_ID)
    .fetch_all(pool)
    .await
}

/// Update a file by ID with optimistic locking.
///
/// If `req.version` is provided, the update will only succeed if the current
/// version matches. Returns `RepositoryError::VersionConflict` if there's a mismatch.
///
/// If `req.version` is None (e.g., internal system updates), version checking is skipped.
pub async fn update_file(
    pool: &PgPool,
    id: Uuid,
    req: UpdateFileRequest,
) -> Result<Option<File>, RepositoryError> {
    // Get the existing file first
    let existing = get_file(pool, id).await?;
    let Some(existing) = existing else {
        return Ok(None);
    };

    // Check version if provided (optimistic locking)
    if let Some(expected_version) = req.version {
        if existing.version != expected_version {
            return Err(RepositoryError::VersionConflict {
                expected: expected_version,
                actual: existing.version,
            });
        }
    }

    // Apply updates
    let name = req.name.unwrap_or(existing.name);
    let description = req.description.or(existing.description);
    let transcript = req.transcript.unwrap_or(existing.transcript);
    let transcript_json = serde_json::to_value(&transcript).unwrap_or_default();
    let summary = req.summary.or(existing.summary);
    let body = req.body.unwrap_or(existing.body);
    let body_json = serde_json::to_value(&body).unwrap_or_default();

    // Update with version check in WHERE clause for race condition safety
    let result = if req.version.is_some() {
        sqlx::query_as::<_, File>(
            r#"
            UPDATE files
            SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW()
            WHERE id = $1 AND owner_id = $2 AND version = $8
            RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at
            "#,
        )
        .bind(id)
        .bind(ANONYMOUS_OWNER_ID)
        .bind(&name)
        .bind(&description)
        .bind(&transcript_json)
        .bind(&summary)
        .bind(&body_json)
        .bind(req.version.unwrap())
        .fetch_optional(pool)
        .await?
    } else {
        // No version check for internal updates
        sqlx::query_as::<_, File>(
            r#"
            UPDATE files
            SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW()
            WHERE id = $1 AND owner_id = $2
            RETURNING id, owner_id, name, description, transcript, location, summary, body, version, created_at, updated_at
            "#,
        )
        .bind(id)
        .bind(ANONYMOUS_OWNER_ID)
        .bind(&name)
        .bind(&description)
        .bind(&transcript_json)
        .bind(&summary)
        .bind(&body_json)
        .fetch_optional(pool)
        .await?
    };

    // If versioned update returned None, there was a race condition
    if result.is_none() && req.version.is_some() {
        // Re-fetch to get the actual version
        if let Some(current) = get_file(pool, id).await? {
            return Err(RepositoryError::VersionConflict {
                expected: req.version.unwrap(),
                actual: current.version,
            });
        }
    }

    Ok(result)
}

/// Delete a file by ID.
pub async fn delete_file(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
    let result = sqlx::query(
        r#"
        DELETE FROM files
        WHERE id = $1 AND owner_id = $2
        "#,
    )
    .bind(id)
    .bind(ANONYMOUS_OWNER_ID)
    .execute(pool)
    .await?;

    Ok(result.rows_affected() > 0)
}

/// Count total files for owner.
pub async fn count_files(pool: &PgPool) -> Result<i64, sqlx::Error> {
    let result: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM files WHERE owner_id = $1")
        .bind(ANONYMOUS_OWNER_ID)
        .fetch_one(pool)
        .await?;

    Ok(result.0)
}

// =============================================================================
// Version History Functions
// =============================================================================

/// Set the version source for the current transaction.
/// This is used by the trigger to record who made the change.
pub async fn set_version_source(pool: &PgPool, source: &str) -> Result<(), sqlx::Error> {
    sqlx::query(&format!("SET LOCAL app.version_source = '{}'", source))
        .execute(pool)
        .await?;
    Ok(())
}

/// Set the change description for the current transaction.
pub async fn set_change_description(pool: &PgPool, description: &str) -> Result<(), sqlx::Error> {
    // Escape single quotes for SQL
    let escaped = description.replace('\'', "''");
    sqlx::query(&format!("SET LOCAL app.change_description = '{}'", escaped))
        .execute(pool)
        .await?;
    Ok(())
}

/// List all versions of a file, ordered by version DESC.
pub async fn list_file_versions(pool: &PgPool, file_id: Uuid) -> Result<Vec<FileVersion>, sqlx::Error> {
    // First get the current version from the files table
    let current = get_file(pool, file_id).await?;

    let mut versions = sqlx::query_as::<_, FileVersion>(
        r#"
        SELECT id, file_id, version, name, description, summary, body, source, change_description, created_at
        FROM file_versions
        WHERE file_id = $1
        ORDER BY version DESC
        "#,
    )
    .bind(file_id)
    .fetch_all(pool)
    .await?;

    // Add the current version as the first entry if it exists
    if let Some(file) = current {
        let current_version = FileVersion {
            id: file.id,
            file_id: file.id,
            version: file.version,
            name: file.name,
            description: file.description,
            summary: file.summary,
            body: file.body,
            source: "user".to_string(), // Current version source
            change_description: None,
            created_at: file.updated_at,
        };
        versions.insert(0, current_version);
    }

    Ok(versions)
}

/// Get a specific version of a file.
pub async fn get_file_version(
    pool: &PgPool,
    file_id: Uuid,
    version: i32,
) -> Result<Option<FileVersion>, sqlx::Error> {
    // First check if this is the current version
    if let Some(file) = get_file(pool, file_id).await? {
        if file.version == version {
            return Ok(Some(FileVersion {
                id: file.id,
                file_id: file.id,
                version: file.version,
                name: file.name,
                description: file.description,
                summary: file.summary,
                body: file.body,
                source: "user".to_string(),
                change_description: None,
                created_at: file.updated_at,
            }));
        }
    }

    // Otherwise, look in the versions table
    sqlx::query_as::<_, FileVersion>(
        r#"
        SELECT id, file_id, version, name, description, summary, body, source, change_description, created_at
        FROM file_versions
        WHERE file_id = $1 AND version = $2
        "#,
    )
    .bind(file_id)
    .bind(version)
    .fetch_optional(pool)
    .await
}

/// Restore a file to a previous version.
/// This creates a new version with the content from the target version.
pub async fn restore_file_version(
    pool: &PgPool,
    file_id: Uuid,
    target_version: i32,
    current_version: i32,
) -> Result<Option<File>, RepositoryError> {
    // Get the target version content
    let target = get_file_version(pool, file_id, target_version).await?;
    let Some(target) = target else {
        return Ok(None);
    };

    // Set version source and description for the trigger
    set_version_source(pool, "system").await?;
    set_change_description(pool, &format!("Restored from version {}", target_version)).await?;

    // Update the file with the target version's content
    // This will trigger the save_file_version trigger to save the current state first
    let update_req = UpdateFileRequest {
        name: Some(target.name),
        description: target.description,
        transcript: None,
        summary: target.summary,
        body: Some(target.body),
        version: Some(current_version),
    };

    update_file(pool, file_id, update_req).await
}

/// Count versions for a file.
pub async fn count_file_versions(pool: &PgPool, file_id: Uuid) -> Result<i64, sqlx::Error> {
    let result: (i64,) = sqlx::query_as(
        "SELECT COUNT(*) + 1 FROM file_versions WHERE file_id = $1", // +1 for current version
    )
    .bind(file_id)
    .fetch_one(pool)
    .await?;

    Ok(result.0)
}