//! Git patch creation and application for checkpoint recovery. use flate2::read::GzDecoder; use flate2::write::GzEncoder; use flate2::Compression; use std::io::{Read, Write}; use std::path::Path; use thiserror::Error; use tokio::process::Command; /// Errors that can occur during patch operations. #[derive(Error, Debug)] pub enum PatchError { #[error("Git command failed: {0}")] GitCommand(String), #[error("Compression error: {0}")] Compression(#[from] std::io::Error), #[error("Patch too large: {size} bytes (max: {max} bytes)")] TooLarge { size: usize, max: usize }, #[error("Empty patch (no changes)")] EmptyPatch, #[error("Failed to apply patch: {0}")] ApplyFailed(String), } /// Create a compressed git diff from worktree changes. /// /// Generates a diff between `base_sha` and HEAD, then compresses it with gzip. /// Returns the compressed patch bytes and the number of files changed. pub async fn create_patch( worktree_path: &Path, base_sha: &str, ) -> Result<(Vec, usize), PatchError> { // Get the diff between base commit and HEAD let output = Command::new("git") .current_dir(worktree_path) .args(["diff", base_sha, "HEAD", "--binary"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to run git diff: {}", e)))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(PatchError::GitCommand(format!("git diff failed: {}", stderr))); } let diff_data = output.stdout; if diff_data.is_empty() { return Err(PatchError::EmptyPatch); } // Count files changed let files_output = Command::new("git") .current_dir(worktree_path) .args(["diff", base_sha, "HEAD", "--name-only"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to count files: {}", e)))?; let files_count = if files_output.status.success() { String::from_utf8_lossy(&files_output.stdout) .lines() .filter(|l| !l.is_empty()) .count() } else { 0 }; // Compress with gzip let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); encoder.write_all(&diff_data)?; let compressed = encoder.finish()?; Ok((compressed, files_count)) } /// Apply a compressed patch to restore worktree state. /// /// The worktree should already be checked out at `base_sha` before calling this. pub async fn apply_patch(worktree_path: &Path, patch_data: &[u8]) -> Result<(), PatchError> { // Decompress gzip let mut decoder = GzDecoder::new(patch_data); let mut decompressed = Vec::new(); decoder.read_to_end(&mut decompressed)?; if decompressed.is_empty() { return Err(PatchError::EmptyPatch); } // Apply the patch using git apply let mut child = Command::new("git") .current_dir(worktree_path) .args(["apply", "--binary", "-"]) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) .spawn() .map_err(|e| PatchError::GitCommand(format!("Failed to spawn git apply: {}", e)))?; // Write patch to stdin if let Some(mut stdin) = child.stdin.take() { use tokio::io::AsyncWriteExt; stdin.write_all(&decompressed).await?; drop(stdin); // Close stdin to signal EOF } let output = child .wait_with_output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to wait for git apply: {}", e)))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(PatchError::ApplyFailed(stderr.to_string())); } Ok(()) } /// Get the parent commit SHA (HEAD~1) from a worktree. pub async fn get_parent_sha(worktree_path: &Path) -> Result { let output = Command::new("git") .current_dir(worktree_path) .args(["rev-parse", "HEAD~1"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to get parent SHA: {}", e)))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(PatchError::GitCommand(format!( "git rev-parse HEAD~1 failed: {}", stderr ))); } Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) } /// Get the current HEAD commit SHA from a worktree. pub async fn get_head_sha(worktree_path: &Path) -> Result { let output = Command::new("git") .current_dir(worktree_path) .args(["rev-parse", "HEAD"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to get HEAD SHA: {}", e)))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(PatchError::GitCommand(format!( "git rev-parse HEAD failed: {}", stderr ))); } Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) } /// Resolve the merge-base SHA for diffing against the main/master branch. /// /// Tries in order: /// 1. Upstream tracking branch merge-base /// 2. Common branches: origin/main, origin/master, main, master /// 3. Fallback: HEAD~1 /// /// Returns `Err(PatchError::EmptyPatch)` if the merge-base equals HEAD (no diff). pub async fn get_merge_base_sha(worktree_path: &Path) -> Result { // Try to get the upstream tracking branch let upstream_result = Command::new("git") .current_dir(worktree_path) .args(["rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"]) .output() .await; let base = if let Ok(output) = upstream_result { if output.status.success() { let upstream = String::from_utf8_lossy(&output.stdout).trim().to_string(); // Get merge-base with upstream let merge_base = Command::new("git") .current_dir(worktree_path) .args(["merge-base", "HEAD", &upstream]) .output() .await; if let Ok(mb_output) = merge_base { if mb_output.status.success() { Some( String::from_utf8_lossy(&mb_output.stdout) .trim() .to_string(), ) } else { None } } else { None } } else { None } } else { None }; // Get current HEAD SHA for comparison let head_sha = Command::new("git") .current_dir(worktree_path) .args(["rev-parse", "HEAD"]) .output() .await .ok() .filter(|o| o.status.success()) .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()); // If we couldn't find upstream, try common default branches let base = if base.is_none() { let default_branches = ["origin/main", "origin/master", "main", "master"]; let mut found_base = None; for branch in default_branches { let merge_base = Command::new("git") .current_dir(worktree_path) .args(["merge-base", "HEAD", branch]) .output() .await; if let Ok(output) = merge_base { if output.status.success() { let mb_sha = String::from_utf8_lossy(&output.stdout).trim().to_string(); // Skip if merge-base equals HEAD (would result in empty diff) if head_sha.as_ref() != Some(&mb_sha) { found_base = Some(mb_sha); break; } } } } found_base } else { // Also check upstream base if base.as_ref() == head_sha.as_ref() { None } else { base } }; // If still nothing, fall back to HEAD~1 Ok(base.unwrap_or_else(|| "HEAD~1".to_string())) } /// Result of creating an export patch. #[derive(Debug, Clone)] pub struct ExportPatchResult { /// The uncompressed, human-readable patch content. pub patch_content: String, /// Number of files changed in the patch. pub files_count: usize, /// Number of lines added. pub lines_added: usize, /// Number of lines removed. pub lines_removed: usize, /// The base commit SHA that the patch is diffed against. pub base_commit_sha: String, } /// Create an uncompressed git diff patch for export. /// /// This creates a human-readable patch that can be applied manually or /// shared as a file. Unlike `create_patch`, this version is not compressed /// and is suitable for display or export. /// /// If `base_sha` is provided, the diff is between that commit and HEAD. /// If `base_sha` is None, it attempts to find the merge-base with the default branch /// or falls back to diffing uncommitted changes against HEAD. pub async fn create_export_patch( worktree_path: &Path, base_sha: Option<&str>, ) -> Result { // Determine the base SHA to diff against let resolved_base_sha = match base_sha { Some(sha) => sha.to_string(), None => get_merge_base_sha(worktree_path).await?, }; // Get diff stats using --stat let stat_output = Command::new("git") .current_dir(worktree_path) .args(["diff", "--stat", &resolved_base_sha, "HEAD"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to run git diff --stat: {}", e)))?; // Parse the stat output to get line counts let (lines_added, lines_removed) = if stat_output.status.success() { parse_diff_stat(&String::from_utf8_lossy(&stat_output.stdout)) } else { (0, 0) }; // Get the actual diff content let diff_output = Command::new("git") .current_dir(worktree_path) .args(["diff", &resolved_base_sha, "HEAD"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to run git diff: {}", e)))?; if !diff_output.status.success() { let stderr = String::from_utf8_lossy(&diff_output.stderr); return Err(PatchError::GitCommand(format!("git diff failed: {}", stderr))); } let patch_content = String::from_utf8_lossy(&diff_output.stdout).to_string(); // Check for empty patch if patch_content.trim().is_empty() { return Err(PatchError::EmptyPatch); } // Count files changed let files_output = Command::new("git") .current_dir(worktree_path) .args(["diff", "--name-only", &resolved_base_sha, "HEAD"]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to count files: {}", e)))?; let files_count = if files_output.status.success() { String::from_utf8_lossy(&files_output.stdout) .lines() .filter(|l| !l.is_empty()) .count() } else { 0 }; Ok(ExportPatchResult { patch_content, files_count, lines_added, lines_removed, base_commit_sha: resolved_base_sha, }) } /// Parse git diff --stat output to extract lines added and removed. /// The last line typically looks like: " 3 files changed, 45 insertions(+), 12 deletions(-)" fn parse_diff_stat(stat_output: &str) -> (usize, usize) { let mut lines_added = 0; let mut lines_removed = 0; // Look for the summary line at the end for line in stat_output.lines().rev() { let line = line.trim(); if line.contains("changed") || line.contains("insertion") || line.contains("deletion") { // Parse insertions if let Some(idx) = line.find("insertion") { let before = &line[..idx]; if let Some(num_str) = before.split_whitespace().last() { if let Ok(num) = num_str.parse::() { lines_added = num; } } } // Parse deletions if let Some(idx) = line.find("deletion") { let before = &line[..idx]; if let Some(num_str) = before.split(',').last() { if let Some(num_str) = num_str.trim().split_whitespace().next() { if let Ok(num) = num_str.parse::() { lines_removed = num; } } } } break; } } (lines_added, lines_removed) } /// Checkout a specific commit in the worktree. pub async fn checkout_commit(worktree_path: &Path, sha: &str) -> Result<(), PatchError> { let output = Command::new("git") .current_dir(worktree_path) .args(["checkout", sha]) .output() .await .map_err(|e| PatchError::GitCommand(format!("Failed to checkout: {}", e)))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(PatchError::GitCommand(format!( "git checkout {} failed: {}", sha, stderr ))); } Ok(()) } #[cfg(test)] mod tests { use super::*; use std::fs; use tempfile::TempDir; async fn setup_test_repo() -> TempDir { let dir = TempDir::new().unwrap(); let path = dir.path(); // Initialize git repo Command::new("git") .current_dir(path) .args(["init"]) .output() .await .unwrap(); // Configure git user Command::new("git") .current_dir(path) .args(["config", "user.email", "test@test.com"]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["config", "user.name", "Test"]) .output() .await .unwrap(); // Create initial commit fs::write(path.join("file.txt"), "initial").unwrap(); Command::new("git") .current_dir(path) .args(["add", "."]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["commit", "-m", "initial"]) .output() .await .unwrap(); dir } #[tokio::test] async fn test_create_and_apply_patch() { let dir = setup_test_repo().await; let path = dir.path(); // Get base SHA let base_sha = get_parent_sha(path).await; // This will fail since there's only one commit assert!(base_sha.is_err()); // Make another commit first fs::write(path.join("file.txt"), "modified").unwrap(); Command::new("git") .current_dir(path) .args(["add", "."]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["commit", "-m", "modified"]) .output() .await .unwrap(); // Now get the base SHA let base_sha = get_parent_sha(path).await.unwrap(); // Create patch let (patch_data, files_count) = create_patch(path, &base_sha).await.unwrap(); assert!(!patch_data.is_empty()); assert_eq!(files_count, 1); // Reset to base and apply patch checkout_commit(path, &base_sha).await.unwrap(); assert_eq!(fs::read_to_string(path.join("file.txt")).unwrap(), "initial"); apply_patch(path, &patch_data).await.unwrap(); assert_eq!( fs::read_to_string(path.join("file.txt")).unwrap(), "modified" ); } #[tokio::test] async fn test_empty_patch() { let dir = setup_test_repo().await; let path = dir.path(); // Make another commit fs::write(path.join("file.txt"), "modified").unwrap(); Command::new("git") .current_dir(path) .args(["add", "."]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["commit", "-m", "modified"]) .output() .await .unwrap(); // Get current HEAD let head_output = Command::new("git") .current_dir(path) .args(["rev-parse", "HEAD"]) .output() .await .unwrap(); let head_sha = String::from_utf8_lossy(&head_output.stdout) .trim() .to_string(); // Try to create patch from HEAD to HEAD (no changes) let result = create_patch(path, &head_sha).await; assert!(matches!(result, Err(PatchError::EmptyPatch))); } #[tokio::test] async fn test_create_export_patch() { let dir = setup_test_repo().await; let path = dir.path(); // Get the initial commit SHA before making changes let initial_sha = get_head_sha(path).await.unwrap(); // Make some changes and commit fs::write(path.join("file.txt"), "modified content").unwrap(); fs::write(path.join("new_file.txt"), "new file content").unwrap(); Command::new("git") .current_dir(path) .args(["add", "."]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["commit", "-m", "changes for export"]) .output() .await .unwrap(); // Create export patch with explicit base let result = create_export_patch(path, Some(&initial_sha)).await.unwrap(); // Verify the result assert!(!result.patch_content.is_empty()); assert_eq!(result.files_count, 2); // file.txt and new_file.txt assert!(result.lines_added > 0); assert_eq!(result.base_commit_sha, initial_sha); // The patch should contain diff headers assert!(result.patch_content.contains("diff --git")); assert!(result.patch_content.contains("new_file.txt")); } #[tokio::test] async fn test_create_export_patch_no_base() { let dir = setup_test_repo().await; let path = dir.path(); // Make a second commit so we have something to diff fs::write(path.join("file.txt"), "modified").unwrap(); Command::new("git") .current_dir(path) .args(["add", "."]) .output() .await .unwrap(); Command::new("git") .current_dir(path) .args(["commit", "-m", "second commit"]) .output() .await .unwrap(); // Create export patch without explicit base (will use HEAD~1) let result = create_export_patch(path, None).await.unwrap(); // Verify the result assert!(!result.patch_content.is_empty()); assert_eq!(result.files_count, 1); assert!(result.patch_content.contains("diff --git")); } #[tokio::test] async fn test_parse_diff_stat() { // Test the parse_diff_stat function with various formats let stat1 = " 3 files changed, 45 insertions(+), 12 deletions(-)"; let (added, removed) = parse_diff_stat(stat1); assert_eq!(added, 45); assert_eq!(removed, 12); let stat2 = " 1 file changed, 10 insertions(+)"; let (added, removed) = parse_diff_stat(stat2); assert_eq!(added, 10); assert_eq!(removed, 0); let stat3 = " 2 files changed, 5 deletions(-)"; let (added, removed) = parse_diff_stat(stat3); assert_eq!(added, 0); assert_eq!(removed, 5); } }