//! Git patch creation and application for checkpoint recovery.
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::{Read, Write};
use std::path::Path;
use thiserror::Error;
use tokio::process::Command;
/// Errors that can occur during patch operations.
#[derive(Error, Debug)]
pub enum PatchError {
#[error("Git command failed: {0}")]
GitCommand(String),
#[error("Compression error: {0}")]
Compression(#[from] std::io::Error),
#[error("Patch too large: {size} bytes (max: {max} bytes)")]
TooLarge { size: usize, max: usize },
#[error("Empty patch (no changes)")]
EmptyPatch,
#[error("Failed to apply patch: {0}")]
ApplyFailed(String),
}
/// Create a compressed git diff from worktree changes.
///
/// Generates a diff between `base_sha` and HEAD, then compresses it with gzip.
/// Returns the compressed patch bytes and the number of files changed.
pub async fn create_patch(
worktree_path: &Path,
base_sha: &str,
) -> Result<(Vec<u8>, usize), PatchError> {
// Get the diff between base commit and HEAD
let output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", base_sha, "HEAD", "--binary"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to run git diff: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!("git diff failed: {}", stderr)));
}
let diff_data = output.stdout;
if diff_data.is_empty() {
return Err(PatchError::EmptyPatch);
}
// Count files changed
let files_output = Command::new("git")
.current_dir(worktree_path)
.args(["diff", base_sha, "HEAD", "--name-only"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to count files: {}", e)))?;
let files_count = if files_output.status.success() {
String::from_utf8_lossy(&files_output.stdout)
.lines()
.filter(|l| !l.is_empty())
.count()
} else {
0
};
// Compress with gzip
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&diff_data)?;
let compressed = encoder.finish()?;
Ok((compressed, files_count))
}
/// Apply a compressed patch to restore worktree state.
///
/// The worktree should already be checked out at `base_sha` before calling this.
pub async fn apply_patch(worktree_path: &Path, patch_data: &[u8]) -> Result<(), PatchError> {
// Decompress gzip
let mut decoder = GzDecoder::new(patch_data);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
if decompressed.is_empty() {
return Err(PatchError::EmptyPatch);
}
// Apply the patch using git apply
let mut child = Command::new("git")
.current_dir(worktree_path)
.args(["apply", "--binary", "-"])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| PatchError::GitCommand(format!("Failed to spawn git apply: {}", e)))?;
// Write patch to stdin
if let Some(mut stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
stdin.write_all(&decompressed).await?;
drop(stdin); // Close stdin to signal EOF
}
let output = child
.wait_with_output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to wait for git apply: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::ApplyFailed(stderr.to_string()));
}
Ok(())
}
/// Get the parent commit SHA (HEAD~1) from a worktree.
pub async fn get_parent_sha(worktree_path: &Path) -> Result<String, PatchError> {
let output = Command::new("git")
.current_dir(worktree_path)
.args(["rev-parse", "HEAD~1"])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to get parent SHA: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!(
"git rev-parse HEAD~1 failed: {}",
stderr
)));
}
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
}
/// Checkout a specific commit in the worktree.
pub async fn checkout_commit(worktree_path: &Path, sha: &str) -> Result<(), PatchError> {
let output = Command::new("git")
.current_dir(worktree_path)
.args(["checkout", sha])
.output()
.await
.map_err(|e| PatchError::GitCommand(format!("Failed to checkout: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(PatchError::GitCommand(format!(
"git checkout {} failed: {}",
sha, stderr
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
async fn setup_test_repo() -> TempDir {
let dir = TempDir::new().unwrap();
let path = dir.path();
// Initialize git repo
Command::new("git")
.current_dir(path)
.args(["init"])
.output()
.await
.unwrap();
// Configure git user
Command::new("git")
.current_dir(path)
.args(["config", "user.email", "test@test.com"])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["config", "user.name", "Test"])
.output()
.await
.unwrap();
// Create initial commit
fs::write(path.join("file.txt"), "initial").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "initial"])
.output()
.await
.unwrap();
dir
}
#[tokio::test]
async fn test_create_and_apply_patch() {
let dir = setup_test_repo().await;
let path = dir.path();
// Get base SHA
let base_sha = get_parent_sha(path).await;
// This will fail since there's only one commit
assert!(base_sha.is_err());
// Make another commit first
fs::write(path.join("file.txt"), "modified").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "modified"])
.output()
.await
.unwrap();
// Now get the base SHA
let base_sha = get_parent_sha(path).await.unwrap();
// Create patch
let (patch_data, files_count) = create_patch(path, &base_sha).await.unwrap();
assert!(!patch_data.is_empty());
assert_eq!(files_count, 1);
// Reset to base and apply patch
checkout_commit(path, &base_sha).await.unwrap();
assert_eq!(fs::read_to_string(path.join("file.txt")).unwrap(), "initial");
apply_patch(path, &patch_data).await.unwrap();
assert_eq!(
fs::read_to_string(path.join("file.txt")).unwrap(),
"modified"
);
}
#[tokio::test]
async fn test_empty_patch() {
let dir = setup_test_repo().await;
let path = dir.path();
// Make another commit
fs::write(path.join("file.txt"), "modified").unwrap();
Command::new("git")
.current_dir(path)
.args(["add", "."])
.output()
.await
.unwrap();
Command::new("git")
.current_dir(path)
.args(["commit", "-m", "modified"])
.output()
.await
.unwrap();
// Get current HEAD
let head_output = Command::new("git")
.current_dir(path)
.args(["rev-parse", "HEAD"])
.output()
.await
.unwrap();
let head_sha = String::from_utf8_lossy(&head_output.stdout)
.trim()
.to_string();
// Try to create patch from HEAD to HEAD (no changes)
let result = create_patch(path, &head_sha).await;
assert!(matches!(result, Err(PatchError::EmptyPatch)));
}
}