summaryrefslogblamecommitdiff
path: root/makima/src/daemon/db/local.rs
blob: f3ed45ad0851661012787089b0e47ffd7c46649e (plain) (tree)
1
2
3
4
5
6
7
8
9







                                                                   
                                   








































































































































































































































































































































                                                                                                                                                                                      
                         



















































                                                                       
//! Local SQLite database for crash recovery and state persistence.

use std::path::Path;

use chrono::{DateTime, Utc};
use rusqlite::{params, Connection, Result as SqliteResult};
use uuid::Uuid;

use crate::daemon::task::TaskState;

/// Local task record for persistence.
#[derive(Debug, Clone)]
pub struct LocalTask {
    pub id: Uuid,
    pub server_task_id: Uuid,
    pub state: TaskState,
    pub container_id: Option<String>,
    pub overlay_path: Option<String>,
    pub repo_url: Option<String>,
    pub base_branch: Option<String>,
    pub plan: String,
    pub created_at: DateTime<Utc>,
    pub started_at: Option<DateTime<Utc>>,
    pub completed_at: Option<DateTime<Utc>>,
    pub error_message: Option<String>,
}

/// Buffered output for reliable delivery.
#[derive(Debug, Clone)]
pub struct BufferedOutput {
    pub id: i64,
    pub task_id: Uuid,
    pub output: String,
    pub is_partial: bool,
    pub timestamp: DateTime<Utc>,
}

/// Local database for daemon state persistence.
pub struct LocalDb {
    conn: Connection,
}

impl LocalDb {
    /// Open or create the local database.
    pub fn open(path: &Path) -> SqliteResult<Self> {
        // Create parent directory if needed
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent).ok();
        }

        let conn = Connection::open(path)?;

        // Initialize schema
        conn.execute_batch(Self::schema())?;

        Ok(Self { conn })
    }

    /// Open an in-memory database (for testing).
    #[cfg(test)]
    pub fn open_memory() -> SqliteResult<Self> {
        let conn = Connection::open_in_memory()?;
        conn.execute_batch(Self::schema())?;
        Ok(Self { conn })
    }

    /// Database schema.
    fn schema() -> &'static str {
        r#"
        -- Local task state for crash recovery
        CREATE TABLE IF NOT EXISTS tasks (
            id TEXT PRIMARY KEY,
            server_task_id TEXT NOT NULL,
            state TEXT NOT NULL,
            container_id TEXT,
            overlay_path TEXT,
            repo_url TEXT,
            base_branch TEXT,
            plan TEXT NOT NULL,
            created_at TEXT NOT NULL,
            started_at TEXT,
            completed_at TEXT,
            error_message TEXT
        );

        -- Buffered output for reliable delivery
        CREATE TABLE IF NOT EXISTS output_buffer (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            task_id TEXT NOT NULL,
            output TEXT NOT NULL,
            is_partial INTEGER NOT NULL,
            timestamp TEXT NOT NULL,
            sent INTEGER NOT NULL DEFAULT 0
        );

        -- Daemon state key-value store
        CREATE TABLE IF NOT EXISTS daemon_state (
            key TEXT PRIMARY KEY,
            value TEXT NOT NULL,
            updated_at TEXT NOT NULL
        );

        -- Indexes
        CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state);
        CREATE INDEX IF NOT EXISTS idx_output_buffer_sent ON output_buffer(sent, id);
        CREATE INDEX IF NOT EXISTS idx_output_buffer_task ON output_buffer(task_id);
        "#
    }

    /// Save a task.
    pub fn save_task(&self, task: &LocalTask) -> SqliteResult<()> {
        self.conn.execute(
            r#"
            INSERT OR REPLACE INTO tasks
            (id, server_task_id, state, container_id, overlay_path, repo_url, base_branch, plan, created_at, started_at, completed_at, error_message)
            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)
            "#,
            params![
                task.id.to_string(),
                task.server_task_id.to_string(),
                task.state.as_str(),
                task.container_id,
                task.overlay_path,
                task.repo_url,
                task.base_branch,
                task.plan,
                task.created_at.to_rfc3339(),
                task.started_at.map(|t| t.to_rfc3339()),
                task.completed_at.map(|t| t.to_rfc3339()),
                task.error_message,
            ],
        )?;
        Ok(())
    }

    /// Get a task by ID.
    pub fn get_task(&self, id: Uuid) -> SqliteResult<Option<LocalTask>> {
        let mut stmt = self.conn.prepare(
            "SELECT id, server_task_id, state, container_id, overlay_path, repo_url, base_branch, plan, created_at, started_at, completed_at, error_message FROM tasks WHERE id = ?1",
        )?;

        let mut rows = stmt.query(params![id.to_string()])?;

        if let Some(row) = rows.next()? {
            Ok(Some(Self::task_from_row(row)?))
        } else {
            Ok(None)
        }
    }

    /// Get all running/active tasks (for recovery).
    pub fn get_active_tasks(&self) -> SqliteResult<Vec<LocalTask>> {
        let mut stmt = self.conn.prepare(
            r#"
            SELECT id, server_task_id, state, container_id, overlay_path, repo_url, base_branch, plan, created_at, started_at, completed_at, error_message
            FROM tasks
            WHERE state IN ('initializing', 'starting', 'running', 'paused', 'blocked')
            "#,
        )?;

        let rows = stmt.query_map([], |row| Self::task_from_row(row))?;

        rows.collect()
    }

    /// Delete a task.
    pub fn delete_task(&self, id: Uuid) -> SqliteResult<()> {
        self.conn.execute(
            "DELETE FROM tasks WHERE id = ?1",
            params![id.to_string()],
        )?;
        Ok(())
    }

    /// Update task state.
    pub fn update_task_state(&self, id: Uuid, state: TaskState) -> SqliteResult<()> {
        self.conn.execute(
            "UPDATE tasks SET state = ?2 WHERE id = ?1",
            params![id.to_string(), state.as_str()],
        )?;
        Ok(())
    }

    /// Buffer output for reliable delivery.
    pub fn buffer_output(&self, task_id: Uuid, output: &str, is_partial: bool) -> SqliteResult<i64> {
        self.conn.execute(
            r#"
            INSERT INTO output_buffer (task_id, output, is_partial, timestamp, sent)
            VALUES (?1, ?2, ?3, datetime('now'), 0)
            "#,
            params![task_id.to_string(), output, is_partial as i32],
        )?;
        Ok(self.conn.last_insert_rowid())
    }

    /// Get unsent outputs.
    pub fn get_unsent_outputs(&self, limit: i64) -> SqliteResult<Vec<BufferedOutput>> {
        let mut stmt = self.conn.prepare(
            r#"
            SELECT id, task_id, output, is_partial, timestamp
            FROM output_buffer
            WHERE sent = 0
            ORDER BY id
            LIMIT ?1
            "#,
        )?;

        let rows = stmt.query_map(params![limit], |row| {
            let id: i64 = row.get(0)?;
            let task_id_str: String = row.get(1)?;
            let task_id = Uuid::parse_str(&task_id_str).unwrap_or_default();
            let output: String = row.get(2)?;
            let is_partial: i32 = row.get(3)?;
            let timestamp_str: String = row.get(4)?;
            let timestamp = DateTime::parse_from_rfc3339(&timestamp_str)
                .map(|dt| dt.with_timezone(&Utc))
                .unwrap_or_else(|_| Utc::now());

            Ok(BufferedOutput {
                id,
                task_id,
                output,
                is_partial: is_partial != 0,
                timestamp,
            })
        })?;

        rows.collect()
    }

    /// Mark outputs as sent.
    pub fn mark_outputs_sent(&self, ids: &[i64]) -> SqliteResult<()> {
        if ids.is_empty() {
            return Ok(());
        }

        let placeholders: Vec<&str> = ids.iter().map(|_| "?").collect();
        let sql = format!(
            "UPDATE output_buffer SET sent = 1 WHERE id IN ({})",
            placeholders.join(",")
        );

        let params: Vec<rusqlite::types::Value> = ids
            .iter()
            .map(|id| rusqlite::types::Value::Integer(*id))
            .collect();

        self.conn.execute(&sql, rusqlite::params_from_iter(params))?;
        Ok(())
    }

    /// Clean up old sent outputs.
    pub fn cleanup_sent_outputs(&self, older_than_hours: i64) -> SqliteResult<usize> {
        let result = self.conn.execute(
            r#"
            DELETE FROM output_buffer
            WHERE sent = 1 AND timestamp < datetime('now', ?1 || ' hours')
            "#,
            params![format!("-{}", older_than_hours)],
        )?;
        Ok(result)
    }

    /// Get daemon state value.
    pub fn get_state(&self, key: &str) -> SqliteResult<Option<String>> {
        let mut stmt = self.conn.prepare(
            "SELECT value FROM daemon_state WHERE key = ?1",
        )?;

        let mut rows = stmt.query(params![key])?;

        if let Some(row) = rows.next()? {
            let value: String = row.get(0)?;
            Ok(Some(value))
        } else {
            Ok(None)
        }
    }

    /// Set daemon state value.
    pub fn set_state(&self, key: &str, value: &str) -> SqliteResult<()> {
        self.conn.execute(
            r#"
            INSERT OR REPLACE INTO daemon_state (key, value, updated_at)
            VALUES (?1, ?2, datetime('now'))
            "#,
            params![key, value],
        )?;
        Ok(())
    }

    /// Parse a task from a database row.
    fn task_from_row(row: &rusqlite::Row) -> SqliteResult<LocalTask> {
        let id_str: String = row.get(0)?;
        let server_task_id_str: String = row.get(1)?;
        let state_str: String = row.get(2)?;
        let container_id: Option<String> = row.get(3)?;
        let overlay_path: Option<String> = row.get(4)?;
        let repo_url: Option<String> = row.get(5)?;
        let base_branch: Option<String> = row.get(6)?;
        let plan: String = row.get(7)?;
        let created_at_str: String = row.get(8)?;
        let started_at_str: Option<String> = row.get(9)?;
        let completed_at_str: Option<String> = row.get(10)?;
        let error_message: Option<String> = row.get(11)?;

        let id = Uuid::parse_str(&id_str).unwrap_or_default();
        let server_task_id = Uuid::parse_str(&server_task_id_str).unwrap_or_default();
        let state = TaskState::from_str(&state_str).unwrap_or_default();
        let created_at = DateTime::parse_from_rfc3339(&created_at_str)
            .map(|dt| dt.with_timezone(&Utc))
            .unwrap_or_else(|_| Utc::now());
        let started_at = started_at_str
            .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
            .map(|dt| dt.with_timezone(&Utc));
        let completed_at = completed_at_str
            .and_then(|s| DateTime::parse_from_rfc3339(&s).ok())
            .map(|dt| dt.with_timezone(&Utc));

        Ok(LocalTask {
            id,
            server_task_id,
            state,
            container_id,
            overlay_path,
            repo_url,
            base_branch,
            plan,
            created_at,
            started_at,
            completed_at,
            error_message,
        })
    }
}

#[cfg(test)]
mod tests {
    use crate::daemon::*;

    #[test]
    fn test_open_memory() {
        let db = LocalDb::open_memory().unwrap();
        assert!(db.get_active_tasks().unwrap().is_empty());
    }

    #[test]
    fn test_save_and_get_task() {
        let db = LocalDb::open_memory().unwrap();

        let task = LocalTask {
            id: Uuid::new_v4(),
            server_task_id: Uuid::new_v4(),
            state: TaskState::Running,
            container_id: Some("abc123".to_string()),
            overlay_path: Some("/tmp/overlay".to_string()),
            repo_url: Some("https://github.com/test/repo".to_string()),
            base_branch: Some("main".to_string()),
            plan: "Build the feature".to_string(),
            created_at: Utc::now(),
            started_at: Some(Utc::now()),
            completed_at: None,
            error_message: None,
        };

        db.save_task(&task).unwrap();

        let loaded = db.get_task(task.id).unwrap().unwrap();
        assert_eq!(loaded.id, task.id);
        assert_eq!(loaded.state, TaskState::Running);
        assert_eq!(loaded.plan, "Build the feature");
    }

    #[test]
    fn test_output_buffer() {
        let db = LocalDb::open_memory().unwrap();
        let task_id = Uuid::new_v4();

        db.buffer_output(task_id, "line 1", false).unwrap();
        db.buffer_output(task_id, "line 2", false).unwrap();

        let unsent = db.get_unsent_outputs(10).unwrap();
        assert_eq!(unsent.len(), 2);

        let ids: Vec<i64> = unsent.iter().map(|o| o.id).collect();
        db.mark_outputs_sent(&ids).unwrap();

        let unsent = db.get_unsent_outputs(10).unwrap();
        assert!(unsent.is_empty());
    }
}