//! Chain planner for LLM-based execution plan generation.
//!
//! Generates chains (DAGs of steps) from directive goals and requirements.
//! Supports both initial plan generation and replanning while preserving
//! completed work.
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
use uuid::Uuid;
use crate::db::models::{AddStepRequest, ChainStep, Directive};
/// Error type for planner operations.
#[derive(Error, Debug)]
pub enum PlannerError {
#[error("Cycle detected in DAG: {0}")]
CycleDetected(String),
#[error("Invalid dependency: step '{step}' depends on unknown step '{dependency}'")]
InvalidDependency { step: String, dependency: String },
#[error("LLM generation failed: {0}")]
LlmError(String),
#[error("Requirement not covered: {0}")]
RequirementNotCovered(String),
#[error("Invalid plan: {0}")]
InvalidPlan(String),
#[error("Empty plan generated")]
EmptyPlan,
}
/// Generated step from LLM planning.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedStep {
/// Unique name within the chain
pub name: String,
/// Type of step (e.g., "research", "implement", "test", "review")
pub step_type: String,
/// Description of what this step accomplishes
pub description: String,
/// Names of steps this depends on
pub depends_on: Vec<String>,
/// IDs of requirements this step addresses
pub requirement_ids: Vec<String>,
/// Contract template fields
pub contract_template: Option<ContractTemplate>,
}
/// Template for contract creation from step.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractTemplate {
/// Contract name
pub name: String,
/// Contract description
pub description: String,
/// Contract type (e.g., "simple", "agentic")
pub contract_type: String,
/// Phases for the contract
pub phases: Vec<String>,
/// Tasks within the contract
pub tasks: Vec<TaskTemplate>,
/// Deliverables expected
pub deliverables: Vec<DeliverableTemplate>,
}
/// Template for task within contract.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskTemplate {
pub name: String,
pub plan: String,
}
/// Template for deliverable.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeliverableTemplate {
pub id: String,
pub name: String,
pub priority: String,
}
/// Generated chain from planning.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedChain {
/// Name for the chain
pub name: String,
/// Description of the execution plan
pub description: String,
/// Steps in the chain
pub steps: Vec<GeneratedStep>,
}
/// Chain planner for LLM-based plan generation.
pub struct ChainPlanner {
/// Default step types to suggest (reserved for future use)
#[allow(dead_code)]
default_step_types: Vec<String>,
}
impl Default for ChainPlanner {
fn default() -> Self {
Self::new()
}
}
impl ChainPlanner {
/// Create a new chain planner.
pub fn new() -> Self {
Self {
default_step_types: vec![
"research".to_string(),
"design".to_string(),
"implement".to_string(),
"test".to_string(),
"review".to_string(),
"document".to_string(),
],
}
}
/// Build a planning prompt for the LLM.
pub fn build_planning_prompt(&self, directive: &Directive) -> String {
let requirements: Vec<String> = directive
.requirements
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_object())
.map(|obj| {
let id = obj.get("id").and_then(|v| v.as_str()).unwrap_or("?");
let desc = obj
.get("description")
.and_then(|v| v.as_str())
.unwrap_or("?");
format!("- {}: {}", id, desc)
})
.collect()
})
.unwrap_or_default();
let criteria: Vec<String> = directive
.acceptance_criteria
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_object())
.map(|obj| {
let id = obj.get("id").and_then(|v| v.as_str()).unwrap_or("?");
let criterion = obj
.get("criterion")
.and_then(|v| v.as_str())
.unwrap_or("?");
format!("- {}: {}", id, criterion)
})
.collect()
})
.unwrap_or_default();
let constraints: Vec<String> = directive
.constraints
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str())
.map(|s| format!("- {}", s))
.collect()
})
.unwrap_or_default();
format!(
r#"You are a software architect planning an execution chain for a coding task.
## Directive Goal
{goal}
## Requirements
{requirements}
## Acceptance Criteria
{criteria}
## Constraints
{constraints}
## Instructions
Create an execution plan as a chain of steps. Each step should:
1. Have a unique, descriptive name (kebab-case)
2. Specify its type (research, design, implement, test, review, document)
3. Declare dependencies on prior steps (if any)
4. Map to specific requirement IDs it addresses
5. Include a contract template with tasks and deliverables
The chain should form a valid DAG (no cycles). Steps can run in parallel if they don't depend on each other.
Respond with a JSON object in this format:
```json
{{
"name": "chain-name",
"description": "Brief description of the plan",
"steps": [
{{
"name": "step-name",
"step_type": "implement",
"description": "What this step does",
"depends_on": ["prior-step-name"],
"requirement_ids": ["REQ-001"],
"contract_template": {{
"name": "Contract Name",
"description": "Contract description",
"contract_type": "simple",
"phases": ["plan", "execute"],
"tasks": [
{{"name": "Task 1", "plan": "Detailed plan for this task"}}
],
"deliverables": [
{{"id": "del-1", "name": "Deliverable 1", "priority": "required"}}
]
}}
}}
]
}}
```
Generate the optimal execution plan now."#,
goal = directive.goal,
requirements = requirements.join("\n"),
criteria = criteria.join("\n"),
constraints = constraints.join("\n"),
)
}
/// Parse LLM response into a generated chain.
pub fn parse_plan_response(&self, response: &str) -> Result<GeneratedChain, PlannerError> {
// Extract JSON from response (may be wrapped in markdown code blocks)
let json_str = extract_json_from_response(response)?;
let chain: GeneratedChain = serde_json::from_str(&json_str)
.map_err(|e| PlannerError::InvalidPlan(format!("JSON parse error: {}", e)))?;
if chain.steps.is_empty() {
return Err(PlannerError::EmptyPlan);
}
// Validate the chain
self.validate_chain(&chain)?;
Ok(chain)
}
/// Validate a generated chain.
pub fn validate_chain(&self, chain: &GeneratedChain) -> Result<(), PlannerError> {
// Build step name set
let step_names: HashSet<&str> = chain.steps.iter().map(|s| s.name.as_str()).collect();
// Check for duplicate names
if step_names.len() != chain.steps.len() {
return Err(PlannerError::InvalidPlan(
"Duplicate step names detected".to_string(),
));
}
// Validate dependencies exist
for step in &chain.steps {
for dep in &step.depends_on {
if !step_names.contains(dep.as_str()) {
return Err(PlannerError::InvalidDependency {
step: step.name.clone(),
dependency: dep.clone(),
});
}
}
}
// Check for cycles using DFS
self.detect_cycles(chain)?;
Ok(())
}
/// Detect cycles in the chain DAG using DFS.
fn detect_cycles(&self, chain: &GeneratedChain) -> Result<(), PlannerError> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
// Build adjacency map
let adj: HashMap<&str, Vec<&str>> = chain
.steps
.iter()
.map(|s| (s.name.as_str(), s.depends_on.iter().map(|d| d.as_str()).collect()))
.collect();
for step in &chain.steps {
if !visited.contains(step.name.as_str()) {
if self.has_cycle(&step.name, &adj, &mut visited, &mut rec_stack) {
return Err(PlannerError::CycleDetected(step.name.clone()));
}
}
}
Ok(())
}
fn has_cycle<'a>(
&self,
node: &'a str,
adj: &HashMap<&'a str, Vec<&'a str>>,
visited: &mut HashSet<&'a str>,
rec_stack: &mut HashSet<&'a str>,
) -> bool {
visited.insert(node);
rec_stack.insert(node);
if let Some(deps) = adj.get(node) {
for &dep in deps {
if !visited.contains(dep) {
if self.has_cycle(dep, adj, visited, rec_stack) {
return true;
}
} else if rec_stack.contains(dep) {
return true;
}
}
}
rec_stack.remove(node);
false
}
/// Check that all requirements are covered by at least one step.
pub fn check_requirement_coverage(
&self,
chain: &GeneratedChain,
directive: &Directive,
) -> Result<(), PlannerError> {
let required_ids: HashSet<String> = directive
.requirements
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("id").and_then(|id| id.as_str()))
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default();
let covered_ids: HashSet<String> = chain
.steps
.iter()
.flat_map(|s| s.requirement_ids.clone())
.collect();
for req_id in required_ids {
if !covered_ids.contains(&req_id) {
return Err(PlannerError::RequirementNotCovered(req_id));
}
}
Ok(())
}
/// Get topological order of steps.
pub fn topological_sort<'a>(
&self,
chain: &'a GeneratedChain,
) -> Result<Vec<&'a str>, PlannerError> {
let mut in_degree: HashMap<&str, usize> = HashMap::new();
let mut adj: HashMap<&str, Vec<&str>> = HashMap::new();
// Initialize
for step in &chain.steps {
in_degree.entry(step.name.as_str()).or_insert(0);
adj.entry(step.name.as_str()).or_insert_with(Vec::new);
}
// Build graph (reversed - edges from dependency to dependent)
for step in &chain.steps {
for dep in &step.depends_on {
adj.entry(dep.as_str())
.or_insert_with(Vec::new)
.push(step.name.as_str());
*in_degree.entry(step.name.as_str()).or_insert(0) += 1;
}
}
// Kahn's algorithm
let mut queue: Vec<&str> = in_degree
.iter()
.filter(|&(_, deg)| *deg == 0)
.map(|(&name, _)| name)
.collect();
let mut result = Vec::new();
while let Some(node) = queue.pop() {
result.push(node);
if let Some(neighbors) = adj.get(node) {
for &neighbor in neighbors {
let deg = in_degree.get_mut(neighbor).unwrap();
*deg -= 1;
if *deg == 0 {
queue.push(neighbor);
}
}
}
}
if result.len() != chain.steps.len() {
return Err(PlannerError::CycleDetected(
"Cycle detected during topological sort".to_string(),
));
}
Ok(result)
}
/// Convert generated steps to AddStepRequest for database insertion.
pub fn steps_to_requests(
&self,
chain: &GeneratedChain,
step_id_map: &HashMap<String, Uuid>,
) -> Vec<AddStepRequest> {
chain
.steps
.iter()
.map(|step| {
let depends_on: Vec<Uuid> = step
.depends_on
.iter()
.filter_map(|name| step_id_map.get(name))
.copied()
.collect();
let task_plan = step
.contract_template
.as_ref()
.and_then(|t| t.tasks.first())
.map(|t| t.plan.clone());
AddStepRequest {
name: step.name.clone(),
description: Some(step.description.clone()),
step_type: Some(step.step_type.clone()),
contract_type: step.contract_template.as_ref().map(|t| t.contract_type.clone()),
initial_phase: Some("plan".to_string()),
task_plan,
phases: step.contract_template.as_ref().map(|t| t.phases.clone()),
depends_on: Some(depends_on),
parallel_group: None,
requirement_ids: Some(step.requirement_ids.clone()),
acceptance_criteria_ids: None,
verifier_config: None,
editor_x: None,
editor_y: None,
}
})
.collect()
}
/// Compute editor positions for steps based on DAG layout.
pub fn compute_editor_positions(
&self,
chain: &GeneratedChain,
) -> HashMap<String, (f64, f64)> {
let depths = self.get_step_depths(chain);
let mut positions: HashMap<String, (f64, f64)> = HashMap::new();
// Group by depth
let mut by_depth: HashMap<usize, Vec<&str>> = HashMap::new();
for step in &chain.steps {
let depth = depths.get(&step.name).copied().unwrap_or(0);
by_depth.entry(depth).or_default().push(&step.name);
}
// Compute positions: x based on depth, y based on index within depth
let x_spacing = 250.0;
let y_spacing = 150.0;
for (depth, steps) in &by_depth {
let x = (*depth as f64) * x_spacing + 100.0;
for (i, name) in steps.iter().enumerate() {
let y = (i as f64) * y_spacing + 100.0;
positions.insert(name.to_string(), (x, y));
}
}
positions
}
/// Get depth of each step in the DAG.
fn get_step_depths(&self, chain: &GeneratedChain) -> HashMap<String, usize> {
let mut depths: HashMap<String, usize> = HashMap::new();
// Build dependency map
let deps: HashMap<String, Vec<String>> = chain
.steps
.iter()
.map(|s| (s.name.clone(), s.depends_on.clone()))
.collect();
fn compute_depth(
name: &str,
deps: &HashMap<String, Vec<String>>,
depths: &mut HashMap<String, usize>,
) -> usize {
if let Some(&d) = depths.get(name) {
return d;
}
let depth = deps
.get(name)
.map(|dep_list| {
dep_list
.iter()
.map(|d| compute_depth(d, deps, depths) + 1)
.max()
.unwrap_or(0)
})
.unwrap_or(0);
depths.insert(name.to_string(), depth);
depth
}
for step in &chain.steps {
compute_depth(&step.name, &deps, &mut depths);
}
depths
}
/// Build a replanning prompt that preserves completed steps.
pub fn build_replan_prompt(
&self,
directive: &Directive,
completed_steps: &[ChainStep],
failed_step: Option<&ChainStep>,
reason: &str,
) -> String {
let completed_summary: Vec<String> = completed_steps
.iter()
.map(|s| format!("- {} ({}): completed", s.name, s.step_type))
.collect();
let failed_summary = failed_step
.map(|s| format!("Failed step: {} - {}", s.name, s.description.as_deref().unwrap_or("")))
.unwrap_or_default();
format!(
r#"You are a software architect replanning an execution chain.
## Original Goal
{goal}
## Completed Steps (preserve these)
{completed}
## Failure Information
{failed}
Reason: {reason}
## Instructions
Generate a new execution plan that:
1. Preserves all completed work
2. Addresses the failure
3. Continues toward the original goal
Use the same JSON format as before. Do not include already completed steps."#,
goal = directive.goal,
completed = completed_summary.join("\n"),
failed = failed_summary,
reason = reason,
)
}
}
/// Extract JSON from LLM response (handles markdown code blocks).
fn extract_json_from_response(response: &str) -> Result<String, PlannerError> {
// Try to find JSON in code block
if let Some(start) = response.find("```json") {
let json_start = start + 7;
if let Some(end) = response[json_start..].find("```") {
return Ok(response[json_start..json_start + end].trim().to_string());
}
}
// Try to find JSON in generic code block
if let Some(start) = response.find("```") {
let block_start = start + 3;
// Skip language identifier if present
let json_start = response[block_start..]
.find('\n')
.map(|i| block_start + i + 1)
.unwrap_or(block_start);
if let Some(end) = response[json_start..].find("```") {
return Ok(response[json_start..json_start + end].trim().to_string());
}
}
// Try to parse the whole thing as JSON
if response.trim().starts_with('{') {
return Ok(response.trim().to_string());
}
Err(PlannerError::InvalidPlan(
"Could not extract JSON from response".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_chain() -> GeneratedChain {
GeneratedChain {
name: "test-chain".to_string(),
description: "Test chain".to_string(),
steps: vec![
GeneratedStep {
name: "step-a".to_string(),
step_type: "research".to_string(),
description: "Research step".to_string(),
depends_on: vec![],
requirement_ids: vec!["REQ-001".to_string()],
contract_template: None,
},
GeneratedStep {
name: "step-b".to_string(),
step_type: "implement".to_string(),
description: "Implementation step".to_string(),
depends_on: vec!["step-a".to_string()],
requirement_ids: vec!["REQ-002".to_string()],
contract_template: None,
},
GeneratedStep {
name: "step-c".to_string(),
step_type: "test".to_string(),
description: "Test step".to_string(),
depends_on: vec!["step-b".to_string()],
requirement_ids: vec!["REQ-001".to_string()],
contract_template: None,
},
],
}
}
#[test]
fn test_validate_chain_valid() {
let planner = ChainPlanner::new();
let chain = make_test_chain();
assert!(planner.validate_chain(&chain).is_ok());
}
#[test]
fn test_validate_chain_invalid_dependency() {
let planner = ChainPlanner::new();
let mut chain = make_test_chain();
chain.steps[1].depends_on = vec!["nonexistent".to_string()];
let result = planner.validate_chain(&chain);
assert!(matches!(result, Err(PlannerError::InvalidDependency { .. })));
}
#[test]
fn test_validate_chain_cycle() {
let planner = ChainPlanner::new();
let chain = GeneratedChain {
name: "cyclic".to_string(),
description: "Has cycle".to_string(),
steps: vec![
GeneratedStep {
name: "a".to_string(),
step_type: "research".to_string(),
description: "A".to_string(),
depends_on: vec!["c".to_string()],
requirement_ids: vec![],
contract_template: None,
},
GeneratedStep {
name: "b".to_string(),
step_type: "implement".to_string(),
description: "B".to_string(),
depends_on: vec!["a".to_string()],
requirement_ids: vec![],
contract_template: None,
},
GeneratedStep {
name: "c".to_string(),
step_type: "test".to_string(),
description: "C".to_string(),
depends_on: vec!["b".to_string()],
requirement_ids: vec![],
contract_template: None,
},
],
};
let result = planner.validate_chain(&chain);
assert!(matches!(result, Err(PlannerError::CycleDetected(_))));
}
#[test]
fn test_topological_sort() {
let planner = ChainPlanner::new();
let chain = make_test_chain();
let order = planner.topological_sort(&chain).unwrap();
// step-a must come before step-b, step-b before step-c
let pos_a = order.iter().position(|&n| n == "step-a").unwrap();
let pos_b = order.iter().position(|&n| n == "step-b").unwrap();
let pos_c = order.iter().position(|&n| n == "step-c").unwrap();
assert!(pos_a < pos_b);
assert!(pos_b < pos_c);
}
#[test]
fn test_extract_json_from_code_block() {
let response = r#"
Here's the plan:
```json
{"name": "test"}
```
That's it!
"#;
let json = extract_json_from_response(response).unwrap();
assert_eq!(json, r#"{"name": "test"}"#);
}
#[test]
fn test_extract_json_raw() {
let response = r#"{"name": "test"}"#;
let json = extract_json_from_response(response).unwrap();
assert_eq!(json, r#"{"name": "test"}"#);
}
}