diff options
Diffstat (limited to 'makima/src/db')
| -rw-r--r-- | makima/src/db/models.rs | 213 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 344 |
2 files changed, 506 insertions, 51 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 9e624c9..2eeba87 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -1114,6 +1114,108 @@ pub struct MergeCompleteCheckResponse { } // ============================================================================= +// Contract Type Templates (User-defined) +// ============================================================================= + +/// A phase definition within a contract template +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PhaseDefinition { + /// Phase identifier (e.g., "research", "plan", "execute") + pub id: String, + /// Display name for the phase + pub name: String, + /// Order in the workflow (0-indexed) + pub order: i32, +} + +/// A deliverable definition within a phase +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DeliverableDefinition { + /// Deliverable identifier (e.g., "plan-document", "pull-request") + pub id: String, + /// Display name for the deliverable + pub name: String, + /// Priority: "required", "recommended", or "optional" + #[serde(default = "default_priority")] + pub priority: String, +} + +fn default_priority() -> String { + "required".to_string() +} + +/// Phase configuration stored on a contract (copied from template at creation) +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PhaseConfig { + /// Ordered list of phases in the workflow + pub phases: Vec<PhaseDefinition>, + /// Default starting phase + pub default_phase: String, + /// Deliverables per phase: { "phase_id": [deliverables] } + #[serde(default)] + pub deliverables: std::collections::HashMap<String, Vec<DeliverableDefinition>>, +} + +/// Contract type template record from the database +#[derive(Debug, Clone, FromRow, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ContractTypeTemplateRecord { + pub id: Uuid, + pub owner_id: Uuid, + pub name: String, + pub description: Option<String>, + #[sqlx(json)] + pub phases: Vec<PhaseDefinition>, + pub default_phase: String, + #[sqlx(json)] + pub deliverables: Option<std::collections::HashMap<String, Vec<DeliverableDefinition>>>, + pub version: i32, + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, +} + +/// Request to create a new contract type template +#[derive(Debug, Clone, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateTemplateRequest { + pub name: String, + pub description: Option<String>, + pub phases: Vec<PhaseDefinition>, + pub default_phase: String, + pub deliverables: Option<std::collections::HashMap<String, Vec<DeliverableDefinition>>>, +} + +/// Request to update a contract type template +#[derive(Debug, Clone, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct UpdateTemplateRequest { + pub name: Option<String>, + pub description: Option<String>, + pub phases: Option<Vec<PhaseDefinition>>, + pub default_phase: Option<String>, + pub deliverables: Option<std::collections::HashMap<String, Vec<DeliverableDefinition>>>, + /// Version for optimistic locking + pub version: Option<i32>, +} + +/// Summary of a contract type template for list views +#[derive(Debug, Clone, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ContractTypeTemplateSummary { + pub id: Uuid, + pub name: String, + pub description: Option<String>, + pub phases: Vec<PhaseDefinition>, + pub default_phase: String, + pub is_builtin: bool, + pub version: i32, + pub created_at: DateTime<Utc>, +} + +// ============================================================================= // Contract Types // ============================================================================= @@ -1355,6 +1457,11 @@ pub struct Contract { /// when evaluating task outputs. #[serde(skip_serializing_if = "Option::is_none")] pub red_team_prompt: Option<String>, + /// Phase configuration copied from template at contract creation. + /// When present, this overrides the built-in contract type phases. + #[sqlx(json)] + #[serde(skip_serializing_if = "Option::is_none")] + pub phase_config: Option<PhaseConfig>, pub version: i32, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, @@ -1376,37 +1483,96 @@ impl Contract { self.status.parse() } - /// Get valid phases for this contract type - pub fn valid_phases(&self) -> Vec<ContractPhase> { + /// Get valid phase IDs for this contract (as strings) + pub fn valid_phase_ids(&self) -> Vec<String> { + // Check phase_config first (for custom templates) + if let Some(ref config) = self.phase_config { + let mut phases: Vec<_> = config.phases.iter().collect(); + phases.sort_by_key(|p| p.order); + return phases.iter().map(|p| p.id.clone()).collect(); + } + + // Fall back to built-in contract types match self.contract_type.as_str() { - "simple" => vec![ContractPhase::Plan, ContractPhase::Execute], + "simple" => vec!["plan".to_string(), "execute".to_string()], "specification" => vec![ - ContractPhase::Research, - ContractPhase::Specify, - ContractPhase::Plan, - ContractPhase::Execute, - ContractPhase::Review, + "research".to_string(), + "specify".to_string(), + "plan".to_string(), + "execute".to_string(), + "review".to_string(), ], - "execute" => vec![ContractPhase::Execute], // Execute-only, single phase - _ => vec![ContractPhase::Plan, ContractPhase::Execute], // Default to simple + "execute" => vec!["execute".to_string()], + _ => vec!["plan".to_string(), "execute".to_string()], + } + } + + /// Get valid phases for this contract type (as ContractPhase enums) + /// Note: For custom templates with non-standard phases, this only returns + /// phases that map to the ContractPhase enum. + pub fn valid_phases(&self) -> Vec<ContractPhase> { + self.valid_phase_ids() + .iter() + .filter_map(|id| id.parse::<ContractPhase>().ok()) + .collect() + } + + /// Get the initial phase ID for this contract type (as string) + pub fn initial_phase_id(&self) -> String { + // Check phase_config first (for custom templates) + if let Some(ref config) = self.phase_config { + return config.default_phase.clone(); + } + + // Fall back to built-in contract types + match self.contract_type.as_str() { + "specification" => "research".to_string(), + "execute" => "execute".to_string(), + _ => "plan".to_string(), } } - /// Get the initial phase for this contract type + /// Get the initial phase for this contract type (as ContractPhase enum) pub fn initial_phase(&self) -> ContractPhase { + self.initial_phase_id() + .parse() + .unwrap_or(ContractPhase::Plan) + } + + /// Get the terminal phase ID for this contract type (as string) + pub fn terminal_phase_id(&self) -> String { + // Check phase_config first (for custom templates) + if let Some(ref config) = self.phase_config { + // Last phase in sorted order is the terminal phase + let mut phases: Vec<_> = config.phases.iter().collect(); + phases.sort_by_key(|p| p.order); + if let Some(last) = phases.last() { + return last.id.clone(); + } + } + + // Fall back to built-in contract types match self.contract_type.as_str() { - "specification" => ContractPhase::Research, - "execute" => ContractPhase::Execute, - _ => ContractPhase::Plan, // simple and default + "specification" => "review".to_string(), + _ => "execute".to_string(), } } /// Get the terminal phase for this contract type (phase where contract can be completed) pub fn terminal_phase(&self) -> ContractPhase { - match self.contract_type.as_str() { - "specification" => ContractPhase::Review, - _ => ContractPhase::Execute, // simple and execute both end at execute - } + self.terminal_phase_id() + .parse() + .unwrap_or(ContractPhase::Execute) + } + + /// Check if a phase ID is valid for this contract + pub fn is_valid_phase(&self, phase_id: &str) -> bool { + self.valid_phase_ids().contains(&phase_id.to_string()) + } + + /// Get the phase configuration for custom templates + pub fn get_phase_config(&self) -> Option<&PhaseConfig> { + self.phase_config.as_ref() } /// Get completed deliverable IDs for a specific phase @@ -1507,12 +1673,19 @@ pub struct CreateContractRequest { pub name: String, /// Optional description pub description: Option<String>, - /// Contract type: "simple" (default) or "specification" + /// Contract type: "simple" (default), "specification", "execute", or a custom template name. + /// For built-in types: /// - simple: Plan -> Execute workflow /// - specification: Research -> Specify -> Plan -> Execute -> Review + /// - execute: Execute only + /// For custom templates, use the template name or provide template_id. #[serde(default)] pub contract_type: Option<String>, - /// Initial phase to start in (defaults based on contract_type) + /// UUID of a custom template to use. If provided, this takes precedence over contract_type. + /// The template's phase configuration will be copied to the contract. + #[serde(default)] + pub template_id: Option<Uuid>, + /// Initial phase to start in (defaults based on contract_type or template) /// - simple: defaults to "plan" /// - specification: defaults to "research" #[serde(default)] diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index b947cdd..1ab4165 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -8,11 +8,12 @@ use uuid::Uuid; use super::models::{ CheckpointPatch, CheckpointPatchInfo, Contract, ContractChatConversation, ContractChatMessageRecord, ContractEvent, ContractRepository, ContractSummary, - ConversationMessage, ConversationSnapshot, CreateContractRequest, CreateFileRequest, - CreateTaskRequest, Daemon, DaemonTaskAssignment, DaemonWithCapacity, File, FileSummary, - FileVersion, HistoryEvent, HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, + ContractTypeTemplateRecord, ConversationMessage, ConversationSnapshot, CreateContractRequest, + CreateFileRequest, CreateTaskRequest, CreateTemplateRequest, Daemon, DaemonTaskAssignment, + DaemonWithCapacity, DeliverableDefinition, File, FileSummary, FileVersion, HistoryEvent, + HistoryQueryFilters, MeshChatConversation, MeshChatMessageRecord, PhaseConfig, PhaseDefinition, RedTeamNotification, SupervisorState, Task, TaskCheckpoint, TaskEvent, TaskSummary, - UpdateContractRequest, UpdateFileRequest, UpdateTaskRequest, + UpdateContractRequest, UpdateFileRequest, UpdateTaskRequest, UpdateTemplateRequest, }; /// Repository error types. @@ -2141,68 +2142,349 @@ pub async fn clear_contract_conversation( } // ============================================================================= +// Contract Type Template Functions (Owner-Scoped) +// ============================================================================= + +/// Create a new contract type template for a specific owner. +pub async fn create_template_for_owner( + pool: &PgPool, + owner_id: Uuid, + req: CreateTemplateRequest, +) -> Result<ContractTypeTemplateRecord, sqlx::Error> { + sqlx::query_as::<_, ContractTypeTemplateRecord>( + r#" + INSERT INTO contract_type_templates (owner_id, name, description, phases, default_phase, deliverables) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING * + "#, + ) + .bind(owner_id) + .bind(&req.name) + .bind(&req.description) + .bind(serde_json::to_value(&req.phases).unwrap_or_default()) + .bind(&req.default_phase) + .bind(req.deliverables.as_ref().map(|d| serde_json::to_value(d).unwrap_or_default())) + .fetch_one(pool) + .await +} + +/// Get a contract type template by ID, scoped to owner. +pub async fn get_template_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<Option<ContractTypeTemplateRecord>, sqlx::Error> { + sqlx::query_as::<_, ContractTypeTemplateRecord>( + r#" + SELECT * + FROM contract_type_templates + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .fetch_optional(pool) + .await +} + +/// Get a contract type template by ID (internal use, no owner scoping). +pub async fn get_template_by_id( + pool: &PgPool, + id: Uuid, +) -> Result<Option<ContractTypeTemplateRecord>, sqlx::Error> { + sqlx::query_as::<_, ContractTypeTemplateRecord>( + r#" + SELECT * + FROM contract_type_templates + WHERE id = $1 + "#, + ) + .bind(id) + .fetch_optional(pool) + .await +} + +/// List all contract type templates for an owner, ordered by name. +pub async fn list_templates_for_owner( + pool: &PgPool, + owner_id: Uuid, +) -> Result<Vec<ContractTypeTemplateRecord>, sqlx::Error> { + sqlx::query_as::<_, ContractTypeTemplateRecord>( + r#" + SELECT * + FROM contract_type_templates + WHERE owner_id = $1 + ORDER BY name ASC + "#, + ) + .bind(owner_id) + .fetch_all(pool) + .await +} + +/// Update a contract type template for an owner. +pub async fn update_template_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, + req: UpdateTemplateRequest, +) -> Result<Option<ContractTypeTemplateRecord>, RepositoryError> { + // Build dynamic update query + let mut query = String::from("UPDATE contract_type_templates SET updated_at = NOW()"); + let mut param_idx = 3; // $1 = id, $2 = owner_id + + if req.name.is_some() { + query.push_str(&format!(", name = ${}", param_idx)); + param_idx += 1; + } + if req.description.is_some() { + query.push_str(&format!(", description = ${}", param_idx)); + param_idx += 1; + } + if req.phases.is_some() { + query.push_str(&format!(", phases = ${}", param_idx)); + param_idx += 1; + } + if req.default_phase.is_some() { + query.push_str(&format!(", default_phase = ${}", param_idx)); + param_idx += 1; + } + if req.deliverables.is_some() { + query.push_str(&format!(", deliverables = ${}", param_idx)); + param_idx += 1; + } + + // Optimistic locking + if req.version.is_some() { + query.push_str(&format!(", version = version + 1 WHERE id = $1 AND owner_id = $2 AND version = ${}", param_idx)); + } else { + query.push_str(", version = version + 1 WHERE id = $1 AND owner_id = $2"); + } + query.push_str(" RETURNING *"); + + let mut sql_query = sqlx::query_as::<_, ContractTypeTemplateRecord>(&query); + sql_query = sql_query.bind(id).bind(owner_id); + + if let Some(ref name) = req.name { + sql_query = sql_query.bind(name); + } + if let Some(ref description) = req.description { + sql_query = sql_query.bind(description); + } + if let Some(ref phases) = req.phases { + sql_query = sql_query.bind(serde_json::to_value(phases).unwrap_or_default()); + } + if let Some(ref default_phase) = req.default_phase { + sql_query = sql_query.bind(default_phase); + } + if let Some(ref deliverables) = req.deliverables { + sql_query = sql_query.bind(serde_json::to_value(deliverables).unwrap_or_default()); + } + if let Some(version) = req.version { + sql_query = sql_query.bind(version); + } + + match sql_query.fetch_optional(pool).await { + Ok(result) => { + if result.is_none() && req.version.is_some() { + // Check if it's a version conflict + if let Some(current) = get_template_for_owner(pool, id, owner_id).await? { + return Err(RepositoryError::VersionConflict { + expected: req.version.unwrap(), + actual: current.version, + }); + } + } + Ok(result) + } + Err(e) => Err(RepositoryError::Database(e)), + } +} + +/// Delete a contract type template for an owner. +pub async fn delete_template_for_owner( + pool: &PgPool, + id: Uuid, + owner_id: Uuid, +) -> Result<bool, sqlx::Error> { + let result = sqlx::query( + r#" + DELETE FROM contract_type_templates + WHERE id = $1 AND owner_id = $2 + "#, + ) + .bind(id) + .bind(owner_id) + .execute(pool) + .await?; + + Ok(result.rows_affected() > 0) +} + +/// Helper function to build PhaseConfig from a template. +pub fn build_phase_config_from_template(template: &ContractTypeTemplateRecord) -> PhaseConfig { + PhaseConfig { + phases: template.phases.clone(), + default_phase: template.default_phase.clone(), + deliverables: template.deliverables.clone().unwrap_or_default(), + } +} + +/// Helper function to build PhaseConfig for built-in contract types. +pub fn build_phase_config_for_builtin(contract_type: &str) -> PhaseConfig { + match contract_type { + "simple" => PhaseConfig { + phases: vec![ + PhaseDefinition { id: "plan".to_string(), name: "Plan".to_string(), order: 0 }, + PhaseDefinition { id: "execute".to_string(), name: "Execute".to_string(), order: 1 }, + ], + default_phase: "plan".to_string(), + deliverables: [ + ("plan".to_string(), vec![DeliverableDefinition { + id: "plan-document".to_string(), + name: "Plan".to_string(), + priority: "required".to_string(), + }]), + ("execute".to_string(), vec![DeliverableDefinition { + id: "pull-request".to_string(), + name: "Pull Request".to_string(), + priority: "required".to_string(), + }]), + ].into_iter().collect(), + }, + "specification" => PhaseConfig { + phases: vec![ + PhaseDefinition { id: "research".to_string(), name: "Research".to_string(), order: 0 }, + PhaseDefinition { id: "specify".to_string(), name: "Specify".to_string(), order: 1 }, + PhaseDefinition { id: "plan".to_string(), name: "Plan".to_string(), order: 2 }, + PhaseDefinition { id: "execute".to_string(), name: "Execute".to_string(), order: 3 }, + PhaseDefinition { id: "review".to_string(), name: "Review".to_string(), order: 4 }, + ], + default_phase: "research".to_string(), + deliverables: [ + ("research".to_string(), vec![DeliverableDefinition { + id: "research-notes".to_string(), + name: "Research Notes".to_string(), + priority: "required".to_string(), + }]), + ("specify".to_string(), vec![DeliverableDefinition { + id: "requirements-document".to_string(), + name: "Requirements Document".to_string(), + priority: "required".to_string(), + }]), + ("plan".to_string(), vec![DeliverableDefinition { + id: "plan-document".to_string(), + name: "Plan".to_string(), + priority: "required".to_string(), + }]), + ("execute".to_string(), vec![DeliverableDefinition { + id: "pull-request".to_string(), + name: "Pull Request".to_string(), + priority: "required".to_string(), + }]), + ("review".to_string(), vec![DeliverableDefinition { + id: "release-notes".to_string(), + name: "Release Notes".to_string(), + priority: "required".to_string(), + }]), + ].into_iter().collect(), + }, + "execute" | _ => PhaseConfig { + phases: vec![ + PhaseDefinition { id: "execute".to_string(), name: "Execute".to_string(), order: 0 }, + ], + default_phase: "execute".to_string(), + deliverables: std::collections::HashMap::new(), + }, + } +} + +// ============================================================================= // Contract Functions (Owner-Scoped) // ============================================================================= /// Create a new contract for a specific owner. +/// Supports both built-in contract types (simple, specification, execute) and custom templates. pub async fn create_contract_for_owner( pool: &PgPool, owner_id: Uuid, req: CreateContractRequest, ) -> Result<Contract, sqlx::Error> { - // Default contract type is "simple" - let contract_type = req.contract_type.as_deref().unwrap_or("simple"); + // Determine phase configuration based on template_id or contract_type + let (phase_config, contract_type_str, default_phase): (PhaseConfig, String, String) = + if let Some(template_id) = req.template_id { + // Look up the custom template + let template = get_template_by_id(pool, template_id) + .await? + .ok_or_else(|| { + sqlx::Error::Protocol(format!("Template not found: {}", template_id)) + })?; + + let config = build_phase_config_from_template(&template); + let default = config.default_phase.clone(); + // For custom templates, store the template name as the contract_type + (config, template.name.clone(), default) + } else { + // Use built-in contract type + let contract_type = req.contract_type.as_deref().unwrap_or("simple"); - // Validate contract type - let valid_types = ["simple", "specification", "execute"]; - if !valid_types.contains(&contract_type) { - return Err(sqlx::Error::Protocol(format!( - "Invalid contract_type '{}'. Must be one of: {}", - contract_type, - valid_types.join(", ") - ))); - } + // Validate contract type + let valid_types = ["simple", "specification", "execute"]; + if !valid_types.contains(&contract_type) { + return Err(sqlx::Error::Protocol(format!( + "Invalid contract_type '{}'. Must be one of: {} or provide a template_id", + contract_type, + valid_types.join(", ") + ))); + } - // Determine valid phases based on contract type - let (valid_phases, default_phase): (&[&str], &str) = match contract_type { - "simple" => (&["plan", "execute"], "plan"), - "specification" => (&["research", "specify", "plan", "execute", "review"], "research"), - "execute" => (&["execute"], "execute"), - _ => (&["plan", "execute"], "plan"), - }; + let config = build_phase_config_for_builtin(contract_type); + let default = config.default_phase.clone(); + (config, contract_type.to_string(), default) + }; - // Use provided initial_phase or default based on contract type - let phase = req.initial_phase.as_deref().unwrap_or(default_phase); + // Get valid phase IDs from the configuration + let valid_phase_ids: Vec<String> = phase_config.phases.iter().map(|p| p.id.clone()).collect(); - // Validate the phase is valid for this contract type - if !valid_phases.contains(&phase) { + // Use provided initial_phase or default based on contract type/template + let phase = req.initial_phase.as_deref().unwrap_or(&default_phase); + + // Validate the phase is valid for this contract type/template + if !valid_phase_ids.contains(&phase.to_string()) { return Err(sqlx::Error::Protocol(format!( "Invalid initial_phase '{}' for contract type '{}'. Must be one of: {}", phase, - contract_type, - valid_phases.join(", ") + contract_type_str, + valid_phase_ids.join(", ") ))); } let autonomous_loop = req.autonomous_loop.unwrap_or(false); let phase_guard = req.phase_guard.unwrap_or(false); let local_only = req.local_only.unwrap_or(false); + let red_team_enabled = req.red_team_enabled.unwrap_or(false); + + // Serialize phase_config to JSON + let phase_config_json = serde_json::to_value(&phase_config).ok(); sqlx::query_as::<_, Contract>( r#" - INSERT INTO contracts (owner_id, name, description, contract_type, phase, autonomous_loop, phase_guard, local_only) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO contracts (owner_id, name, description, contract_type, phase, autonomous_loop, phase_guard, local_only, red_team_enabled, red_team_prompt, phase_config) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING * "#, ) .bind(owner_id) .bind(&req.name) .bind(&req.description) - .bind(contract_type) + .bind(&contract_type_str) .bind(phase) .bind(autonomous_loop) .bind(phase_guard) .bind(local_only) + .bind(red_team_enabled) + .bind(&req.red_team_prompt) + .bind(phase_config_json) .fetch_one(pool) .await } |
