//! Commit discipline and validation for task commits.
//!
//! This module enforces structured commit messages and optional quality checks
//! before checkpoint commits are created. It follows the conventional commit
//! format and always appends a Co-Authored-By trailer.
use serde::{Deserialize, Serialize};
use std::path::Path;
use thiserror::Error;
use tokio::process::Command;
use uuid::Uuid;
/// Errors that can occur during commit validation.
#[derive(Debug, Error)]
pub enum CommitValidationError {
#[error("Invalid commit message format: {0}")]
InvalidFormat(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Quality check failed: {0}")]
QualityCheckFailed(String),
#[error("Lint check failed: {0}")]
LintFailed(String),
#[error("Tests failed: {0}")]
TestsFailed(String),
#[error("Command execution failed: {0}")]
CommandFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
/// Commit message format style.
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum MessageFormat {
/// Conventional commit format: feat/fix/chore: [Task ID] - [Summary]
#[default]
Conventional,
/// Simple format: [Task ID] - [Summary]
Simple,
}
impl MessageFormat {
pub fn as_str(&self) -> &'static str {
match self {
MessageFormat::Conventional => "conventional",
MessageFormat::Simple => "simple",
}
}
}
impl std::str::FromStr for MessageFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"conventional" => Ok(MessageFormat::Conventional),
"simple" => Ok(MessageFormat::Simple),
_ => Err(format!("Unknown message format: {}. Use 'conventional' or 'simple'", s)),
}
}
}
/// Commit type for conventional commits.
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum CommitType {
#[default]
Feat,
Fix,
Chore,
Docs,
Style,
Refactor,
Perf,
Test,
Build,
Ci,
}
impl CommitType {
pub fn as_str(&self) -> &'static str {
match self {
CommitType::Feat => "feat",
CommitType::Fix => "fix",
CommitType::Chore => "chore",
CommitType::Docs => "docs",
CommitType::Style => "style",
CommitType::Refactor => "refactor",
CommitType::Perf => "perf",
CommitType::Test => "test",
CommitType::Build => "build",
CommitType::Ci => "ci",
}
}
}
impl std::str::FromStr for CommitType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"feat" | "feature" => Ok(CommitType::Feat),
"fix" => Ok(CommitType::Fix),
"chore" => Ok(CommitType::Chore),
"docs" | "doc" => Ok(CommitType::Docs),
"style" => Ok(CommitType::Style),
"refactor" => Ok(CommitType::Refactor),
"perf" | "performance" => Ok(CommitType::Perf),
"test" | "tests" => Ok(CommitType::Test),
"build" => Ok(CommitType::Build),
"ci" => Ok(CommitType::Ci),
_ => Err(format!("Unknown commit type: {}. Use feat/fix/chore/docs/style/refactor/perf/test/build/ci", s)),
}
}
}
/// Result of quality check execution.
#[derive(Debug, Clone)]
pub struct QualityCheckResult {
/// Whether the check passed.
pub passed: bool,
/// Name of the check.
pub check_name: String,
/// Output from the check command.
pub output: String,
/// Exit code from the check command.
pub exit_code: Option<i32>,
/// Duration of the check in milliseconds.
pub duration_ms: u64,
}
/// Configuration for commit discipline.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct CommitDisciplineConfig {
/// Enable commit discipline (always enabled by default - this is opinionated).
pub enabled: bool,
/// Require tests to pass before commits (optional, controlled by --require-tests flag).
pub require_tests: bool,
/// Require lint to pass before commits (optional).
pub require_lint: bool,
/// Commit message format: "conventional" or "simple".
pub message_format: MessageFormat,
/// Custom test command (auto-detected if not set).
pub test_command: Option<String>,
/// Custom lint command (auto-detected if not set).
pub lint_command: Option<String>,
/// Timeout for quality checks in seconds.
#[serde(default = "default_check_timeout")]
pub check_timeout_secs: u64,
}
fn default_check_timeout() -> u64 {
300 // 5 minutes
}
impl Default for CommitDisciplineConfig {
fn default() -> Self {
Self {
enabled: true,
require_tests: false,
require_lint: false,
message_format: MessageFormat::Conventional,
test_command: None,
lint_command: None,
check_timeout_secs: default_check_timeout(),
}
}
}
/// Co-Authored-By trailer for commits.
const CO_AUTHOR_TRAILER: &str = "Co-Authored-By: Claude <noreply@anthropic.com>";
/// Validator for commit messages and quality checks.
pub struct CommitValidator {
config: CommitDisciplineConfig,
}
impl CommitValidator {
/// Create a new commit validator with the given configuration.
pub fn new(config: CommitDisciplineConfig) -> Self {
Self { config }
}
/// Create a new commit validator with default configuration.
pub fn with_defaults() -> Self {
Self::new(CommitDisciplineConfig::default())
}
/// Get the current configuration.
pub fn config(&self) -> &CommitDisciplineConfig {
&self.config
}
/// Validate a commit message against the configured format.
///
/// Returns Ok(()) if valid, or an error describing the validation failure.
pub fn validate_message(&self, message: &str) -> Result<(), CommitValidationError> {
if message.trim().is_empty() {
return Err(CommitValidationError::MissingField("commit message".to_string()));
}
// Get first line (subject)
let subject = message.lines().next().unwrap_or("");
if subject.is_empty() {
return Err(CommitValidationError::MissingField("commit subject".to_string()));
}
// Check subject length (recommended max 72 chars)
if subject.len() > 100 {
tracing::warn!(
"Commit subject exceeds recommended length (100 chars): {} chars",
subject.len()
);
}
// For conventional format, validate the prefix
if self.config.message_format == MessageFormat::Conventional {
let valid_prefixes = [
"feat:", "fix:", "chore:", "docs:", "style:",
"refactor:", "perf:", "test:", "build:", "ci:",
// Also allow with scope: feat(scope):
"feat(", "fix(", "chore(", "docs(", "style(",
"refactor(", "perf(", "test(", "build(", "ci(",
];
let has_valid_prefix = valid_prefixes.iter().any(|prefix| {
subject.starts_with(prefix)
});
if !has_valid_prefix {
return Err(CommitValidationError::InvalidFormat(
format!(
"Commit message must start with a conventional commit type (feat/fix/chore/docs/style/refactor/perf/test/build/ci). Got: {}",
subject.chars().take(30).collect::<String>()
)
));
}
}
Ok(())
}
/// Format a commit message according to the configured format.
///
/// Always appends the Co-Authored-By trailer.
pub fn format_message(
&self,
task_id: Uuid,
summary: &str,
body: Option<&str>,
commit_type: Option<CommitType>,
) -> String {
let short_id = &task_id.to_string()[..8];
let commit_type = commit_type.unwrap_or_default();
// Build subject line based on format
let subject = match self.config.message_format {
MessageFormat::Conventional => {
format!("{}: [{}] {}", commit_type.as_str(), short_id, summary.trim())
}
MessageFormat::Simple => {
format!("[{}] {}", short_id, summary.trim())
}
};
// Build full message with optional body and trailer
let mut message = subject;
if let Some(body_text) = body {
if !body_text.trim().is_empty() {
message.push_str("\n\n");
message.push_str(body_text.trim());
}
}
// Always append Co-Authored-By trailer
message.push_str("\n\n");
message.push_str(CO_AUTHOR_TRAILER);
message
}
/// Format a heartbeat/WIP commit message.
pub fn format_heartbeat_message(&self, task_id: Uuid, iteration: Option<u32>) -> String {
let short_id = &task_id.to_string()[..8];
let timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC");
let summary = match iteration {
Some(n) => format!("WIP checkpoint (iteration {}) - {}", n, timestamp),
None => format!("WIP checkpoint - {}", timestamp),
};
match self.config.message_format {
MessageFormat::Conventional => {
format!(
"chore: [{}] {}\n\n{}",
short_id,
summary,
CO_AUTHOR_TRAILER
)
}
MessageFormat::Simple => {
format!(
"[{}] {}\n\n{}",
short_id,
summary,
CO_AUTHOR_TRAILER
)
}
}
}
/// Format a checkpoint commit message with optional progress info.
pub fn format_checkpoint_message(
&self,
task_id: Uuid,
user_message: &str,
files_changed: Option<&[String]>,
) -> String {
let short_id = &task_id.to_string()[..8];
// Use user message as summary, or generate one
let summary = if user_message.trim().is_empty() {
"Checkpoint commit".to_string()
} else {
user_message.trim().to_string()
};
// Build body with file list if provided
let body = files_changed.map(|files| {
if files.is_empty() {
String::new()
} else {
let file_list = files.iter()
.take(20) // Limit to 20 files
.map(|f| format!("- {}", f))
.collect::<Vec<_>>()
.join("\n");
if files.len() > 20 {
format!("Files changed:\n{}\n... and {} more", file_list, files.len() - 20)
} else {
format!("Files changed:\n{}", file_list)
}
}
});
self.format_message(
task_id,
&summary,
body.as_deref(),
Some(CommitType::Chore),
)
}
/// Run quality checks before committing.
///
/// Returns Ok(results) with all check results, or Err if any required check fails.
pub async fn run_quality_checks(
&self,
worktree_path: &Path,
) -> Result<Vec<QualityCheckResult>, CommitValidationError> {
let mut results = Vec::new();
// Run lint check if configured
if self.config.require_lint {
let lint_result = self.run_lint_check(worktree_path).await?;
let passed = lint_result.passed;
results.push(lint_result);
if !passed {
return Err(CommitValidationError::LintFailed(
results.last().map(|r| r.output.clone()).unwrap_or_default()
));
}
}
// Run tests if configured
if self.config.require_tests {
let test_result = self.run_test_check(worktree_path).await?;
let passed = test_result.passed;
results.push(test_result);
if !passed {
return Err(CommitValidationError::TestsFailed(
results.last().map(|r| r.output.clone()).unwrap_or_default()
));
}
}
Ok(results)
}
/// Run lint check.
async fn run_lint_check(&self, worktree_path: &Path) -> Result<QualityCheckResult, CommitValidationError> {
let cmd = match &self.config.lint_command {
Some(cmd) => cmd.clone(),
None => self.detect_lint_command(worktree_path).await,
};
if cmd.is_empty() {
return Ok(QualityCheckResult {
passed: true,
check_name: "lint".to_string(),
output: "No lint command detected, skipping".to_string(),
exit_code: None,
duration_ms: 0,
});
}
self.run_check_command("lint", &cmd, worktree_path).await
}
/// Run test check.
async fn run_test_check(&self, worktree_path: &Path) -> Result<QualityCheckResult, CommitValidationError> {
let cmd = match &self.config.test_command {
Some(cmd) => cmd.clone(),
None => self.detect_test_command(worktree_path).await,
};
if cmd.is_empty() {
return Ok(QualityCheckResult {
passed: true,
check_name: "test".to_string(),
output: "No test command detected, skipping".to_string(),
exit_code: None,
duration_ms: 0,
});
}
self.run_check_command("test", &cmd, worktree_path).await
}
/// Run a check command and collect results.
async fn run_check_command(
&self,
check_name: &str,
command: &str,
worktree_path: &Path,
) -> Result<QualityCheckResult, CommitValidationError> {
tracing::info!(
check = check_name,
command = command,
path = %worktree_path.display(),
"Running quality check"
);
let start = std::time::Instant::now();
// Parse command into program and args
let parts: Vec<&str> = command.split_whitespace().collect();
if parts.is_empty() {
return Err(CommitValidationError::CommandFailed(
format!("Empty command for check: {}", check_name)
));
}
let program = parts[0];
let args = &parts[1..];
let timeout = std::time::Duration::from_secs(self.config.check_timeout_secs);
let output = tokio::time::timeout(
timeout,
Command::new(program)
.args(args)
.current_dir(worktree_path)
.output()
)
.await
.map_err(|_| CommitValidationError::CommandFailed(
format!("Check '{}' timed out after {} seconds", check_name, self.config.check_timeout_secs)
))?
.map_err(|e| CommitValidationError::CommandFailed(
format!("Failed to run '{}': {}", check_name, e)
))?;
let duration_ms = start.elapsed().as_millis() as u64;
let passed = output.status.success();
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let combined_output = if stderr.is_empty() {
stdout.to_string()
} else if stdout.is_empty() {
stderr.to_string()
} else {
format!("{}\n{}", stdout, stderr)
};
// Limit output size
let output_trimmed = if combined_output.len() > 10000 {
format!("{}...\n[output truncated]", &combined_output[..10000])
} else {
combined_output
};
tracing::info!(
check = check_name,
passed = passed,
duration_ms = duration_ms,
exit_code = output.status.code(),
"Quality check completed"
);
Ok(QualityCheckResult {
passed,
check_name: check_name.to_string(),
output: output_trimmed,
exit_code: output.status.code(),
duration_ms,
})
}
/// Detect the appropriate test command based on project files.
async fn detect_test_command(&self, worktree_path: &Path) -> String {
// Check for Cargo.toml (Rust)
if worktree_path.join("Cargo.toml").exists() {
return "cargo test".to_string();
}
// Check for package.json (Node.js)
if worktree_path.join("package.json").exists() {
// Check if there's a test script
if let Ok(content) = tokio::fs::read_to_string(worktree_path.join("package.json")).await {
if content.contains("\"test\"") {
return "npm test".to_string();
}
}
}
// Check for pytest (Python)
if worktree_path.join("pytest.ini").exists()
|| worktree_path.join("pyproject.toml").exists()
|| worktree_path.join("setup.py").exists()
{
return "pytest".to_string();
}
// Check for Go
if worktree_path.join("go.mod").exists() {
return "go test ./...".to_string();
}
// Check for Maven (Java)
if worktree_path.join("pom.xml").exists() {
return "mvn test".to_string();
}
// Check for Gradle (Java/Kotlin)
if worktree_path.join("build.gradle").exists() || worktree_path.join("build.gradle.kts").exists() {
return "./gradlew test".to_string();
}
String::new()
}
/// Detect the appropriate lint command based on project files.
async fn detect_lint_command(&self, worktree_path: &Path) -> String {
// Check for Cargo.toml (Rust)
if worktree_path.join("Cargo.toml").exists() {
return "cargo clippy --all-targets".to_string();
}
// Check for package.json (Node.js)
if worktree_path.join("package.json").exists() {
// Check if there's a lint script
if let Ok(content) = tokio::fs::read_to_string(worktree_path.join("package.json")).await {
if content.contains("\"lint\"") {
return "npm run lint".to_string();
}
// Check for eslint
if content.contains("eslint") {
return "npx eslint .".to_string();
}
}
}
// Check for Python linters
if worktree_path.join("pyproject.toml").exists() {
if let Ok(content) = tokio::fs::read_to_string(worktree_path.join("pyproject.toml")).await {
if content.contains("[tool.ruff]") {
return "ruff check .".to_string();
}
if content.contains("[tool.flake8]") {
return "flake8".to_string();
}
}
}
// Check for Go
if worktree_path.join("go.mod").exists() {
return "go vet ./...".to_string();
}
String::new()
}
/// Append the Co-Authored-By trailer to an existing message if not present.
pub fn ensure_co_author_trailer(&self, message: &str) -> String {
if message.contains(CO_AUTHOR_TRAILER) {
message.to_string()
} else {
format!("{}\n\n{}", message.trim_end(), CO_AUTHOR_TRAILER)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_conventional_message() {
let validator = CommitValidator::with_defaults();
let task_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
let msg = validator.format_message(
task_id,
"Add user authentication",
None,
Some(CommitType::Feat),
);
assert!(msg.starts_with("feat: [550e8400] Add user authentication"));
assert!(msg.contains(CO_AUTHOR_TRAILER));
}
#[test]
fn test_format_simple_message() {
let config = CommitDisciplineConfig {
message_format: MessageFormat::Simple,
..Default::default()
};
let validator = CommitValidator::new(config);
let task_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
let msg = validator.format_message(
task_id,
"Add user authentication",
None,
Some(CommitType::Feat),
);
assert!(msg.starts_with("[550e8400] Add user authentication"));
assert!(msg.contains(CO_AUTHOR_TRAILER));
}
#[test]
fn test_format_message_with_body() {
let validator = CommitValidator::with_defaults();
let task_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
let msg = validator.format_message(
task_id,
"Fix login bug",
Some("- Fixed null pointer exception\n- Added input validation"),
Some(CommitType::Fix),
);
assert!(msg.contains("fix: [550e8400] Fix login bug"));
assert!(msg.contains("Fixed null pointer exception"));
assert!(msg.contains(CO_AUTHOR_TRAILER));
}
#[test]
fn test_validate_conventional_message() {
let validator = CommitValidator::with_defaults();
// Valid messages
assert!(validator.validate_message("feat: add new feature").is_ok());
assert!(validator.validate_message("fix: resolve bug").is_ok());
assert!(validator.validate_message("chore: update deps").is_ok());
assert!(validator.validate_message("feat(auth): add login").is_ok());
// Invalid messages
assert!(validator.validate_message("").is_err());
assert!(validator.validate_message("add new feature").is_err());
assert!(validator.validate_message("FEAT: uppercase").is_err());
}
#[test]
fn test_validate_simple_message() {
let config = CommitDisciplineConfig {
message_format: MessageFormat::Simple,
..Default::default()
};
let validator = CommitValidator::new(config);
// Simple format accepts any non-empty message
assert!(validator.validate_message("any message").is_ok());
assert!(validator.validate_message("[task-id] description").is_ok());
assert!(validator.validate_message("").is_err());
}
#[test]
fn test_format_heartbeat_message() {
let validator = CommitValidator::with_defaults();
let task_id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
let msg = validator.format_heartbeat_message(task_id, Some(3));
assert!(msg.contains("chore: [550e8400]"));
assert!(msg.contains("WIP checkpoint (iteration 3)"));
assert!(msg.contains(CO_AUTHOR_TRAILER));
}
#[test]
fn test_ensure_co_author_trailer() {
let validator = CommitValidator::with_defaults();
let msg_without = "feat: something";
let result = validator.ensure_co_author_trailer(msg_without);
assert!(result.contains(CO_AUTHOR_TRAILER));
let msg_with = format!("feat: something\n\n{}", CO_AUTHOR_TRAILER);
let result = validator.ensure_co_author_trailer(&msg_with);
// Should not duplicate
assert_eq!(result.matches(CO_AUTHOR_TRAILER).count(), 1);
}
#[test]
fn test_commit_type_parsing() {
assert_eq!("feat".parse::<CommitType>().unwrap(), CommitType::Feat);
assert_eq!("feature".parse::<CommitType>().unwrap(), CommitType::Feat);
assert_eq!("fix".parse::<CommitType>().unwrap(), CommitType::Fix);
assert_eq!("docs".parse::<CommitType>().unwrap(), CommitType::Docs);
assert!("invalid".parse::<CommitType>().is_err());
}
#[test]
fn test_message_format_parsing() {
assert_eq!("conventional".parse::<MessageFormat>().unwrap(), MessageFormat::Conventional);
assert_eq!("simple".parse::<MessageFormat>().unwrap(), MessageFormat::Simple);
assert!("invalid".parse::<MessageFormat>().is_err());
}
}