//! Git patch creation and application for checkpoint recovery.
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::{Read, Write};
use std::path::Path;
use thiserror::Error;
use tokio::process::Command;
/// Errors that can occur during patch operations.
#[derive(Error, Debug)]
pub enum PatchError {
#[error("Git command failed: {0}")]
GitCommand(String),
#[error("Compression error: {0}")]
Compression(#[from] std::io::Error),
#[error("Patch too large: {size} bytes (max: {max} bytes)")]
TooLarge { size: usize, max: usize },
#[error("Empty patch (no changes)")]
EmptyPatch,
#[error("Failed to apply patch: {0}")]
ApplyFailed(String),
}
/// Create a compressed git diff from worktree changes.
///
/// Generates a diff between `base_sha` and HEAD, then compresses it with gzip.
/// Returns the compressed patch bytes and the number of files changed.
pub async fn create_patch(
worktree_path: &Path,
base_sha: &str,
) -> Result<(Vec<u8>, usize), PatchError> {
// Get the diff between base commit and HEAD
let output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", base_sha, "HEAD", "--binary"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to run git diff: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!("git diff failed: {}", stderr)));
}
let diff_data = output.stdout;
if diff_data.is_empty() {
return Err(PatchError::EmptyPatch);
}
// Count files changed
let files_output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", base_sha, "HEAD", "--name-only"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to count files: {}", e)))?;
let files_count = if files_output.status.success() {
String::from_utf8_lossy(&files_output.stdout)
.lines()
.filter(|l| !l.is_empty())
.count()
} else {
0
};
// Compress with gzip
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&diff_data)?;
let compressed = encoder.finish()?;
Ok((compressed, files_count))
}
/// Apply a compressed patch to restore worktree state.
///
/// The worktree should already be checked out at `base_sha` before calling this.
pub async fn apply_patch(worktree_path: &Path, patch_data: &[u8]) -> Result<(), PatchError> {
// Decompress gzip
let mut decoder = GzDecoder::new(patch_data);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
if decompressed.is_empty() {
return Err(PatchError::EmptyPatch);
}
// Apply the patch using git apply
let mut child = Command::new("git")
.current_dir(worktree_path)
.args(["apply", "--binary", "-"])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| PatchError::GitCommand(format!("Failed to spawn git apply: {}", e)))?;
// Write patch to stdin
if let Some(mut stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
stdin.write_all(&decompressed).await?;
drop(stdin); // Close stdin to signal EOF
}
let output = child
.wait_with_output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to wait for git apply: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::ApplyFailed(stderr.to_string()));
}
Ok(())
}
/// Get the parent commit SHA (HEAD~1) from a worktree.
pub async fn get_parent_sha(worktree_path: &Path) -> Result<String, PatchError> {
let output = Command::new("git")
.current_dir(worktree_path)
.args(["rev-parse", "HEAD~1"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to get parent SHA: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!(
"git rev-parse HEAD~1 failed: {}",
stderr
)));
}
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
}
/// Get the current HEAD commit SHA from a worktree.
pub async fn get_head_sha(worktree_path: &Path) -> Result<String, PatchError> {
let output = Command::new("git")
.current_dir(worktree_path)
.args(["rev-parse", "HEAD"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to get HEAD SHA: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!(
"git rev-parse HEAD failed: {}",
stderr
)));
}
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
}
/// Resolve the merge-base SHA for diffing against the main/master branch.
///
/// Tries in order:
/// 1. Upstream tracking branch merge-base
/// 2. Common branches: origin/main, origin/master, main, master
/// 3. Fallback: HEAD~1
///
/// Returns `Err(PatchError::EmptyPatch)` if the merge-base equals HEAD (no diff).
pub async fn get_merge_base_sha(worktree_path: &Path) -> Result<String, PatchError> {
// Try to get the upstream tracking branch
let upstream_result = Command::new("git")
.current_dir(worktree_path)
.args(["rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"])
.output()
.await;
let base = if let Ok(output) = upstream_result {
if output.status.success() {
let upstream = String::from_utf8_lossy(&output.stdout).trim().to_string();
// Get merge-base with upstream
let merge_base = Command::new("git")
.current_dir(worktree_path)
.args(["merge-base", "HEAD", &upstream])
.output()
.await;
if let Ok(mb_output) = merge_base {
if mb_output.status.success() {
Some(
String::from_utf8_lossy(&mb_output.stdout)
.trim()
.to_string(),
)
} else {
None
}
} else {
None
}
} else {
None
}
} else {
None
};
// Get current HEAD SHA for comparison
let head_sha = Command::new("git")
.current_dir(worktree_path)
.args(["rev-parse", "HEAD"])
.output()
.await
.ok()
.filter(|o| o.status.success())
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string());
// If we couldn't find upstream, try common default branches
let base = if base.is_none() {
let default_branches = ["origin/main", "origin/master", "main", "master"];
let mut found_base = None;
for branch in default_branches {
let merge_base = Command::new("git")
.current_dir(worktree_path)
.args(["merge-base", "HEAD", branch])
.output()
.await;
if let Ok(output) = merge_base {
if output.status.success() {
let mb_sha = String::from_utf8_lossy(&output.stdout).trim().to_string();
// Skip if merge-base equals HEAD (would result in empty diff)
if head_sha.as_ref() != Some(&mb_sha) {
found_base = Some(mb_sha);
break;
}
}
}
}
found_base
} else {
// Also check upstream base
if base.as_ref() == head_sha.as_ref() {
None
} else {
base
}
};
// If still nothing, fall back to HEAD~1
Ok(base.unwrap_or_else(|| "HEAD~1".to_string()))
}
/// Result of creating an export patch.
#[derive(Debug, Clone)]
pub struct ExportPatchResult {
/// The uncompressed, human-readable patch content.
pub patch_content: String,
/// Number of files changed in the patch.
pub files_count: usize,
/// Number of lines added.
pub lines_added: usize,
/// Number of lines removed.
pub lines_removed: usize,
/// The base commit SHA that the patch is diffed against.
pub base_commit_sha: String,
}
/// Create an uncompressed git diff patch for export.
///
/// This creates a human-readable patch that can be applied manually or
/// shared as a file. Unlike `create_patch`, this version is not compressed
/// and is suitable for display or export.
///
/// If `base_sha` is provided, the diff is between that commit and HEAD.
/// If `base_sha` is None, it attempts to find the merge-base with the default branch
/// or falls back to diffing uncommitted changes against HEAD.
pub async fn create_export_patch(
worktree_path: &Path,
base_sha: Option<&str>,
) -> Result<ExportPatchResult, PatchError> {
// Determine the base SHA to diff against
let resolved_base_sha = match base_sha {
Some(sha) => sha.to_string(),
None => get_merge_base_sha(worktree_path).await?,
};
// Get diff stats using --stat
let stat_output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", "--stat", &resolved_base_sha, "HEAD"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to run git diff --stat: {}", e)))?;
// Parse the stat output to get line counts
let (lines_added, lines_removed) = if stat_output.status.success() {
parse_diff_stat(&String::from_utf8_lossy(&stat_output.stdout))
} else {
(0, 0)
};
// Get the actual diff content
let diff_output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", &resolved_base_sha, "HEAD"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to run git diff: {}", e)))?;
if !diff_output.status.success() {
let stderr = String::from_utf8_lossy(&diff_output.stderr);
return Err(PatchError::GitCommand(format!("git diff failed: {}", stderr)));
}
let patch_content = String::from_utf8_lossy(&diff_output.stdout).to_string();
// Check for empty patch
if patch_content.trim().is_empty() {
return Err(PatchError::EmptyPatch);
}
// Count files changed
let files_output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", "--name-only", &resolved_base_sha, "HEAD"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to count files: {}", e)))?;
let files_count = if files_output.status.success() {
String::from_utf8_lossy(&files_output.stdout)
.lines()
.filter(|l| !l.is_empty())
.count()
} else {
0
};
Ok(ExportPatchResult {
patch_content,
files_count,
lines_added,
lines_removed,
base_commit_sha: resolved_base_sha,
})
}
/// Parse git diff --stat output to extract lines added and removed.
/// The last line typically looks like: " 3 files changed, 45 insertions(+), 12 deletions(-)"
fn parse_diff_stat(stat_output: &str) -> (usize, usize) {
let mut lines_added = 0;
let mut lines_removed = 0;
// Look for the summary line at the end
for line in stat_output.lines().rev() {
let line = line.trim();
if line.contains("changed") || line.contains("insertion") || line.contains("deletion") {
// Parse insertions
if let Some(idx) = line.find("insertion") {
let before = &line[..idx];
if let Some(num_str) = before.split_whitespace().last() {
if let Ok(num) = num_str.parse::<usize>() {
lines_added = num;
}
}
}
// Parse deletions
if let Some(idx) = line.find("deletion") {
let before = &line[..idx];
if let Some(num_str) = before.split(',').last() {
if let Some(num_str) = num_str.trim().split_whitespace().next() {
if let Ok(num) = num_str.parse::<usize>() {
lines_removed = num;
}
}
}
}
break;
}
}
(lines_added, lines_removed)
}
/// Checkout a specific commit in the worktree.
pub async fn checkout_commit(worktree_path: &Path, sha: &str) -> Result<(), PatchError> {
let output = Command::new("git")
.current_dir(worktree_path)
.args(["checkout", sha])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to checkout: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!(
"git checkout {} failed: {}",
sha, stderr
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
async fn setup_test_repo() -> TempDir {
let dir = TempDir::new().unwrap();
let path = dir.path();
// Initialize git repo
Command::new("git")
.current_dir(path)
.args(["init"])
.output()
.await
.unwrap();
// Configure git user
Command::new("git")
.current_dir(path)
.args(["config", "user.email", "test@test.com"])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["config", "user.name", "Test"])
.output()
.await
.unwrap();
// Create initial commit
fs::write(path.join("file.txt"), "initial").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "initial"])
.output()
.await
.unwrap();
dir
}
#[tokio::test]
async fn test_create_and_apply_patch() {
let dir = setup_test_repo().await;
let path = dir.path();
// Get base SHA
let base_sha = get_parent_sha(path).await;
// This will fail since there's only one commit
assert!(base_sha.is_err());
// Make another commit first
fs::write(path.join("file.txt"), "modified").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "modified"])
.output()
.await
.unwrap();
// Now get the base SHA
let base_sha = get_parent_sha(path).await.unwrap();
// Create patch
let (patch_data, files_count) = create_patch(path, &base_sha).await.unwrap();
assert!(!patch_data.is_empty());
assert_eq!(files_count, 1);
// Reset to base and apply patch
checkout_commit(path, &base_sha).await.unwrap();
assert_eq!(fs::read_to_string(path.join("file.txt")).unwrap(), "initial");
apply_patch(path, &patch_data).await.unwrap();
assert_eq!(
fs::read_to_string(path.join("file.txt")).unwrap(),
"modified"
);
}
#[tokio::test]
async fn test_empty_patch() {
let dir = setup_test_repo().await;
let path = dir.path();
// Make another commit
fs::write(path.join("file.txt"), "modified").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "modified"])
.output()
.await
.unwrap();
// Get current HEAD
let head_output = Command::new("git")
.current_dir(path)
.args(["rev-parse", "HEAD"])
.output()
.await
.unwrap();
let head_sha = String::from_utf8_lossy(&head_output.stdout)
.trim()
.to_string();
// Try to create patch from HEAD to HEAD (no changes)
let result = create_patch(path, &head_sha).await;
assert!(matches!(result, Err(PatchError::EmptyPatch)));
}
#[tokio::test]
async fn test_create_export_patch() {
let dir = setup_test_repo().await;
let path = dir.path();
// Get the initial commit SHA before making changes
let initial_sha = get_head_sha(path).await.unwrap();
// Make some changes and commit
fs::write(path.join("file.txt"), "modified content").unwrap();
fs::write(path.join("new_file.txt"), "new file content").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "changes for export"])
.output()
.await
.unwrap();
// Create export patch with explicit base
let result = create_export_patch(path, Some(&initial_sha)).await.unwrap();
// Verify the result
assert!(!result.patch_content.is_empty());
assert_eq!(result.files_count, 2); // file.txt and new_file.txt
assert!(result.lines_added > 0);
assert_eq!(result.base_commit_sha, initial_sha);
// The patch should contain diff headers
assert!(result.patch_content.contains("diff --git"));
assert!(result.patch_content.contains("new_file.txt"));
}
#[tokio::test]
async fn test_create_export_patch_no_base() {
let dir = setup_test_repo().await;
let path = dir.path();
// Make a second commit so we have something to diff
fs::write(path.join("file.txt"), "modified").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "second commit"])
.output()
.await
.unwrap();
// Create export patch without explicit base (will use HEAD~1)
let result = create_export_patch(path, None).await.unwrap();
// Verify the result
assert!(!result.patch_content.is_empty());
assert_eq!(result.files_count, 1);
assert!(result.patch_content.contains("diff --git"));
}
#[tokio::test]
async fn test_parse_diff_stat() {
// Test the parse_diff_stat function with various formats
let stat1 = " 3 files changed, 45 insertions(+), 12 deletions(-)";
let (added, removed) = parse_diff_stat(stat1);
assert_eq!(added, 45);
assert_eq!(removed, 12);
let stat2 = " 1 file changed, 10 insertions(+)";
let (added, removed) = parse_diff_stat(stat2);
assert_eq!(added, 10);
assert_eq!(removed, 0);
let stat3 = " 2 files changed, 5 deletions(-)";
let (added, removed) = parse_diff_stat(stat3);
assert_eq!(added, 0);
assert_eq!(removed, 5);
}
}