//! DAG validation and traversal for chain contracts.
//!
//! Provides cycle detection and topological sorting for contract dependencies.
use std::collections::{HashMap, HashSet, VecDeque};
use thiserror::Error;
use super::parser::ChainDefinition;
/// Error type for DAG operations.
#[derive(Error, Debug)]
pub enum DagError {
#[error("Cycle detected in dependency graph: {0}")]
CycleDetected(String),
#[error("Unknown contract in dependency: {0}")]
UnknownContract(String),
}
/// Validates that the chain definition forms a valid DAG (no cycles).
///
/// Uses depth-first search with color marking to detect cycles.
/// Returns Ok(()) if valid, or an error describing the cycle.
pub fn validate_dag(chain: &ChainDefinition) -> Result<(), DagError> {
// Build adjacency list from contract dependencies
let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
let contract_names: HashSet<&str> = chain.contracts.iter().map(|c| c.name.as_str()).collect();
for contract in &chain.contracts {
let deps: Vec<&str> = contract
.depends_on
.as_ref()
.map(|d| d.iter().map(|s| s.as_str()).collect())
.unwrap_or_default();
// Validate all dependencies exist
for dep in &deps {
if !contract_names.contains(dep) {
return Err(DagError::UnknownContract(format!(
"Contract '{}' depends on unknown contract '{}'",
contract.name, dep
)));
}
}
adjacency.insert(contract.name.as_str(), deps);
}
// Color-based DFS for cycle detection
// White (0): not visited, Gray (1): in progress, Black (2): completed
let mut color: HashMap<&str, u8> = HashMap::new();
for name in &contract_names {
color.insert(name, 0);
}
// Track path for cycle reporting
fn dfs<'a>(
node: &'a str,
adjacency: &HashMap<&'a str, Vec<&'a str>>,
color: &mut HashMap<&'a str, u8>,
path: &mut Vec<&'a str>,
) -> Result<(), DagError> {
color.insert(node, 1); // Mark as in-progress
path.push(node);
if let Some(deps) = adjacency.get(node) {
for dep in deps {
match color.get(dep) {
Some(1) => {
// Found cycle - dep is in current path
let cycle_start = path.iter().position(|&n| n == *dep).unwrap();
let cycle: Vec<_> = path[cycle_start..].to_vec();
return Err(DagError::CycleDetected(format!(
"{} -> {}",
cycle.join(" -> "),
dep
)));
}
Some(0) => {
// Not visited - recurse
dfs(dep, adjacency, color, path)?;
}
_ => {
// Already completed - skip
}
}
}
}
color.insert(node, 2); // Mark as completed
path.pop();
Ok(())
}
// Run DFS from each unvisited node
for name in &contract_names {
if color.get(name) == Some(&0) {
let mut path = Vec::new();
dfs(name, &adjacency, &mut color, &mut path)?;
}
}
Ok(())
}
/// Returns contracts in topological order (dependencies before dependents).
///
/// Uses Kahn's algorithm for topological sorting.
pub fn topological_sort(chain: &ChainDefinition) -> Result<Vec<&str>, DagError> {
// Validate first
validate_dag(chain)?;
// Build in-degree map and adjacency list
let mut in_degree: HashMap<&str, usize> = HashMap::new();
let mut dependents: HashMap<&str, Vec<&str>> = HashMap::new();
for contract in &chain.contracts {
in_degree.entry(contract.name.as_str()).or_insert(0);
dependents.entry(contract.name.as_str()).or_default();
if let Some(deps) = &contract.depends_on {
for dep in deps {
*in_degree.entry(contract.name.as_str()).or_insert(0) += 1;
dependents
.entry(dep.as_str())
.or_default()
.push(contract.name.as_str());
}
}
}
// Kahn's algorithm
let mut queue: VecDeque<&str> = VecDeque::new();
let mut result: Vec<&str> = Vec::new();
// Start with nodes that have no dependencies
for (name, °ree) in &in_degree {
if degree == 0 {
queue.push_back(name);
}
}
while let Some(node) = queue.pop_front() {
result.push(node);
if let Some(deps) = dependents.get(node) {
for dep in deps {
if let Some(degree) = in_degree.get_mut(dep) {
*degree -= 1;
if *degree == 0 {
queue.push_back(dep);
}
}
}
}
}
Ok(result)
}
/// Returns contracts that are ready to run (have no unmet dependencies).
///
/// Takes a set of completed contract names and returns contracts that
/// can now be started.
pub fn get_ready_contracts<'a>(
chain: &'a ChainDefinition,
completed: &HashSet<&str>,
) -> Vec<&'a str> {
chain
.contracts
.iter()
.filter(|c| {
// Already completed? Skip
if completed.contains(c.name.as_str()) {
return false;
}
// Check if all dependencies are met
match &c.depends_on {
None => true, // No dependencies
Some(deps) => deps.iter().all(|d| completed.contains(d.as_str())),
}
})
.map(|c| c.name.as_str())
.collect()
}
/// Get the depth of each contract in the DAG (for layout purposes).
///
/// Root nodes (no dependencies) have depth 0.
/// Each dependent has depth = max(dependency depths) + 1.
pub fn get_contract_depths(chain: &ChainDefinition) -> HashMap<&str, usize> {
let mut depths: HashMap<&str, usize> = HashMap::new();
// Multiple passes to handle dependencies
let max_iterations = chain.contracts.len();
for _ in 0..max_iterations {
let mut changed = false;
for contract in &chain.contracts {
let new_depth = match &contract.depends_on {
None => 0,
Some(deps) => {
if deps.iter().all(|d| depths.contains_key(d.as_str())) {
deps.iter()
.filter_map(|d| depths.get(d.as_str()))
.max()
.copied()
.unwrap_or(0)
+ 1
} else {
continue; // Dependencies not yet computed
}
}
};
if depths.get(contract.name.as_str()) != Some(&new_depth) {
depths.insert(contract.name.as_str(), new_depth);
changed = true;
}
}
if !changed {
break;
}
}
depths
}
#[cfg(test)]
mod tests {
use super::*;
use crate::daemon::chain::parser::parse_chain_yaml;
#[test]
fn test_valid_dag() {
let yaml = r#"
name: Valid DAG
contracts:
- name: A
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
depends_on: [A]
tasks:
- name: Task
plan: "Do C"
- name: D
depends_on: [B, C]
tasks:
- name: Task
plan: "Do D"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
assert!(validate_dag(&chain).is_ok());
}
#[test]
fn test_simple_cycle() {
let yaml = r#"
name: Simple Cycle
contracts:
- name: A
depends_on: [B]
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
let result = validate_dag(&chain);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Cycle detected"));
}
#[test]
fn test_longer_cycle() {
let yaml = r#"
name: Longer Cycle
contracts:
- name: A
depends_on: [C]
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
depends_on: [B]
tasks:
- name: Task
plan: "Do C"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
let result = validate_dag(&chain);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Cycle detected"));
}
#[test]
fn test_topological_sort() {
let yaml = r#"
name: Topo Test
contracts:
- name: A
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
depends_on: [A]
tasks:
- name: Task
plan: "Do C"
- name: D
depends_on: [B, C]
tasks:
- name: Task
plan: "Do D"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
let sorted = topological_sort(&chain).unwrap();
// A must come before B, C; B and C must come before D
let pos_a = sorted.iter().position(|&n| n == "A").unwrap();
let pos_b = sorted.iter().position(|&n| n == "B").unwrap();
let pos_c = sorted.iter().position(|&n| n == "C").unwrap();
let pos_d = sorted.iter().position(|&n| n == "D").unwrap();
assert!(pos_a < pos_b);
assert!(pos_a < pos_c);
assert!(pos_b < pos_d);
assert!(pos_c < pos_d);
}
#[test]
fn test_get_ready_contracts() {
let yaml = r#"
name: Ready Test
contracts:
- name: A
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
tasks:
- name: Task
plan: "Do C"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
// Initially A and C are ready (no dependencies)
let completed = HashSet::new();
let mut ready = get_ready_contracts(&chain, &completed);
ready.sort();
assert_eq!(ready, vec!["A", "C"]);
// After A completes, B becomes ready
let mut completed = HashSet::new();
completed.insert("A");
let ready = get_ready_contracts(&chain, &completed);
assert!(ready.contains(&"B"));
assert!(ready.contains(&"C")); // C still ready if not started
}
#[test]
fn test_get_contract_depths() {
let yaml = r#"
name: Depth Test
contracts:
- name: A
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
depends_on: [B]
tasks:
- name: Task
plan: "Do C"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
let depths = get_contract_depths(&chain);
assert_eq!(depths.get("A"), Some(&0));
assert_eq!(depths.get("B"), Some(&1));
assert_eq!(depths.get("C"), Some(&2));
}
#[test]
fn test_diamond_dependency_depths() {
let yaml = r#"
name: Diamond Test
contracts:
- name: A
tasks:
- name: Task
plan: "Do A"
- name: B
depends_on: [A]
tasks:
- name: Task
plan: "Do B"
- name: C
depends_on: [A]
tasks:
- name: Task
plan: "Do C"
- name: D
depends_on: [B, C]
tasks:
- name: Task
plan: "Do D"
"#;
let chain = parse_chain_yaml(yaml).unwrap();
let depths = get_contract_depths(&chain);
assert_eq!(depths.get("A"), Some(&0));
assert_eq!(depths.get("B"), Some(&1));
assert_eq!(depths.get("C"), Some(&1));
assert_eq!(depths.get("D"), Some(&2));
}
}