summaryrefslogtreecommitdiff
path: root/makima/src/daemon/chain/dag.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/daemon/chain/dag.rs')
-rw-r--r--makima/src/daemon/chain/dag.rs450
1 files changed, 450 insertions, 0 deletions
diff --git a/makima/src/daemon/chain/dag.rs b/makima/src/daemon/chain/dag.rs
new file mode 100644
index 0000000..7ba5904
--- /dev/null
+++ b/makima/src/daemon/chain/dag.rs
@@ -0,0 +1,450 @@
+//! DAG validation and traversal for chain contracts.
+//!
+//! Provides cycle detection and topological sorting for contract dependencies.
+
+use std::collections::{HashMap, HashSet, VecDeque};
+use thiserror::Error;
+
+use super::parser::ChainDefinition;
+
+/// Error type for DAG operations.
+#[derive(Error, Debug)]
+pub enum DagError {
+ #[error("Cycle detected in dependency graph: {0}")]
+ CycleDetected(String),
+
+ #[error("Unknown contract in dependency: {0}")]
+ UnknownContract(String),
+}
+
+/// Validates that the chain definition forms a valid DAG (no cycles).
+///
+/// Uses depth-first search with color marking to detect cycles.
+/// Returns Ok(()) if valid, or an error describing the cycle.
+pub fn validate_dag(chain: &ChainDefinition) -> Result<(), DagError> {
+ // Build adjacency list from contract dependencies
+ let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
+ let contract_names: HashSet<&str> = chain.contracts.iter().map(|c| c.name.as_str()).collect();
+
+ for contract in &chain.contracts {
+ let deps: Vec<&str> = contract
+ .depends_on
+ .as_ref()
+ .map(|d| d.iter().map(|s| s.as_str()).collect())
+ .unwrap_or_default();
+
+ // Validate all dependencies exist
+ for dep in &deps {
+ if !contract_names.contains(dep) {
+ return Err(DagError::UnknownContract(format!(
+ "Contract '{}' depends on unknown contract '{}'",
+ contract.name, dep
+ )));
+ }
+ }
+
+ adjacency.insert(contract.name.as_str(), deps);
+ }
+
+ // Color-based DFS for cycle detection
+ // White (0): not visited, Gray (1): in progress, Black (2): completed
+ let mut color: HashMap<&str, u8> = HashMap::new();
+ for name in &contract_names {
+ color.insert(name, 0);
+ }
+
+ // Track path for cycle reporting
+ fn dfs<'a>(
+ node: &'a str,
+ adjacency: &HashMap<&'a str, Vec<&'a str>>,
+ color: &mut HashMap<&'a str, u8>,
+ path: &mut Vec<&'a str>,
+ ) -> Result<(), DagError> {
+ color.insert(node, 1); // Mark as in-progress
+ path.push(node);
+
+ if let Some(deps) = adjacency.get(node) {
+ for dep in deps {
+ match color.get(dep) {
+ Some(1) => {
+ // Found cycle - dep is in current path
+ let cycle_start = path.iter().position(|&n| n == *dep).unwrap();
+ let cycle: Vec<_> = path[cycle_start..].to_vec();
+ return Err(DagError::CycleDetected(format!(
+ "{} -> {}",
+ cycle.join(" -> "),
+ dep
+ )));
+ }
+ Some(0) => {
+ // Not visited - recurse
+ dfs(dep, adjacency, color, path)?;
+ }
+ _ => {
+ // Already completed - skip
+ }
+ }
+ }
+ }
+
+ color.insert(node, 2); // Mark as completed
+ path.pop();
+ Ok(())
+ }
+
+ // Run DFS from each unvisited node
+ for name in &contract_names {
+ if color.get(name) == Some(&0) {
+ let mut path = Vec::new();
+ dfs(name, &adjacency, &mut color, &mut path)?;
+ }
+ }
+
+ Ok(())
+}
+
+/// Returns contracts in topological order (dependencies before dependents).
+///
+/// Uses Kahn's algorithm for topological sorting.
+pub fn topological_sort(chain: &ChainDefinition) -> Result<Vec<&str>, DagError> {
+ // Validate first
+ validate_dag(chain)?;
+
+ // Build in-degree map and adjacency list
+ let mut in_degree: HashMap<&str, usize> = HashMap::new();
+ let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
+
+ for contract in &chain.contracts {
+ in_degree.entry(contract.name.as_str()).or_insert(0);
+ dependents.entry(contract.name.as_str()).or_default();
+
+ if let Some(deps) = &contract.depends_on {
+ for dep in deps {
+ *in_degree.entry(contract.name.as_str()).or_insert(0) += 1;
+ dependents
+ .entry(dep.as_str())
+ .or_default()
+ .push(contract.name.as_str());
+ }
+ }
+ }
+
+ // Kahn's algorithm
+ let mut queue: VecDeque<&str> = VecDeque::new();
+ let mut result: Vec<&str> = Vec::new();
+
+ // Start with nodes that have no dependencies
+ for (name, &degree) in &in_degree {
+ if degree == 0 {
+ queue.push_back(name);
+ }
+ }
+
+ while let Some(node) = queue.pop_front() {
+ result.push(node);
+
+ if let Some(deps) = dependents.get(node) {
+ for dep in deps {
+ if let Some(degree) = in_degree.get_mut(dep) {
+ *degree -= 1;
+ if *degree == 0 {
+ queue.push_back(dep);
+ }
+ }
+ }
+ }
+ }
+
+ Ok(result)
+}
+
+/// Returns contracts that are ready to run (have no unmet dependencies).
+///
+/// Takes a set of completed contract names and returns contracts that
+/// can now be started.
+pub fn get_ready_contracts<'a>(
+ chain: &'a ChainDefinition,
+ completed: &HashSet<&str>,
+) -> Vec<&'a str> {
+ chain
+ .contracts
+ .iter()
+ .filter(|c| {
+ // Already completed? Skip
+ if completed.contains(c.name.as_str()) {
+ return false;
+ }
+
+ // Check if all dependencies are met
+ match &c.depends_on {
+ None => true, // No dependencies
+ Some(deps) => deps.iter().all(|d| completed.contains(d.as_str())),
+ }
+ })
+ .map(|c| c.name.as_str())
+ .collect()
+}
+
+/// Get the depth of each contract in the DAG (for layout purposes).
+///
+/// Root nodes (no dependencies) have depth 0.
+/// Each dependent has depth = max(dependency depths) + 1.
+pub fn get_contract_depths(chain: &ChainDefinition) -> HashMap<&str, usize> {
+ let mut depths: HashMap<&str, usize> = HashMap::new();
+
+ // Multiple passes to handle dependencies
+ let max_iterations = chain.contracts.len();
+ for _ in 0..max_iterations {
+ let mut changed = false;
+
+ for contract in &chain.contracts {
+ let new_depth = match &contract.depends_on {
+ None => 0,
+ Some(deps) => {
+ if deps.iter().all(|d| depths.contains_key(d.as_str())) {
+ deps.iter()
+ .filter_map(|d| depths.get(d.as_str()))
+ .max()
+ .copied()
+ .unwrap_or(0)
+ + 1
+ } else {
+ continue; // Dependencies not yet computed
+ }
+ }
+ };
+
+ if depths.get(contract.name.as_str()) != Some(&new_depth) {
+ depths.insert(contract.name.as_str(), new_depth);
+ changed = true;
+ }
+ }
+
+ if !changed {
+ break;
+ }
+ }
+
+ depths
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::daemon::chain::parser::parse_chain_yaml;
+
+ #[test]
+ fn test_valid_dag() {
+ let yaml = r#"
+name: Valid DAG
+contracts:
+ - name: A
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do C"
+ - name: D
+ depends_on: [B, C]
+ tasks:
+ - name: Task
+ plan: "Do D"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ assert!(validate_dag(&chain).is_ok());
+ }
+
+ #[test]
+ fn test_simple_cycle() {
+ let yaml = r#"
+name: Simple Cycle
+contracts:
+ - name: A
+ depends_on: [B]
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ let result = validate_dag(&chain);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("Cycle detected"));
+ }
+
+ #[test]
+ fn test_longer_cycle() {
+ let yaml = r#"
+name: Longer Cycle
+contracts:
+ - name: A
+ depends_on: [C]
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ depends_on: [B]
+ tasks:
+ - name: Task
+ plan: "Do C"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ let result = validate_dag(&chain);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("Cycle detected"));
+ }
+
+ #[test]
+ fn test_topological_sort() {
+ let yaml = r#"
+name: Topo Test
+contracts:
+ - name: A
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do C"
+ - name: D
+ depends_on: [B, C]
+ tasks:
+ - name: Task
+ plan: "Do D"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ let sorted = topological_sort(&chain).unwrap();
+
+ // A must come before B, C; B and C must come before D
+ let pos_a = sorted.iter().position(|&n| n == "A").unwrap();
+ let pos_b = sorted.iter().position(|&n| n == "B").unwrap();
+ let pos_c = sorted.iter().position(|&n| n == "C").unwrap();
+ let pos_d = sorted.iter().position(|&n| n == "D").unwrap();
+
+ assert!(pos_a < pos_b);
+ assert!(pos_a < pos_c);
+ assert!(pos_b < pos_d);
+ assert!(pos_c < pos_d);
+ }
+
+ #[test]
+ fn test_get_ready_contracts() {
+ let yaml = r#"
+name: Ready Test
+contracts:
+ - name: A
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ tasks:
+ - name: Task
+ plan: "Do C"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+
+ // Initially A and C are ready (no dependencies)
+ let completed = HashSet::new();
+ let mut ready = get_ready_contracts(&chain, &completed);
+ ready.sort();
+ assert_eq!(ready, vec!["A", "C"]);
+
+ // After A completes, B becomes ready
+ let mut completed = HashSet::new();
+ completed.insert("A");
+ let ready = get_ready_contracts(&chain, &completed);
+ assert!(ready.contains(&"B"));
+ assert!(ready.contains(&"C")); // C still ready if not started
+ }
+
+ #[test]
+ fn test_get_contract_depths() {
+ let yaml = r#"
+name: Depth Test
+contracts:
+ - name: A
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ depends_on: [B]
+ tasks:
+ - name: Task
+ plan: "Do C"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ let depths = get_contract_depths(&chain);
+
+ assert_eq!(depths.get("A"), Some(&0));
+ assert_eq!(depths.get("B"), Some(&1));
+ assert_eq!(depths.get("C"), Some(&2));
+ }
+
+ #[test]
+ fn test_diamond_dependency_depths() {
+ let yaml = r#"
+name: Diamond Test
+contracts:
+ - name: A
+ tasks:
+ - name: Task
+ plan: "Do A"
+ - name: B
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do B"
+ - name: C
+ depends_on: [A]
+ tasks:
+ - name: Task
+ plan: "Do C"
+ - name: D
+ depends_on: [B, C]
+ tasks:
+ - name: Task
+ plan: "Do D"
+"#;
+ let chain = parse_chain_yaml(yaml).unwrap();
+ let depths = get_contract_depths(&chain);
+
+ assert_eq!(depths.get("A"), Some(&0));
+ assert_eq!(depths.get("B"), Some(&1));
+ assert_eq!(depths.get("C"), Some(&1));
+ assert_eq!(depths.get("D"), Some(&2));
+ }
+}