diff options
Diffstat (limited to 'makima/src/server')
| -rw-r--r-- | makima/src/server/auth.rs | 1238 | ||||
| -rw-r--r-- | makima/src/server/handlers/api_keys.rs | 282 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 115 | ||||
| -rw-r--r-- | makima/src/server/handlers/files.rs | 53 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh.rs | 1679 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_chat.rs | 2088 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_daemon.rs | 959 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_merge.rs | 441 | ||||
| -rw-r--r-- | makima/src/server/handlers/mesh_ws.rs | 346 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 7 | ||||
| -rw-r--r-- | makima/src/server/handlers/users.rs | 972 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 59 | ||||
| -rw-r--r-- | makima/src/server/openapi.rs | 96 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 467 |
14 files changed, 8775 insertions, 27 deletions
diff --git a/makima/src/server/auth.rs b/makima/src/server/auth.rs new file mode 100644 index 0000000..b694df6 --- /dev/null +++ b/makima/src/server/auth.rs @@ -0,0 +1,1238 @@ +//! Authentication module for Makima server. +//! +//! Supports multiple authentication methods: +//! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification) +//! - API keys for programmatic access (daemons, CLI) +//! - Tool keys for orchestrator internal access + +use axum::{ + extract::FromRequestParts, + http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use chrono::{DateTime, Utc}; +use dashmap::DashMap; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use sqlx::{FromRow, PgPool, Row}; +use std::time::{Duration, Instant}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +// ============================================================================= +// Configuration +// ============================================================================= + +/// JWT algorithm configuration. +#[derive(Debug, Clone)] +pub enum JwtAlgorithm { + /// RS256 with RSA public key + Rs256 { public_key: String }, + /// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys) + Es256 { public_key: String }, +} + +/// Authentication configuration loaded from environment. +#[derive(Debug, Clone)] +pub struct AuthConfig { + /// Supabase project URL (e.g., https://your-project.supabase.co) + pub supabase_url: String, + /// JWT algorithm and key material + pub algorithm: JwtAlgorithm, +} + +impl AuthConfig { + /// Load auth config from environment variables. + /// + /// Supports two modes (checked in order): + /// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA) + /// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key) + /// + /// Returns None if auth is not configured. + pub fn from_env() -> Option<Self> { + let supabase_url = std::env::var("SUPABASE_URL").ok()?; + + // Try ES256 first (default for Supabase), then RS256 + let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") { + tracing::info!("Using ES256 JWT verification with ECDSA public key"); + JwtAlgorithm::Es256 { public_key } + } else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") { + tracing::info!("Using RS256 JWT verification with RSA public key"); + JwtAlgorithm::Rs256 { public_key } + } else { + return None; + }; + + Some(Self { + supabase_url, + algorithm, + }) + } +} + +// ============================================================================= +// JWT Claims +// ============================================================================= + +/// JWT claims from Supabase Auth tokens. +#[derive(Debug, Serialize, Deserialize)] +pub struct SupabaseClaims { + /// Audience (e.g., "authenticated") + pub aud: String, + /// Expiration time (Unix timestamp) + pub exp: i64, + /// Issued at (Unix timestamp) + pub iat: i64, + /// Issuer (Supabase project URL + /auth/v1) + pub iss: String, + /// Subject (user ID) + pub sub: Uuid, + /// User's email + pub email: Option<String>, + /// User's phone + pub phone: Option<String>, + /// App metadata (set by server/admin) + pub app_metadata: Option<serde_json::Value>, + /// User metadata (set by user) + pub user_metadata: Option<serde_json::Value>, + /// Role (e.g., "authenticated") + pub role: Option<String>, + /// Session ID + pub session_id: Option<Uuid>, +} + +// ============================================================================= +// JWT Verifier +// ============================================================================= + +/// JWT verifier for Supabase tokens. +pub struct JwtVerifier { + supabase_url: String, + decoding_key: DecodingKey, + algorithm: Algorithm, +} + +impl JwtVerifier { + /// Create a new JWT verifier from auth config. + /// + /// Supports multiple key formats: + /// - JWK (JSON Web Key) - detected by presence of `{` + /// - PEM - detected by `-----BEGIN` + /// - Base64-encoded DER - fallback + pub fn new(config: AuthConfig) -> Result<Self, AuthError> { + let (decoding_key, algorithm) = match &config.algorithm { + JwtAlgorithm::Rs256 { public_key } => { + let key = Self::parse_public_key(public_key, "RSA")?; + (key, Algorithm::RS256) + } + JwtAlgorithm::Es256 { public_key } => { + let key = Self::parse_public_key(public_key, "EC")?; + (key, Algorithm::ES256) + } + }; + + Ok(Self { + supabase_url: config.supabase_url, + decoding_key, + algorithm, + }) + } + + /// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER). + fn parse_public_key(key_data: &str, key_type: &str) -> Result<DecodingKey, AuthError> { + let trimmed = key_data.trim(); + + // Check for JSON format (JWK or JWKS) + if trimmed.starts_with('{') { + // First try to parse as a generic JSON value to inspect structure + let mut json_value: serde_json::Value = serde_json::from_str(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?; + + // Check if it's a JWKS (has "keys" array) + if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) { + // Find the first signing key (or just use the first key) + let jwk_value = keys.first_mut() + .ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?; + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = jwk_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWKS (first key)"); + return DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))); + } + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = json_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + // Try as single JWK + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWK"); + DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))) + } + // Check for PEM format + else if trimmed.contains("-----BEGIN") { + tracing::info!("Loaded JWT public key from PEM"); + match key_type { + "RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))), + "EC" => DecodingKey::from_ec_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + // Assume base64-encoded DER + else { + tracing::info!("Loaded JWT public key from base64 DER"); + let der_bytes = base64::engine::general_purpose::STANDARD + .decode(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?; + + match key_type { + "RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)), + "EC" => Ok(DecodingKey::from_ec_der(&der_bytes)), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + } + + /// Verify a JWT token and return claims. + pub fn verify(&self, token: &str) -> Result<SupabaseClaims, AuthError> { + // Decode header to check algorithm mismatch + let header = jsonwebtoken::decode_header(token) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?; + + tracing::debug!( + "JWT header: algorithm={:?}, typ={:?}, kid={:?}", + header.alg, + header.typ, + header.kid + ); + + if header.alg != self.algorithm { + let hint = match header.alg { + Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings", + Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key", + _ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported", + }; + tracing::warn!( + "JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}", + header.alg, + self.algorithm, + hint + ); + return Err(AuthError::InvalidToken(format!( + "Algorithm mismatch: token is {:?}, expected {:?}", + header.alg, self.algorithm + ))); + } + + let mut validation = Validation::new(self.algorithm); + validation.set_audience(&["authenticated"]); + validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]); + + // First try with full validation + let token_data = match decode::<SupabaseClaims>(token, &self.decoding_key, &validation) { + Ok(data) => data, + Err(e) => { + // Log detailed error info + tracing::warn!( + "JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)", + e, + self.algorithm, + self.supabase_url + ); + + // If it's InvalidAlgorithm, try to understand why by decoding payload manually + if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) { + // Decode the payload part of the JWT manually (base64) + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() >= 2 { + if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) { + if let Ok(payload_str) = String::from_utf8(payload_bytes) { + if let Ok(claims) = serde_json::from_str::<serde_json::Value>(&payload_str) { + tracing::warn!( + "JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}", + claims.get("iss"), + claims.get("aud"), + claims.get("sub") + ); + } + } + } + } + } + + return Err(AuthError::InvalidToken(e.to_string())); + } + }; + + Ok(token_data.claims) + } + + /// Extract user ID from a token. + pub fn get_user_id(&self, token: &str) -> Result<Uuid, AuthError> { + let claims = self.verify(token)?; + Ok(claims.sub) + } +} + +// ============================================================================= +// Auth Error +// ============================================================================= + +/// Authentication error types. +#[derive(Debug)] +pub enum AuthError { + /// No authentication token provided + MissingToken, + /// Token format is invalid + InvalidToken(String), + /// Token has expired + ExpiredToken, + /// User not found in database + UserNotFound, + /// API key is invalid or revoked + InvalidApiKey, + /// Database error during auth lookup + DatabaseError(String), + /// Authentication is not configured + NotConfigured, + /// Insufficient permissions for the operation + InsufficientPermissions, +} + +impl std::fmt::Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthError::MissingToken => write!(f, "Missing authentication token"), + AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg), + AuthError::ExpiredToken => write!(f, "Token has expired"), + AuthError::UserNotFound => write!(f, "User not found"), + AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"), + AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg), + AuthError::NotConfigured => write!(f, "Authentication not configured"), + AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"), + } + } +} + +impl std::error::Error for AuthError {} + +impl IntoResponse for AuthError { + fn into_response(self) -> axum::response::Response { + let (status, code, message) = match &self { + AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"), + AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"), + AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"), + AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"), + AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"), + AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"), + AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"), + AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"), + }; + + (status, Json(ApiError::new(code, message))).into_response() + } +} + +// ============================================================================= +// Auth Source +// ============================================================================= + +/// Source of authentication. +#[derive(Debug, Clone)] +pub enum AuthSource { + /// Authenticated via Supabase JWT (web client) + Jwt, + /// Authenticated via API key (daemon, CLI, integrations) + ApiKey, + /// Authenticated via tool key (orchestrator internal access) + ToolKey(Uuid), +} + +// ============================================================================= +// Authenticated User +// ============================================================================= + +/// Authenticated user context extracted from request. +/// +/// Contains the resolved user_id and owner_id for database operations. +#[derive(Debug, Clone)] +pub struct AuthenticatedUser { + /// Supabase auth user ID (from auth.users) + pub user_id: Uuid, + /// Owner ID for data isolation (from users.default_owner_id) + pub owner_id: Uuid, + /// How the user was authenticated + pub auth_source: AuthSource, + /// User's email (if available) + pub email: Option<String>, +} + +// ============================================================================= +// Header Constants +// ============================================================================= + +/// Header name for tool key authentication (orchestrators). +pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; + +/// Header name for API key authentication. +pub const API_KEY_HEADER: &str = "x-makima-api-key"; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Hash an API key for database lookup. +pub fn hash_api_key(key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + hex::encode(hasher.finalize()) +} + +// ============================================================================= +// API Key Generation +// ============================================================================= + +/// API key prefix for identification. +pub const API_KEY_PREFIX: &str = "mk_"; + +/// Result of generating an API key. +pub struct GeneratedApiKey { + /// The full API key (shown only once to user) + pub full_key: String, + /// SHA-256 hash of the key (stored in database) + pub key_hash: String, + /// Prefix for display (first 8 chars after mk_) + pub key_prefix: String, +} + +/// Generate a new API key with mk_ prefix. +/// +/// Returns the full key (to show once), hash (to store), and prefix (for display). +pub fn generate_api_key() -> GeneratedApiKey { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + + let key_bytes = URL_SAFE_NO_PAD.encode(bytes); + let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes); + + let key_hash = hash_api_key(&full_key); + let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]); + + GeneratedApiKey { + full_key, + key_hash, + key_prefix, + } +} + +// ============================================================================= +// API Key Cache +// ============================================================================= + +/// Cache entry for validated API keys. +struct ApiKeyCacheEntry { + user_id: Uuid, + owner_id: Uuid, + cached_at: Instant, +} + +/// In-memory cache for API key validation to avoid database lookups on every request. +pub struct ApiKeyCache { + /// key_hash -> (user_id, owner_id, cached_at) + cache: DashMap<String, ApiKeyCacheEntry>, + /// Time-to-live for cache entries + ttl: Duration, +} + +impl ApiKeyCache { + /// Create a new cache with the specified TTL in seconds. + pub fn new(ttl_seconds: u64) -> Self { + Self { + cache: DashMap::new(), + ttl: Duration::from_secs(ttl_seconds), + } + } + + /// Get cached user_id and owner_id for a key hash, if not expired. + pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> { + self.cache.get(key_hash).and_then(|entry| { + if entry.cached_at.elapsed() < self.ttl { + Some((entry.user_id, entry.owner_id)) + } else { + None + } + }) + } + + /// Cache a validated API key. + pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) { + self.cache.insert( + key_hash, + ApiKeyCacheEntry { + user_id, + owner_id, + cached_at: Instant::now(), + }, + ); + } + + /// Invalidate a cache entry (e.g., on key revocation). + pub fn invalidate(&self, key_hash: &str) { + self.cache.remove(key_hash); + } + + /// Clear all cache entries. + pub fn clear(&self) { + self.cache.clear(); + } +} + +impl Default for ApiKeyCache { + fn default() -> Self { + // Default TTL: 5 minutes + Self::new(300) + } +} + +// ============================================================================= +// API Key Models +// ============================================================================= + +/// API key record from the database. +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKey { + pub id: Uuid, + pub user_id: Uuid, + #[serde(skip)] + pub key_hash: String, + pub key_prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, + pub revoked_at: Option<DateTime<Utc>>, +} + +/// Request to create a new API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyRequest { + /// User-provided label for the key + pub name: Option<String>, +} + +/// Response after creating an API key (includes the full key - shown only once). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyResponse { + pub id: Uuid, + /// The full API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, +} + +/// Response for getting API key info (excludes the full key). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKeyInfoResponse { + pub id: Uuid, + pub prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, +} + +impl From<ApiKey> for ApiKeyInfoResponse { + fn from(key: ApiKey) -> Self { + Self { + id: key.id, + prefix: key.key_prefix, + name: key.name, + last_used_at: key.last_used_at, + created_at: key.created_at, + } + } +} + +/// Request to refresh an API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyRequest { + /// New name for the refreshed key + pub name: Option<String>, +} + +/// Response after refreshing an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyResponse { + pub id: Uuid, + /// The new API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, + pub previous_key_revoked: bool, +} + +/// Response after revoking an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RevokeApiKeyResponse { + pub message: String, + pub revoked_key_prefix: String, +} + +/// API key event types for audit logging. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApiKeyEventType { + Created, + Used, + Revoked, + Refreshed, +} + +impl std::fmt::Display for ApiKeyEventType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyEventType::Created => write!(f, "created"), + ApiKeyEventType::Used => write!(f, "used"), + ApiKeyEventType::Revoked => write!(f, "revoked"), + ApiKeyEventType::Refreshed => write!(f, "refreshed"), + } + } +} + +// ============================================================================= +// API Keys Repository +// ============================================================================= + +/// Repository error for API key operations. +#[derive(Debug)] +pub enum ApiKeyError { + /// Database error + Database(sqlx::Error), + /// An active API key already exists for this user + KeyAlreadyExists, + /// No active API key found for this user + KeyNotFound, +} + +impl std::fmt::Display for ApiKeyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyError::Database(e) => write!(f, "Database error: {}", e), + ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"), + ApiKeyError::KeyNotFound => write!(f, "No active API key found"), + } + } +} + +impl std::error::Error for ApiKeyError {} + +impl From<sqlx::Error> for ApiKeyError { + fn from(e: sqlx::Error) -> Self { + ApiKeyError::Database(e) + } +} + +/// Get the active API key for a user (if any). +pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result<Option<ApiKey>, sqlx::Error> { + sqlx::query_as::<_, ApiKey>( + r#" + SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + FROM api_keys + WHERE user_id = $1 AND revoked_at IS NULL + "#, + ) + .bind(user_id) + .fetch_optional(pool) + .await +} + +/// Create a new API key for a user. +/// +/// Returns an error if the user already has an active key. +/// The `generated` parameter should be created using `generate_api_key()`. +pub async fn create_api_key( + pool: &PgPool, + user_id: Uuid, + generated: &GeneratedApiKey, + name: Option<&str>, +) -> Result<ApiKey, ApiKeyError> { + // Check if user already has an active key + if let Some(_) = get_active_api_key(pool, user_id).await? { + return Err(ApiKeyError::KeyAlreadyExists); + } + + let key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&generated.key_hash) + .bind(&generated.key_prefix) + .bind(name) + .fetch_one(pool) + .await?; + + // Log the creation event + let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await; + + Ok(key) +} + +/// Revoke an API key by marking it with revoked_at timestamp. +pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result<ApiKey, ApiKeyError> { + // Get the active key first + let key = get_active_api_key(pool, user_id) + .await? + .ok_or(ApiKeyError::KeyNotFound)?; + + // Revoke it + let revoked = sqlx::query_as::<_, ApiKey>( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(key.id) + .fetch_one(pool) + .await?; + + // Log the revocation event + let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await; + + Ok(revoked) +} + +/// Refresh an API key: revoke the old one and create a new one atomically. +/// +/// Returns the new key. The caller should use `generate_api_key()` to create +/// the `new_generated` parameter. +pub async fn refresh_api_key( + pool: &PgPool, + user_id: Uuid, + new_generated: &GeneratedApiKey, + new_name: Option<&str>, +) -> Result<(ApiKey, Option<String>), ApiKeyError> { + // Get and revoke the old key (if exists) + let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? { + let old_prefix = old_key.key_prefix.clone(); + + // Revoke the old key + sqlx::query( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + "#, + ) + .bind(old_key.id) + .execute(pool) + .await?; + + // Log the refresh event on the old key + let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await; + + Some(old_prefix) + } else { + None + }; + + // Create the new key + let new_key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&new_generated.key_hash) + .bind(&new_generated.key_prefix) + .bind(new_name) + .fetch_one(pool) + .await?; + + // Log the creation event on the new key + let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await; + + Ok((new_key, old_prefix)) +} + +/// Update last_used_at timestamp for an API key. +pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + UPDATE api_keys + SET last_used_at = NOW() + WHERE key_hash = $1 AND revoked_at IS NULL + "#, + ) + .bind(key_hash) + .execute(pool) + .await?; + + Ok(()) +} + +/// Log an API key event for audit purposes. +pub async fn log_api_key_event( + pool: &PgPool, + api_key_id: Uuid, + event_type: ApiKeyEventType, + ip_address: Option<&str>, + user_agent: Option<&str>, +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent) + VALUES ($1, $2, $3::inet, $4) + "#, + ) + .bind(api_key_id) + .bind(event_type.to_string()) + .bind(ip_address) + .bind(user_agent) + .execute(pool) + .await?; + + Ok(()) +} + +// ============================================================================= +// Internal Helper Functions +// ============================================================================= + +/// Resolve owner_id from user_id by looking up the users table. +/// If the user doesn't exist, auto-creates them on first login. +/// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously. +async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> { + // First, try to get existing user + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + if let Some(row) = row { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + return owner_id.ok_or(AuthError::UserNotFound); + } + + // User doesn't exist - auto-create on first login + tracing::info!("Creating new user record for {}", user_id); + + // Create owner first (use ON CONFLICT to handle race conditions) + let owner_id = Uuid::new_v4(); + sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING") + .bind(owner_id) + .bind(email.unwrap_or("Unknown")) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?; + + // Create user with reference to owner (use ON CONFLICT to handle race conditions) + sqlx::query( + "INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING" + ) + .bind(user_id) + .bind(email) + .bind(owner_id) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?; + + // Re-fetch the user to get the actual owner_id (in case another request created it first) + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + owner_id.ok_or(AuthError::UserNotFound) + } + None => Err(AuthError::DatabaseError("Failed to create user record".to_string())) + } +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id = owner_id.ok_or(AuthError::UserNotFound)?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok((user_id, owner_id)) + } + None => Err(AuthError::InvalidApiKey), + } +} + +/// Extract authentication from request headers. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +async fn extract_auth( + state: &SharedState, + headers: &HeaderMap, +) -> Result<AuthenticatedUser, AuthError> { + // 1. Check for tool key (orchestrator access) + if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { + if let Ok(key_str) = tool_key.to_str() { + if let Some(task_id) = state.validate_tool_key(key_str) { + // Tool keys are trusted - use a placeholder user/owner for orchestrator actions + // The orchestrator inherits the owner_id from its task + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + + // Get owner_id from the task + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))? + .ok_or(AuthError::UserNotFound)?; + + let task_owner: Uuid = row.try_get("owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + return Ok(AuthenticatedUser { + user_id: Uuid::nil(), // Tool keys don't have a user + owner_id: task_owner, + auth_source: AuthSource::ToolKey(task_id), + email: None, + }); + } + tracing::warn!("Invalid tool key provided"); + } + } + + // 2. Check for API key + if let Some(api_key) = headers.get(API_KEY_HEADER) { + if let Ok(key_str) = api_key.to_str() { + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let (user_id, owner_id) = validate_api_key(pool, key_str).await?; + + return Ok(AuthenticatedUser { + user_id, + owner_id, + auth_source: AuthSource::ApiKey, + email: None, + }); + } + } + + // 3. Check for JWT (Bearer token) + if let Some(auth_header) = headers.get(AUTHORIZATION) { + if let Ok(auth_str) = auth_header.to_str() { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + let verifier = state + .jwt_verifier + .as_ref() + .ok_or(AuthError::NotConfigured)?; + + let claims = verifier.verify(token)?; + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?; + + return Ok(AuthenticatedUser { + user_id: claims.sub, + owner_id, + auth_source: AuthSource::Jwt, + email: claims.email, + }); + } + } + } + + Err(AuthError::MissingToken) +} + +// ============================================================================= +// Extractors +// ============================================================================= + +/// Extractor for authenticated requests. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +/// +/// Returns 401 Unauthorized if no valid authentication is found. +/// +/// # Example +/// ```ignore +/// async fn protected_handler( +/// Authenticated(user): Authenticated, +/// ) -> impl IntoResponse { +/// Json(format!("Hello user {}", user.user_id)) +/// } +/// ``` +pub struct Authenticated(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for Authenticated { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + Ok(Authenticated(user)) + } +} + +/// Extractor for user-only authentication (JWT or API key, no tool keys). +/// +/// Use this for endpoints that should only be accessible to actual users, +/// not orchestrators with tool keys. +/// +/// Returns 401 Unauthorized if no valid user authentication is found. +/// Returns 403 Forbidden if a tool key is used. +/// +/// # Example +/// ```ignore +/// async fn user_profile( +/// UserOnly(user): UserOnly, +/// ) -> impl IntoResponse { +/// // Only actual users can access this +/// Json(format!("User profile for {}", user.user_id)) +/// } +/// ``` +pub struct UserOnly(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for UserOnly { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + + // Reject tool key authentication + if matches!(user.auth_source, AuthSource::ToolKey(_)) { + return Err(AuthError::InsufficientPermissions); + } + + Ok(UserOnly(user)) + } +} + +/// Extractor for optional authentication. +/// +/// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise. +/// Never returns an error - invalid auth is treated as no auth. +/// +/// # Example +/// ```ignore +/// async fn public_or_private( +/// MaybeAuthenticated(user): MaybeAuthenticated, +/// ) -> impl IntoResponse { +/// match user { +/// Some(u) => Json(format!("Hello {}", u.user_id)), +/// None => Json("Hello anonymous".to_string()), +/// } +/// } +/// ``` +pub struct MaybeAuthenticated(pub Option<AuthenticatedUser>); + +impl FromRequestParts<SharedState> for MaybeAuthenticated { + type Rejection = std::convert::Infallible; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await.ok(); + Ok(MaybeAuthenticated(user)) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_api_key() { + let key = "mk_test123456789"; + let hash = hash_api_key(key); + + // Hash should be consistent + assert_eq!(hash, hash_api_key(key)); + + // Hash should be 64 characters (SHA-256 hex) + assert_eq!(hash.len(), 64); + } + + #[test] + fn test_auth_error_display() { + assert_eq!( + AuthError::MissingToken.to_string(), + "Missing authentication token" + ); + assert_eq!( + AuthError::InvalidToken("bad".to_string()).to_string(), + "Invalid token: bad" + ); + } + + #[test] + fn test_generate_api_key_format() { + let generated = generate_api_key(); + + // Full key should start with mk_ prefix + assert!(generated.full_key.starts_with(API_KEY_PREFIX)); + + // Full key should be mk_ + 43 chars (32 bytes base64url encoded) + assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43 + + // Prefix should be mk_ + first 8 chars + assert!(generated.key_prefix.starts_with(API_KEY_PREFIX)); + assert_eq!(generated.key_prefix.len(), 3 + 8); + + // Hash should be 64 hex chars (SHA-256) + assert_eq!(generated.key_hash.len(), 64); + } + + #[test] + fn test_generate_api_key_uniqueness() { + let key1 = generate_api_key(); + let key2 = generate_api_key(); + + // Keys should be unique + assert_ne!(key1.full_key, key2.full_key); + assert_ne!(key1.key_hash, key2.key_hash); + } + + #[test] + fn test_api_key_cache_basic() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_123"; + + // Cache miss initially + assert!(cache.get(key_hash).is_none()); + + // Set and verify cache hit + cache.set(key_hash.to_string(), user_id, owner_id); + let result = cache.get(key_hash); + assert!(result.is_some()); + let (cached_user, cached_owner) = result.unwrap(); + assert_eq!(cached_user, user_id); + assert_eq!(cached_owner, owner_id); + } + + #[test] + fn test_api_key_cache_invalidate() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_456"; + + cache.set(key_hash.to_string(), user_id, owner_id); + assert!(cache.get(key_hash).is_some()); + + cache.invalidate(key_hash); + assert!(cache.get(key_hash).is_none()); + } + + #[test] + fn test_api_key_cache_clear() { + let cache = ApiKeyCache::new(300); + + cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4()); + cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4()); + + assert!(cache.get("hash1").is_some()); + assert!(cache.get("hash2").is_some()); + + cache.clear(); + + assert!(cache.get("hash1").is_none()); + assert!(cache.get("hash2").is_none()); + } + + #[test] + fn test_api_key_event_type_display() { + assert_eq!(ApiKeyEventType::Created.to_string(), "created"); + assert_eq!(ApiKeyEventType::Used.to_string(), "used"); + assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked"); + assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed"); + } +} diff --git a/makima/src/server/handlers/api_keys.rs b/makima/src/server/handlers/api_keys.rs new file mode 100644 index 0000000..5a678a2 --- /dev/null +++ b/makima/src/server/handlers/api_keys.rs @@ -0,0 +1,282 @@ +//! HTTP handlers for API key management. +//! +//! These endpoints allow users to create, view, refresh, and revoke their API keys. +//! API keys are used for daemon authentication and programmatic access. + +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + Json, +}; + +use crate::server::auth::{ + create_api_key, generate_api_key, get_active_api_key, refresh_api_key, revoke_api_key, + ApiKeyError, ApiKeyInfoResponse, CreateApiKeyRequest, CreateApiKeyResponse, + RefreshApiKeyRequest, RefreshApiKeyResponse, RevokeApiKeyResponse, UserOnly, +}; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +/// Create a new API key for the authenticated user. +/// +/// Each user can only have one active API key at a time. If an existing key +/// exists, this will return a 409 Conflict error - use the refresh endpoint +/// to replace the existing key, or revoke it first. +#[utoipa::path( + post, + path = "/api/v1/auth/api-keys", + request_body = CreateApiKeyRequest, + responses( + (status = 201, description = "API key created", body = CreateApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 409, description = "API key already exists", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn create_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<CreateApiKeyRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Generate a new API key + let generated = generate_api_key(); + + match create_api_key(pool, user.user_id, &generated, req.name.as_deref()).await { + Ok(key) => { + let response = CreateApiKeyResponse { + id: key.id, + key: generated.full_key, + prefix: key.key_prefix, + name: key.name, + created_at: key.created_at, + }; + (StatusCode::CREATED, Json(response)).into_response() + } + Err(ApiKeyError::KeyAlreadyExists) => ( + StatusCode::CONFLICT, + Json(ApiError::new( + "KEY_EXISTS", + "An active API key already exists. Revoke it first or use refresh.", + )), + ) + .into_response(), + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to create API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to create API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get information about the current active API key. +/// +/// Returns the key's ID, prefix (for identification), name, and timestamps. +/// The full key is never returned - it was only shown once when created. +#[utoipa::path( + get, + path = "/api/v1/auth/api-keys", + responses( + (status = 200, description = "API key info", body = ApiKeyInfoResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 404, description = "No active API key", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn get_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match get_active_api_key(pool, user.user_id).await { + Ok(Some(key)) => { + let response: ApiKeyInfoResponse = key.into(); + Json(response).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NO_KEY", "No active API key found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Refresh the current API key. +/// +/// This revokes the existing key (if any) and creates a new one atomically. +/// Use this for key rotation without downtime. +#[utoipa::path( + post, + path = "/api/v1/auth/api-keys/refresh", + request_body = RefreshApiKeyRequest, + responses( + (status = 200, description = "API key refreshed", body = RefreshApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn refresh_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<RefreshApiKeyRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Generate a new API key + let generated = generate_api_key(); + + match refresh_api_key(pool, user.user_id, &generated, req.name.as_deref()).await { + Ok((key, old_prefix)) => { + // Invalidate cache for the old key if we had a cache + // (The cache lookup is by hash, but we revoked the old key in DB so it won't match) + + let response = RefreshApiKeyResponse { + id: key.id, + key: generated.full_key, + prefix: key.key_prefix, + name: key.name, + created_at: key.created_at, + previous_key_revoked: old_prefix.is_some(), + }; + Json(response).into_response() + } + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to refresh API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to refresh API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Revoke the current active API key. +/// +/// After revocation, the key can no longer be used for authentication. +/// A new key can be created after revocation. +#[utoipa::path( + delete, + path = "/api/v1/auth/api-keys", + responses( + (status = 200, description = "API key revoked", body = RevokeApiKeyResponse), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 404, description = "No active API key", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "API Keys" +)] +pub async fn revoke_api_key_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match revoke_api_key(pool, user.user_id).await { + Ok(key) => { + let response = RevokeApiKeyResponse { + message: "API key revoked successfully".to_string(), + revoked_key_prefix: key.key_prefix, + }; + Json(response).into_response() + } + Err(ApiKeyError::KeyNotFound) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NO_KEY", "No active API key found")), + ) + .into_response(), + Err(ApiKeyError::Database(e)) => { + tracing::error!("Failed to revoke API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + Err(e) => { + tracing::error!("Failed to revoke API key: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", e.to_string())), + ) + .into_response() + } + } +} diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index 51f17c1..dfdb64e 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -53,6 +53,9 @@ pub struct ChatRequest { /// Optional conversation history for context continuity #[serde(default)] pub history: Option<Vec<ChatHistoryMessage>>, + /// Optional focused element index (for targeted editing) + #[serde(default)] + pub focused_element_index: Option<usize>, } #[derive(Debug, Serialize, ToSchema)] @@ -232,6 +235,9 @@ pub async fn chat_handler( // Build context about the file let file_context = build_file_context(&file); + // Build focused element context if specified + let focused_context = build_focused_element_context(&file.body, request.focused_element_index); + // Build agentic system prompt let system_prompt = format!( r#"You are an intelligent document editing agent. You help users view, analyze, and modify document files. @@ -274,13 +280,14 @@ You have access to tools for: ## Current Document Context {file_context} - +{focused_context} ## Important Notes - Body element indices are 0-based - When updating elements, provide ALL required fields for that element type - The transcript is read-only (you cannot modify it, only read it) - Changes are saved automatically after tool execution"#, - file_context = file_context + file_context = file_context, + focused_context = focused_context ); // Build initial messages (Groq/OpenAI format - will be converted for Claude) @@ -690,12 +697,25 @@ fn build_file_context(file: &crate::db::models::File) -> String { let desc = match element { BodyElement::Heading { level, text } => format!("H{}: {}", level, text), BodyElement::Paragraph { text } => { - let preview = if text.len() > 50 { - format!("{}...", &text[..50]) + let preview: String = text.chars().take(50).collect(); + if text.chars().count() > 50 { + format!("Paragraph: {}...", preview) } else { - text.clone() - }; - format!("Paragraph: {}", preview) + format!("Paragraph: {}", preview) + } + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + let preview: String = content.chars().take(50).collect(); + if content.chars().count() > 50 { + format!("Code ({}): {}...", lang, preview) + } else { + format!("Code ({}): {}", lang, preview) + } + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "ordered" } else { "unordered" }; + format!("List ({}): {} items", list_type, items.len()) } BodyElement::Chart { chart_type, title, .. } => { format!( @@ -726,6 +746,64 @@ fn build_file_context(file: &crate::db::models::File) -> String { context } +/// Build context for a focused element +fn build_focused_element_context(body: &[BodyElement], focused_index: Option<usize>) -> String { + let Some(index) = focused_index else { + return String::new(); + }; + + let Some(element) = body.get(index) else { + return format!( + "\n## Focused Element\nNote: User focused on element [{}] but it doesn't exist (document has {} elements).\n", + index, + body.len() + ); + }; + + let (element_type, full_content) = match element { + BodyElement::Heading { level, text } => { + (format!("Heading (level {})", level), text.clone()) + } + BodyElement::Paragraph { text } => { + ("Paragraph".to_string(), text.clone()) + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + (format!("Code ({})", lang), content.clone()) + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "Ordered list" } else { "Unordered list" }; + let content = items.iter() + .enumerate() + .map(|(i, item)| format!("{}. {}", i + 1, item)) + .collect::<Vec<_>>() + .join("\n"); + (list_type.to_string(), content) + } + BodyElement::Chart { chart_type, title, .. } => { + let title_str = title.as_deref().unwrap_or("untitled"); + (format!("Chart ({:?})", chart_type), title_str.to_string()) + } + BodyElement::Image { alt, caption, .. } => { + let desc = alt.as_deref().or(caption.as_deref()).unwrap_or("no description"); + ("Image".to_string(), desc.to_string()) + } + }; + + format!( + r#" +## Focused Element +The user is focusing on element [{}]: {} +Full content of focused element: +--- +{} +--- +When the user's request is ambiguous about which element to modify, prioritize this focused element. +"#, + index, element_type, full_content + ) +} + /// Result of handling a version tool request struct VersionRequestResult { result: ToolResult, @@ -795,12 +873,25 @@ async fn handle_version_request( let desc = match element { BodyElement::Heading { level, text } => format!("H{}: {}", level, text), BodyElement::Paragraph { text } => { - let preview = if text.len() > 100 { - format!("{}...", &text[..100]) + let preview: String = text.chars().take(100).collect(); + if text.chars().count() > 100 { + format!("Paragraph: {}...", preview) } else { - text.clone() - }; - format!("Paragraph: {}", preview) + format!("Paragraph: {}", preview) + } + } + BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or("plain"); + let preview: String = content.chars().take(100).collect(); + if content.chars().count() > 100 { + format!("Code ({}): {}...", lang, preview) + } else { + format!("Code ({}): {}", lang, preview) + } + } + BodyElement::List { ordered, items } => { + let list_type = if *ordered { "ordered" } else { "unordered" }; + format!("List ({}): {} items", list_type, items.len()) } BodyElement::Chart { chart_type, title, .. } => { format!( diff --git a/makima/src/server/handlers/files.rs b/makima/src/server/handlers/files.rs index c65eed5..9634b73 100644 --- a/makima/src/server/handlers/files.rs +++ b/makima/src/server/handlers/files.rs @@ -10,21 +10,30 @@ use uuid::Uuid; use crate::db::models::{CreateFileRequest, FileListResponse, FileSummary, UpdateFileRequest}; use crate::db::repository::{self, RepositoryError}; +use crate::server::auth::Authenticated; use crate::server::messages::ApiError; use crate::server::state::{FileUpdateNotification, SharedState}; -/// List all files for the current owner. +/// List all files for the authenticated user's owner. #[utoipa::path( get, path = "/api/v1/files", responses( (status = 200, description = "List of files", body = FileListResponse), + (status = 401, description = "Unauthorized", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] -pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { +pub async fn list_files( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { return ( StatusCode::SERVICE_UNAVAILABLE, @@ -33,7 +42,7 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { .into_response(); }; - match repository::list_files(pool).await { + match repository::list_files_for_owner(pool, auth.owner_id).await { Ok(files) => { let summaries: Vec<FileSummary> = files.into_iter().map(FileSummary::from).collect(); let total = summaries.len() as i64; @@ -54,7 +63,7 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { } } -/// Get a single file by ID. +/// Get a single file by ID (scoped by owner). #[utoipa::path( get, path = "/api/v1/files/{id}", @@ -63,14 +72,20 @@ pub async fn list_files(State(state): State<SharedState>) -> impl IntoResponse { ), responses( (status = 200, description = "File details", body = crate::db::models::File), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn get_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -81,7 +96,7 @@ pub async fn get_file( .into_response(); }; - match repository::get_file(pool, id).await { + match repository::get_file_for_owner(pool, id, auth.owner_id).await { Ok(Some(file)) => Json(file).into_response(), Ok(None) => ( StatusCode::NOT_FOUND, @@ -107,13 +122,19 @@ pub async fn get_file( responses( (status = 201, description = "File created", body = crate::db::models::File), (status = 400, description = "Invalid request", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn create_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Json(req): Json<CreateFileRequest>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -124,7 +145,7 @@ pub async fn create_file( .into_response(); }; - match repository::create_file(pool, req).await { + match repository::create_file_for_owner(pool, auth.owner_id, req).await { Ok(file) => (StatusCode::CREATED, Json(file)).into_response(), Err(e) => { tracing::error!("Failed to create file: {}", e); @@ -137,7 +158,7 @@ pub async fn create_file( } } -/// Update an existing file. +/// Update an existing file (scoped by owner). #[utoipa::path( put, path = "/api/v1/files/{id}", @@ -147,15 +168,21 @@ pub async fn create_file( request_body = UpdateFileRequest, responses( (status = 200, description = "File updated", body = crate::db::models::File), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 409, description = "Version conflict", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn update_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, Json(req): Json<UpdateFileRequest>, ) -> impl IntoResponse { @@ -185,7 +212,7 @@ pub async fn update_file( updated_fields.push("body".to_string()); } - match repository::update_file(pool, id, req).await { + match repository::update_file_for_owner(pool, id, auth.owner_id, req).await { Ok(Some(file)) => { // Broadcast update notification state.broadcast_file_update(FileUpdateNotification { @@ -233,7 +260,7 @@ pub async fn update_file( } } -/// Delete a file. +/// Delete a file (scoped by owner). #[utoipa::path( delete, path = "/api/v1/files/{id}", @@ -242,14 +269,20 @@ pub async fn update_file( ), responses( (status = 204, description = "File deleted"), + (status = 401, description = "Unauthorized", body = ApiError), (status = 404, description = "File not found", body = ApiError), (status = 503, description = "Database not configured", body = ApiError), (status = 500, description = "Internal server error", body = ApiError), ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), tag = "Files" )] pub async fn delete_file( State(state): State<SharedState>, + Authenticated(auth): Authenticated, Path(id): Path<Uuid>, ) -> impl IntoResponse { let Some(ref pool) = state.db_pool else { @@ -260,7 +293,7 @@ pub async fn delete_file( .into_response(); }; - match repository::delete_file(pool, id).await { + match repository::delete_file_for_owner(pool, id, auth.owner_id).await { Ok(true) => StatusCode::NO_CONTENT.into_response(), Ok(false) => ( StatusCode::NOT_FOUND, diff --git a/makima/src/server/handlers/mesh.rs b/makima/src/server/handlers/mesh.rs new file mode 100644 index 0000000..760740c --- /dev/null +++ b/makima/src/server/handlers/mesh.rs @@ -0,0 +1,1679 @@ +//! HTTP handlers for task and daemon mesh operations. + +use axum::{ + extract::{Path, State}, + http::{HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use uuid::Uuid; + +use crate::db::models::{ + CreateTaskRequest, DaemonDirectory, DaemonDirectoriesResponse, DaemonListResponse, + SendMessageRequest, Task, TaskEventListResponse, TaskListResponse, TaskOutputEntry, + TaskOutputResponse, TaskWithSubtasks, UpdateTaskRequest, +}; +use crate::db::repository::{self, RepositoryError}; +use crate::server::auth::Authenticated; +use crate::server::messages::ApiError; +use crate::server::state::{DaemonCommand, SharedState, TaskUpdateNotification}; + +// ============================================================================= +// Authentication Types +// ============================================================================= + +/// Source of authentication for mesh endpoints. +#[derive(Debug, Clone)] +pub enum AuthSource { + /// Authenticated via tool key (orchestrator accessing API). + /// Contains the task ID that owns this key. + ToolKey(Uuid), + /// Authenticated via user token (web client). + /// Contains the user ID. (Not implemented yet) + #[allow(dead_code)] + UserToken(Uuid), + /// No authentication provided (anonymous access). + Anonymous, +} + +/// Header name for tool key authentication. +pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; + +/// Extract authentication source from request headers. +/// +/// Checks for: +/// 1. `X-Makima-Tool-Key` header for orchestrator tool access +/// 2. `Authorization: Bearer` header for user access (future) +/// 3. Falls back to Anonymous if no auth provided +pub fn extract_auth(state: &SharedState, headers: &HeaderMap) -> AuthSource { + // Check for tool key header first + if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { + if let Ok(key_str) = tool_key.to_str() { + if let Some(task_id) = state.validate_tool_key(key_str) { + return AuthSource::ToolKey(task_id); + } + tracing::warn!("Invalid tool key provided"); + } + } + + // Check for Authorization header (future user auth) + if let Some(auth_header) = headers.get("authorization") { + if let Ok(auth_str) = auth_header.to_str() { + if auth_str.starts_with("Bearer ") { + // Future: validate JWT and extract user ID + tracing::debug!("Bearer token auth not yet implemented"); + } + } + } + + // Default to anonymous + AuthSource::Anonymous +} + +// ============================================================================= +// Task Handlers +// ============================================================================= + +/// List all tasks for the current owner. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks", + responses( + (status = 200, description = "List of tasks", body = TaskListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_tasks( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::list_tasks_for_owner(pool, auth.owner_id).await { + Ok(tasks) => { + let total = tasks.len() as i64; + Json(TaskListResponse { tasks, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list tasks: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get a single task by ID with its subtasks (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task details with subtasks", body = TaskWithSubtasks), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(task)) => { + // Get subtasks for this task (also scoped by owner) + match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(subtasks) => Json(TaskWithSubtasks { task, subtasks }).into_response(), + Err(e) => { + tracing::error!("Failed to get subtasks for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Create a new task. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks", + request_body = CreateTaskRequest, + responses( + (status = 201, description = "Task created", body = Task), + (status = 400, description = "Invalid request", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn create_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Json(req): Json<CreateTaskRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::create_task_for_owner(pool, auth.owner_id, req).await { + Ok(task) => (StatusCode::CREATED, Json(task)).into_response(), + Err(e) => { + tracing::error!("Failed to create task: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Update an existing task (scoped by owner). +#[utoipa::path( + put, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = UpdateTaskRequest, + responses( + (status = 200, description = "Task updated", body = Task), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 409, description = "Version conflict", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn update_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(req): Json<UpdateTaskRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Track which fields are being updated for the notification + let mut updated_fields = Vec::new(); + if req.name.is_some() { + updated_fields.push("name".to_string()); + } + if req.description.is_some() { + updated_fields.push("description".to_string()); + } + if req.status.is_some() { + updated_fields.push("status".to_string()); + } + if req.priority.is_some() { + updated_fields.push("priority".to_string()); + } + if req.plan.is_some() { + updated_fields.push("plan".to_string()); + } + if req.progress_summary.is_some() { + updated_fields.push("progress_summary".to_string()); + } + if req.error_message.is_some() { + updated_fields.push("error_message".to_string()); + } + + match repository::update_task_for_owner(pool, id, auth.owner_id, req).await { + Ok(Some(task)) => { + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: task.id, + owner_id: Some(auth.owner_id), + version: task.version, + status: task.status.clone(), + updated_fields, + updated_by: "user".to_string(), + }); + Json(task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(RepositoryError::VersionConflict { expected, actual }) => { + tracing::info!( + "Version conflict on task {}: expected {}, actual {}", + id, + expected, + actual + ); + ( + StatusCode::CONFLICT, + Json(serde_json::json!({ + "code": "VERSION_CONFLICT", + "message": format!( + "Task was modified by another user. Expected version {}, actual version {}", + expected, actual + ), + "expectedVersion": expected, + "actualVersion": actual, + })), + ) + .into_response() + } + Err(RepositoryError::Database(e)) => { + tracing::error!("Failed to update task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Delete a task (scoped by owner). +#[utoipa::path( + delete, + path = "/api/v1/mesh/tasks/{id}", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 204, description = "Task deleted"), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn delete_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task first to check if it's running and needs to be stopped + if let Ok(Some(task)) = repository::get_task_for_owner(pool, id, auth.owner_id).await { + let is_active = matches!( + task.status.as_str(), + "running" | "starting" | "initializing" | "paused" + ); + + // If task is active and has a daemon, send interrupt command + if is_active { + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { + task_id: id, + graceful: false, + }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::warn!( + task_id = %id, + daemon_id = %daemon_id, + "Failed to send InterruptTask before delete: {}", + e + ); + } else { + tracing::info!( + task_id = %id, + daemon_id = %daemon_id, + "Sent InterruptTask before delete" + ); + } + } + } + } + + match repository::delete_task_for_owner(pool, id, auth.owner_id).await { + Ok(true) => StatusCode::NO_CONTENT.into_response(), + Ok(false) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to delete task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Start a task by sending it to an available daemon (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/start", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task started", body = Task), + (status = 400, description = "Task cannot be started", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or no daemons available", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn start_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + headers: HeaderMap, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + // Extract authentication to log who is starting the task + let legacy_auth = extract_auth(&state, &headers); + match &legacy_auth { + AuthSource::ToolKey(orchestrator_id) => { + tracing::info!( + task_id = %id, + orchestrator_task_id = %orchestrator_id, + owner_id = %auth.owner_id, + "Orchestrator starting subtask via tool key" + ); + } + AuthSource::Anonymous => { + tracing::info!( + task_id = %id, + owner_id = %auth.owner_id, + "Starting task (user request)" + ); + } + AuthSource::UserToken(user_id) => { + tracing::info!( + task_id = %id, + user_id = %user_id, + owner_id = %auth.owner_id, + "Starting task via user token" + ); + } + } + + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task can be started (allow pending, failed, interrupted, done, or merged) + let startable_statuses = ["pending", "failed", "interrupted", "done", "merged"]; + if !startable_statuses.contains(&task.status.as_str()) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task cannot be started from status: {}", task.status), + )), + ) + .into_response(); + } + + // Find an available daemon belonging to this owner + let target_daemon_id = match state.daemon_connections + .iter() + .find(|d| d.value().owner_id == auth.owner_id) + { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot start task.", + )), + ) + .into_response(); + } + }; + + // Check if this is an orchestrator (depth 0 with subtasks) + let subtask_count = match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(subtasks) => { + tracing::info!( + task_id = %id, + subtask_count = subtasks.len(), + subtask_ids = ?subtasks.iter().map(|s| s.id.to_string()).collect::<Vec<_>>(), + "Counted subtasks for orchestrator check" + ); + subtasks.len() + }, + Err(e) => { + tracing::warn!("Failed to check subtasks for {}: {}", id, e); + 0 + } + }; + let is_orchestrator = task.depth == 0 && subtask_count > 0; + + tracing::info!( + task_id = %id, + task_depth = task.depth, + subtask_count = subtask_count, + is_orchestrator = is_orchestrator, + "Starting task with orchestrator determination" + ); + + // IMPORTANT: Update database FIRST to assign daemon_id before sending command + // This prevents race conditions where the task starts but daemon_id is not set + let update_req = UpdateTaskRequest { + status: Some("starting".to_string()), + daemon_id: Some(target_daemon_id), + version: Some(task.version), + ..Default::default() + }; + + let updated_task = match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Send SpawnTask command to daemon + let command = DaemonCommand::SpawnTask { + task_id: id, + task_name: task.name.clone(), + plan: task.plan.clone(), + repo_url: task.repository_url.clone(), + base_branch: task.base_branch.clone(), + target_branch: task.target_branch.clone(), + parent_task_id: task.parent_task_id, + depth: task.depth, + is_orchestrator, + target_repo_path: task.target_repo_path.clone(), + completion_action: task.completion_action.clone(), + continue_from_task_id: task.continue_from_task_id, + copy_files: task.copy_files.as_ref().and_then(|v| serde_json::from_value(v.clone()).ok()), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send SpawnTask command: {}", e); + // Rollback: clear daemon_id and reset status since command failed + let rollback_req = UpdateTaskRequest { + status: Some("pending".to_string()), + clear_daemon_id: true, // Explicitly clear daemon_id + ..Default::default() + }; + let _ = repository::update_task_for_owner(pool, id, auth.owner_id, rollback_req).await; + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "starting".to_string(), + updated_fields: vec!["status".to_string(), "daemon_id".to_string()], + updated_by: "system".to_string(), + }); + + Json(updated_task).into_response() +} + +/// Stop a running task (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/stop", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task stopped", body = Task), + (status = 400, description = "Task is not running", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn stop_task( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is running/active + let is_active = matches!( + task.status.as_str(), + "running" | "starting" | "initializing" | "paused" + ); + if !is_active { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task cannot be stopped from status: {}", task.status), + )), + ) + .into_response(); + } + + // Find the daemon running this task + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + daemon_id + } else { + // No daemon assigned, just update status directly + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user".to_string()), + version: Some(task.version), + ..Default::default() + }; + + return match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + }; + }; + + // Send InterruptTask command to daemon + let command = DaemonCommand::InterruptTask { + task_id: id, + graceful: false, + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::warn!("Failed to send InterruptTask command: {}", e); + // Daemon might be disconnected - update task status directly + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user (daemon unavailable)".to_string()), + version: Some(task.version), + ..Default::default() + }; + + return match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + }; + } + + // Update task status to "failed" (stopped) + let update_req = UpdateTaskRequest { + status: Some("failed".to_string()), + error_message: Some("Task stopped by user".to_string()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, id, auth.owner_id, update_req).await { + Ok(Some(updated_task)) => { + // Broadcast task update notification + state.broadcast_task_update(TaskUpdateNotification { + task_id: id, + owner_id: Some(auth.owner_id), + version: updated_task.version, + status: "failed".to_string(), + updated_fields: vec!["status".to_string(), "error_message".to_string()], + updated_by: "user".to_string(), + }); + + Json(updated_task).into_response() + } + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to update task status: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Send a message to a running task's stdin (scoped by owner). +/// +/// This can be used to provide input to Claude Code when it's waiting for user input, +/// or to inject context/instructions into a running task. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/message", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = SendMessageRequest, + responses( + (status = 200, description = "Message sent successfully"), + (status = 400, description = "Task is not running", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn send_message( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(req): Json<SendMessageRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is running + if task.status != "running" { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!( + "Cannot send message to task in status: {}. Task must be running.", + task.status + ), + )), + ) + .into_response(); + } + + // Find the daemon running this task + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + daemon_id + } else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "Task has no assigned daemon. Cannot send message.", + )), + ) + .into_response(); + }; + + // Send SendMessage command to daemon + let command = DaemonCommand::SendMessage { + task_id: id, + message: req.message.clone(), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send SendMessage command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + tracing::info!(task_id = %id, message_len = req.message.len(), "Message sent to task"); + + // Return success + ( + StatusCode::OK, + Json(serde_json::json!({ + "success": true, + "taskId": id, + "messageLength": req.message.len() + })), + ) + .into_response() +} + +/// Get task output history (scoped by owner). +/// +/// Retrieves all recorded output from a task's Claude Code process. +/// This allows the frontend to fetch missed output when subscribing late +/// or reconnecting after a disconnect. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/output", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Task output history", body = TaskOutputResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_task_output( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Verify task exists and belongs to owner + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + } + + // Get output history (task already verified to belong to owner) + match repository::get_task_output(pool, id, None).await { + Ok(events) => { + let entries: Vec<TaskOutputEntry> = events + .into_iter() + .filter_map(TaskOutputEntry::from_task_event) + .collect(); + let total = entries.len(); + + Json(TaskOutputResponse { + entries, + total, + task_id: id, + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to get task output: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// List subtasks for a parent task (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/subtasks", + params( + ("id" = Uuid, Path, description = "Parent task ID") + ), + responses( + (status = 200, description = "List of subtasks", body = TaskListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_subtasks( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + match repository::list_subtasks_for_owner(pool, id, auth.owner_id).await { + Ok(tasks) => { + let total = tasks.len() as i64; + Json(TaskListResponse { tasks, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list subtasks for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// List events for a task (scoped by owner). +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/events", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "List of task events", body = TaskEventListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_task_events( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Verify task exists and belongs to owner + match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(_)) => {} + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + } + + match repository::list_task_events(pool, id, None).await { + Ok(events) => { + let total = events.len() as i64; + Json(TaskEventListResponse { events, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list events for task {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Retry completion action for a completed task (scoped by owner). +/// +/// This allows retrying a completion action (push branch, merge, create PR) +/// after filling in the target_repo_path if it wasn't set when the task completed. +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/retry-completion", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 200, description = "Completion action initiated"), + (status = 400, description = "Invalid request (task not completed, no completion action, etc.)", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn retry_completion_action( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Check if task is in a terminal state + let terminal_statuses = ["done", "failed", "merged"]; + if !terminal_statuses.contains(&task.status.as_str()) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!( + "Task must be completed to retry completion action. Current status: {}", + task.status + ), + )), + ) + .into_response(); + } + + // Check if completion action is set + let action = match &task.completion_action { + Some(action) if action != "none" => action.clone(), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "NO_COMPLETION_ACTION", + "Task has no completion action configured (or is set to 'none')", + )), + ) + .into_response(); + } + }; + + // Check if target_repo_path is set + let target_repo_path = match &task.target_repo_path { + Some(path) if !path.is_empty() => path.clone(), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "NO_TARGET_REPO", + "Target repository path must be set before retrying completion action", + )), + ) + .into_response(); + } + }; + + // Note: We don't check overlay_path here because the server may not have it + // The daemon will scan its worktrees directory to find the worktree by task ID + + // Find a daemon to execute the action (must belong to this owner) + // Prefer the daemon that ran the task, but fall back to any available daemon for this owner + let target_daemon_id = if let Some(daemon_id) = task.daemon_id { + // Check if this daemon is still connected and belongs to this owner + if state.daemon_connections.iter().any(|d| d.value().id == daemon_id && d.value().owner_id == auth.owner_id) { + daemon_id + } else { + // Fall back to any connected daemon for this owner + match state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id) { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot execute completion action.", + )), + ) + .into_response(); + } + } + } + } else { + // No daemon assigned - use any available for this owner + match state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id) { + Some(d) => d.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "NO_DAEMON", + "No daemons connected for your account. Cannot execute completion action.", + )), + ) + .into_response(); + } + } + }; + + // Send RetryCompletionAction command to daemon + let command = DaemonCommand::RetryCompletionAction { + task_id: id, + task_name: task.name.clone(), + action: action.clone(), + target_repo_path: target_repo_path.clone(), + target_branch: task.target_branch.clone(), + }; + + if let Err(e) = state.send_daemon_command(target_daemon_id, command).await { + tracing::error!("Failed to send RetryCompletionAction command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + tracing::info!( + task_id = %id, + action = %action, + target_repo = %target_repo_path, + "Retry completion action initiated" + ); + + ( + StatusCode::OK, + Json(serde_json::json!({ + "success": true, + "taskId": id, + "action": action, + "targetRepoPath": target_repo_path, + "message": "Completion action initiated. Check task output for results." + })), + ) + .into_response() +} + +// ============================================================================= +// Daemon Handlers +// ============================================================================= + +/// List all connected daemons (requires authentication). +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons", + responses( + (status = 200, description = "List of daemons", body = DaemonListResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_daemons( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Only list daemons belonging to this owner + match repository::list_daemons_for_owner(pool, auth.owner_id).await { + Ok(daemons) => { + let total = daemons.len() as i64; + Json(DaemonListResponse { daemons, total }).into_response() + } + Err(e) => { + tracing::error!("Failed to list daemons: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get a single daemon by ID (requires authentication). +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/{id}", + params( + ("id" = Uuid, Path, description = "Daemon ID") + ), + responses( + (status = 200, description = "Daemon details", body = crate::db::models::Daemon), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Daemon not found", body = ApiError), + (status = 503, description = "Database not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_daemon( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Only get daemon if it belongs to this owner + match repository::get_daemon_for_owner(pool, id, auth.owner_id).await { + Ok(Some(daemon)) => Json(daemon).into_response(), + Ok(None) => ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Daemon not found")), + ) + .into_response(), + Err(e) => { + tracing::error!("Failed to get daemon {}: {}", id, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response() + } + } +} + +/// Get suggested directories from connected daemons (requires authentication). +/// +/// Returns directories that can be used as target_repo_path for completion actions. +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/directories", + responses( + (status = 200, description = "List of suggested directories", body = DaemonDirectoriesResponse), + (status = 401, description = "Unauthorized", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_daemon_directories( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let mut directories = Vec::new(); + + // Iterate over connected daemons belonging to this owner and collect their directories + for entry in state.daemon_connections.iter() { + let daemon = entry.value(); + + // Only include daemons belonging to this owner + if daemon.owner_id != auth.owner_id { + continue; + } + + // Add working directory if available + if let Some(ref working_dir) = daemon.working_directory { + directories.push(DaemonDirectory { + path: working_dir.clone(), + label: "Working Directory".to_string(), + directory_type: "working".to_string(), + hostname: daemon.hostname.clone(), + exists: None, + }); + } + + // Add home directory if available (for cloning completed work) + if let Some(ref home_dir) = daemon.home_directory { + directories.push(DaemonDirectory { + path: home_dir.clone(), + label: "Makima Home".to_string(), + directory_type: "home".to_string(), + hostname: daemon.hostname.clone(), + exists: None, + }); + } + } + + Json(DaemonDirectoriesResponse { directories }) +} + +/// Request to clone a worktree to a target directory. +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CloneWorktreeRequest { + /// Path to the target directory. + pub target_dir: String, +} + +/// Clone a task's worktree to a target directory (scoped by owner). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/clone", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = CloneWorktreeRequest, + responses( + (status = 200, description = "Clone command sent"), + (status = 400, description = "Invalid request or task not completed", body = ApiError), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 404, description = "Task not found", body = ApiError), + (status = 503, description = "Database not configured or daemon not connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn clone_worktree( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(body): Json<CloneWorktreeRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get the task (scoped by owner) + let task = match repository::get_task_for_owner(pool, id, auth.owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Task not found")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to get task {}: {}", id, e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", e.to_string())), + ) + .into_response(); + } + }; + + // Verify task is in a completed state + let is_completed = matches!(task.status.as_str(), "done" | "failed" | "merged"); + if !is_completed { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_STATE", + format!("Task must be completed to clone (current status: {})", task.status), + )), + ) + .into_response(); + } + + // Find a connected daemon belonging to this owner to send the command + let daemon_entry = state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id); + let daemon_id = match daemon_entry { + Some(entry) => entry.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("NO_DAEMON", "No daemon connected for your account")), + ) + .into_response(); + } + }; + + // Send CloneWorktree command to daemon + let command = DaemonCommand::CloneWorktree { + task_id: id, + target_dir: body.target_dir.clone(), + }; + + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::error!("Failed to send CloneWorktree command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + Json(serde_json::json!({ + "status": "cloning", + "taskId": id.to_string(), + "targetDir": body.target_dir, + })) + .into_response() +} + +/// Request to check if a target directory exists. +#[derive(Debug, serde::Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CheckTargetExistsRequest { + /// Path to check. + pub target_dir: String, +} + +/// Response for check target exists. +#[derive(Debug, serde::Serialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CheckTargetExistsResponse { + /// Whether the target directory exists. + pub exists: bool, + /// The path that was checked (expanded). + pub target_dir: String, +} + +/// Check if a target directory exists (for clone validation, requires authentication). +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/check-target", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = CheckTargetExistsRequest, + responses( + (status = 200, description = "Check result", body = CheckTargetExistsResponse), + (status = 401, description = "Unauthorized", body = ApiError), + (status = 503, description = "No daemon connected", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn check_target_exists( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(id): Path<Uuid>, + Json(body): Json<CheckTargetExistsRequest>, +) -> impl IntoResponse { + // Find a connected daemon belonging to this owner to send the command + let daemon_entry = state.daemon_connections.iter().find(|d| d.value().owner_id == auth.owner_id); + let daemon_id = match daemon_entry { + Some(entry) => entry.value().id, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("NO_DAEMON", "No daemon connected for your account")), + ) + .into_response(); + } + }; + + // Send CheckTargetExists command to daemon + let command = DaemonCommand::CheckTargetExists { + task_id: id, + target_dir: body.target_dir.clone(), + }; + + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + tracing::error!("Failed to send CheckTargetExists command: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DAEMON_ERROR", e)), + ) + .into_response(); + } + + // The actual result will be sent back via WebSocket + // For now, just acknowledge the request was sent + Json(serde_json::json!({ + "status": "checking", + "taskId": id.to_string(), + "targetDir": body.target_dir, + })) + .into_response() +} diff --git a/makima/src/server/handlers/mesh_chat.rs b/makima/src/server/handlers/mesh_chat.rs new file mode 100644 index 0000000..5d6d2ee --- /dev/null +++ b/makima/src/server/handlers/mesh_chat.rs @@ -0,0 +1,2088 @@ +//! Chat endpoint for LLM-powered task orchestration. +//! +//! This handler provides an agentic loop for managing tasks, daemons, and +//! overlay operations through LLM tool calling. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::db::{models::CreateTaskRequest, repository}; +use crate::llm::{ + claude::{self, ClaudeClient, ClaudeError, ClaudeModel}, + groq::{GroqClient, GroqError, Message, ToolCallResponse}, + parse_mesh_tool_call, LlmModel, MeshToolRequest, ToolCall, ToolResult, UserQuestion, + MESH_TOOLS, +}; +use crate::server::auth::Authenticated; +use crate::server::state::{DaemonCommand, SharedState, TaskUpdateNotification}; + +/// Maximum number of tool-calling rounds to prevent infinite loops +const MAX_TOOL_ROUNDS: usize = 30; + +#[derive(Debug, Clone, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatHistoryMessage { + /// Role: "user" or "assistant" + pub role: String, + /// Message content + pub content: String, +} + +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatRequest { + /// The user's message/instruction + pub message: String, + /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq" + #[serde(default)] + pub model: Option<String>, + /// Optional conversation history for context continuity (deprecated - now loaded from DB) + #[serde(default)] + pub history: Option<Vec<MeshChatHistoryMessage>>, + /// Context type: "mesh", "task", or "subtask" + #[serde(default)] + pub context_type: Option<String>, + /// Task ID if context is task/subtask + #[serde(default)] + pub context_task_id: Option<Uuid>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshChatResponse { + /// The LLM's response message + pub response: String, + /// Tool calls that were executed + pub tool_calls: Vec<MeshToolCallInfo>, + /// Questions pending user answers (pauses conversation) + #[serde(skip_serializing_if = "Option::is_none")] + pub pending_questions: Option<Vec<UserQuestion>>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct MeshToolCallInfo { + pub name: String, + pub result: ToolResult, +} + +/// Enum to hold LLM clients +enum LlmClient { + Groq(GroqClient), + Claude(ClaudeClient), +} + +/// Unified result from LLM call +struct LlmResult { + content: Option<String>, + tool_calls: Vec<ToolCall>, + raw_tool_calls: Vec<ToolCallResponse>, + finish_reason: String, +} + +/// Chat with mesh orchestrator at the top level (no specific task context) +#[utoipa::path( + post, + path = "/api/v1/mesh/chat", + request_body = MeshChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = MeshChatResponse), + (status = 401, description = "Unauthorized"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn mesh_toplevel_chat_handler( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Json(request): Json<MeshChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or(LlmModel::ClaudeSonnet); + + tracing::info!("Mesh top-level chat using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::ClaudeOpus => match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::GroqKimi => match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "GROQ_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Groq client error: {}", e) })), + ) + .into_response(); + } + }, + }; + + // Build context about all tasks and daemons + let mesh_context = build_mesh_overview_context(pool, &state, auth.owner_id).await; + + // Build agentic system prompt for top-level mesh orchestration + let system_prompt = format!( + r#"You are an intelligent task orchestration agent. You help users manage and coordinate tasks running on connected daemons with Claude Code containers. + +## Your Capabilities +You have access to tools for: +- **Task Lifecycle**: create_task, run_task, pause_task, resume_task, interrupt_task, discard_task +- **Task Queries**: query_task_status, list_tasks, list_subtasks, list_siblings, list_daemons +- **File Access**: list_files, read_file (read documents from the files system) +- **Task Communication**: send_message_to_task, update_task_plan +- **Overlay/Merge Operations**: peek_sibling_overlay, get_overlay_diff, preview_merge, merge_subtask, complete_task, set_merge_mode + +## Current Mesh Overview +{mesh_context} + +## Agentic Behavior Guidelines + +### 1. Analyze Before Acting +- For complex orchestration requests, first gather information using query_task_status, list_tasks, or list_daemons +- Understand the current state before making changes +- For simple, direct requests (e.g., "create a new task"), you can act immediately + +### 2. Plan Multi-Step Operations +- Break complex orchestration into logical steps +- For parallel execution: create multiple subtasks, then run them on different daemons +- For sequential execution: create subtasks and run them in order + +### 3. Create and Manage Tasks +- Use create_task to create new top-level tasks or subtasks +- Assign appropriate priorities and plans +- **Repository Default**: When creating tasks, use the daemon's working directory as the repository_url by default (shown as "Default Repository" above). Only omit repository_url if the task doesn't involve code, or use a different URL if the user explicitly requests it. +- If a working directory is a git repository, use it as the repository_url for code-related tasks + +### 4. Coordinate Multiple Tasks +- Use list_tasks to see all tasks and their statuses +- Use list_daemons to see available compute resources +- Balance workload across daemons + +### 5. Be Efficient +- Don't over-analyze simple requests +- Use the minimum number of tool calls needed +- Provide clear summaries of actions taken + +## Important Notes +- Task IDs are UUIDs - ensure you use the correct format +- Running a task requires at least one connected daemon +- When creating subtasks, specify the parent_task_id +- Always confirm destructive operations (discard_task) with the user"#, + mesh_context = mesh_context + ); + + // Run the shared agentic loop + run_mesh_agentic_loop(pool, &state, &llm_client, system_prompt, &request, auth.owner_id).await +} + +/// Chat with task mesh orchestrator using LLM tool calling (scoped by owner) +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/chat", + request_body = MeshChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = MeshChatResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Task not found"), + (status = 500, description = "Internal server error") + ), + params( + ("id" = Uuid, Path, description = "Task ID (context for orchestration)") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn mesh_chat_handler( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(task_id): Path<Uuid>, + Json(request): Json<MeshChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + // Get the context task (scoped by owner) + let task = match repository::get_task_for_owner(pool, task_id, auth.owner_id).await { + Ok(Some(task)) => task, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(json!({ "error": "Task not found" })), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Database error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Database error: {}", e) })), + ) + .into_response(); + } + }; + + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or(LlmModel::ClaudeSonnet); + + tracing::info!("Mesh chat using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::ClaudeOpus => match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "ANTHROPIC_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Claude client error: {}", e) })), + ) + .into_response(); + } + }, + LlmModel::GroqKimi => match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "GROQ_API_KEY not configured" })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Groq client error: {}", e) })), + ) + .into_response(); + } + }, + }; + + // Build context about the current task and mesh state + let task_context = build_task_context(&task); + + // Build agentic system prompt for task orchestration + let system_prompt = format!( + r#"You are an intelligent task orchestration agent. You help users manage and coordinate tasks running on connected daemons with Claude Code containers. + +## Your Capabilities +You have access to tools for: +- **Task Lifecycle**: create_task, run_task, pause_task, resume_task, interrupt_task, discard_task +- **Task Queries**: query_task_status, list_tasks, list_subtasks, list_siblings, list_daemons +- **File Access**: list_files, read_file (read documents from the files system) +- **Task Communication**: send_message_to_task, update_task_plan +- **Overlay/Merge Operations**: peek_sibling_overlay, get_overlay_diff, preview_merge, merge_subtask, complete_task, set_merge_mode + +## Current Context +{task_context} + +## Agentic Behavior Guidelines + +### 1. Analyze Before Acting +- For complex orchestration requests, first gather information using query_task_status, list_tasks, or list_daemons +- Understand the current state before making changes +- For simple, direct requests (e.g., "pause this task"), you can act immediately + +### 2. Plan Multi-Step Operations +- Break complex orchestration into logical steps +- For parallel execution: create multiple subtasks, then run them on different daemons +- For sequential execution: create subtasks and run them in order + +### 3. Monitor Task Progress +- Use query_task_status to check on running tasks +- Watch for status changes and react accordingly +- Handle failures gracefully (retry, escalate, or report) + +### 4. Coordinate Sibling Tasks +- Use peek_sibling_overlay to see what other tasks have changed +- Preview merges before completing to catch conflicts +- Coordinate timing when multiple tasks need to merge + +### 5. Be Efficient +- Don't over-analyze simple requests +- Use the minimum number of tool calls needed +- Provide clear summaries of actions taken + +## Important Notes +- Task IDs are UUIDs - ensure you use the correct format +- Running a task requires at least one connected daemon +- Overlay operations require the task to have been run at least once +- Always confirm destructive operations (discard_task) with the user +- When creating subtasks for this task, use parent_task_id: {task_id}"#, + task_context = task_context, + task_id = task_id + ); + + // Run the shared agentic loop + run_mesh_agentic_loop(pool, &state, &llm_client, system_prompt, &request, auth.owner_id).await +} + +fn build_task_context(task: &crate::db::models::Task) -> String { + let mut context = format!( + "Current Task: {} (ID: {})\n", + task.name, task.id + ); + context.push_str(&format!("Status: {}\n", task.status)); + context.push_str(&format!("Priority: {}\n", task.priority)); + + if let Some(ref desc) = task.description { + context.push_str(&format!("Description: {}\n", desc)); + } + + // Truncate plan preview if too long + let plan_preview = if task.plan.len() > 200 { + format!("{}...", &task.plan[..200]) + } else { + task.plan.clone() + }; + context.push_str(&format!("Plan: {}\n", plan_preview)); + + if let Some(ref summary) = task.progress_summary { + context.push_str(&format!("Progress: {}\n", summary)); + } + + if let Some(ref error) = task.error_message { + context.push_str(&format!("Error: {}\n", error)); + } + + // Repository info + if let Some(ref url) = task.repository_url { + context.push_str(&format!("Repository: {}\n", url)); + } + if let Some(ref branch) = task.base_branch { + context.push_str(&format!("Base branch: {}\n", branch)); + } + + context +} + +/// Build overview context for top-level mesh orchestration +async fn build_mesh_overview_context(pool: &sqlx::PgPool, state: &SharedState, owner_id: Uuid) -> String { + let mut context = String::new(); + + // Get task counts by status + match repository::list_tasks_for_owner(pool, owner_id).await { + Ok(tasks) => { + let total = tasks.len(); + let pending = tasks.iter().filter(|t| t.status == "pending").count(); + let running = tasks.iter().filter(|t| t.status == "running").count(); + let paused = tasks.iter().filter(|t| t.status == "paused").count(); + let done = tasks.iter().filter(|t| t.status == "done").count(); + let failed = tasks.iter().filter(|t| t.status == "failed").count(); + + context.push_str(&format!( + "Tasks: {} total ({} pending, {} running, {} paused, {} done, {} failed)\n", + total, pending, running, paused, done, failed + )); + + // List recent/active tasks + if !tasks.is_empty() { + context.push_str("\nRecent Tasks:\n"); + for task in tasks.iter().take(5) { + context.push_str(&format!( + " - {} (ID: {}, Status: {})\n", + task.name, task.id, task.status + )); + } + if tasks.len() > 5 { + context.push_str(&format!(" ... and {} more\n", tasks.len() - 5)); + } + } + } + Err(e) => { + context.push_str(&format!("Error fetching tasks: {}\n", e)); + } + } + + // Get connected daemons for this owner + let owner_daemons: Vec<_> = state.daemon_connections.iter() + .filter(|e| e.value().owner_id == owner_id) + .collect(); + let daemon_count = owner_daemons.len(); + context.push_str(&format!("\nConnected Daemons: {}\n", daemon_count)); + + for entry in owner_daemons.iter().take(3) { + let daemon = entry.value(); + let working_dir = daemon.working_directory.as_deref().unwrap_or("not set"); + context.push_str(&format!( + " - {} (ID: {}, Working Directory: {})\n", + daemon.hostname.as_deref().unwrap_or("unknown"), + daemon.id, + working_dir + )); + } + + // Add default repository guidance if there's exactly one daemon with a working directory + let daemons_with_working_dir: Vec<_> = owner_daemons.iter() + .filter(|e| e.value().working_directory.is_some()) + .collect(); + + if daemons_with_working_dir.len() == 1 { + if let Some(dir) = &daemons_with_working_dir[0].value().working_directory { + context.push_str(&format!( + "\nDefault Repository: {} (use this as repository_url when creating tasks unless user specifies otherwise)\n", + dir + )); + } + } + + context +} + +/// Run the shared agentic loop for mesh chat +async fn run_mesh_agentic_loop( + pool: &sqlx::PgPool, + state: &SharedState, + llm_client: &LlmClient, + system_prompt: String, + request: &MeshChatRequest, + owner_id: Uuid, +) -> axum::response::Response { + // Get or create conversation for storing messages + let conversation = match repository::get_or_create_active_conversation(pool, owner_id).await { + Ok(c) => c, + Err(e) => { + tracing::error!("Failed to get/create conversation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("Failed to initialize conversation: {}", e) })), + ) + .into_response(); + } + }; + + // Build initial messages + let mut messages = vec![Message { + role: "system".to_string(), + content: Some(system_prompt), + tool_calls: None, + tool_call_id: None, + }]; + + // Load conversation history from database (or use provided for backwards compatibility) + if let Some(history) = &request.history { + // Legacy: use provided history + for hist_msg in history { + messages.push(Message { + role: hist_msg.role.clone(), + content: Some(hist_msg.content.clone()), + tool_calls: None, + tool_call_id: None, + }); + } + tracing::info!( + history_messages = history.len(), + "Loaded mesh conversation history from request (legacy)" + ); + } else { + // New: load from database + match repository::list_chat_messages(pool, conversation.id, Some(50)).await { + Ok(db_messages) => { + for msg in db_messages { + messages.push(Message { + role: msg.role.clone(), + content: Some(msg.content.clone()), + tool_calls: None, + tool_call_id: None, + }); + } + tracing::info!( + history_messages = messages.len() - 1, // minus system message + "Loaded mesh conversation history from database" + ); + } + Err(e) => { + tracing::warn!("Failed to load chat history: {}", e); + // Continue without history + } + } + } + + // Add current user message + messages.push(Message { + role: "user".to_string(), + content: Some(request.message.clone()), + tool_calls: None, + tool_call_id: None, + }); + + // State for tracking + let mut all_tool_call_infos: Vec<MeshToolCallInfo> = Vec::new(); + let mut final_response: Option<String> = None; + let mut consecutive_failures = 0; + const MAX_CONSECUTIVE_FAILURES: usize = 3; + let mut pending_questions: Option<Vec<UserQuestion>> = None; + + // Multi-turn agentic tool calling loop + for round in 0..MAX_TOOL_ROUNDS { + tracing::info!( + round = round, + total_tool_calls = all_tool_call_infos.len(), + "Mesh agentic loop iteration" + ); + + // Check consecutive failures + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + tracing::warn!( + "Breaking mesh loop due to {} consecutive failures", + consecutive_failures + ); + final_response = Some( + "I encountered multiple consecutive errors and stopped. \ + Please check the task state and try again." + .to_string(), + ); + break; + } + + // Call the appropriate LLM API + let result = match llm_client { + LlmClient::Groq(groq) => { + match groq.chat_with_tools(messages.clone(), &MESH_TOOLS).await { + Ok(r) => LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls: r.raw_tool_calls, + finish_reason: r.finish_reason, + }, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("LLM API error: {}", e) })), + ) + .into_response(); + } + } + } + LlmClient::Claude(claude_client) => { + let claude_messages = claude::groq_messages_to_claude(&messages); + match claude_client + .chat_with_tools(claude_messages, &MESH_TOOLS) + .await + { + Ok(r) => { + let raw_tool_calls: Vec<ToolCallResponse> = r + .tool_calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + call_type: "function".to_string(), + function: crate::llm::groq::FunctionCall { + name: tc.name.clone(), + arguments: tc.arguments.to_string(), + }, + }) + .collect(); + + LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls, + finish_reason: r.stop_reason, + } + } + Err(e) => { + tracing::error!("Claude API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": format!("LLM API error: {}", e) })), + ) + .into_response(); + } + } + } + }; + + // Check if there are tool calls to execute + if result.tool_calls.is_empty() { + final_response = result.content; + break; + } + + // Add assistant message with tool calls to conversation + messages.push(Message { + role: "assistant".to_string(), + content: result.content.clone(), + tool_calls: Some(result.raw_tool_calls.clone()), + tool_call_id: None, + }); + + // Execute each tool call + for (i, tool_call) in result.tool_calls.iter().enumerate() { + tracing::info!(tool = %tool_call.name, round = round, "Executing mesh tool call"); + + // Parse the tool call + let mut execution_result = parse_mesh_tool_call(tool_call); + + // Handle async mesh tool requests + if let Some(mesh_request) = execution_result.request.take() { + let async_result = handle_mesh_request(pool, state, mesh_request, owner_id).await; + execution_result.success = async_result.success; + execution_result.message = async_result.message; + execution_result.data = async_result.data; + } + + // Track consecutive failures + if execution_result.success { + consecutive_failures = 0; + } else { + consecutive_failures += 1; + tracing::warn!( + tool = %tool_call.name, + consecutive_failures = consecutive_failures, + "Mesh tool call failed" + ); + } + + // Check for pending user questions + if let Some(questions) = execution_result.pending_questions { + tracing::info!( + question_count = questions.len(), + "Mesh LLM requesting user input" + ); + pending_questions = Some(questions); + all_tool_call_infos.push(MeshToolCallInfo { + name: tool_call.name.clone(), + result: ToolResult { + success: execution_result.success, + message: execution_result.message.clone(), + }, + }); + break; + } + + // Build tool result message + let result_content = if let Some(data) = &execution_result.data { + json!({ + "success": execution_result.success, + "message": execution_result.message, + "data": data + }) + .to_string() + } else { + json!({ + "success": execution_result.success, + "message": execution_result.message + }) + .to_string() + }; + + // Add tool result message + let tool_call_id = match llm_client { + LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(), + LlmClient::Claude(_) => tool_call.id.clone(), + }; + + messages.push(Message { + role: "tool".to_string(), + content: Some(result_content), + tool_calls: None, + tool_call_id: Some(tool_call_id), + }); + + // Track for response + all_tool_call_infos.push(MeshToolCallInfo { + name: tool_call.name.clone(), + result: ToolResult { + success: execution_result.success, + message: execution_result.message, + }, + }); + } + + // If user questions are pending, pause + if pending_questions.is_some() { + final_response = result.content; + break; + } + + // If finish reason indicates completion, exit loop + let finish_lower = result.finish_reason.to_lowercase(); + if finish_lower == "stop" || finish_lower == "end_turn" { + final_response = result.content; + break; + } + } + + // Build response + let response_text = final_response.unwrap_or_else(|| { + if all_tool_call_infos.is_empty() { + "I couldn't understand your request. Please try rephrasing.".to_string() + } else { + format!( + "Done! Executed {} tool{}.", + all_tool_call_infos.len(), + if all_tool_call_infos.len() == 1 { + "" + } else { + "s" + } + ) + } + }); + + // Save messages to database (only if not using legacy history mode) + if request.history.is_none() { + let context_type = request.context_type.clone().unwrap_or_else(|| "mesh".to_string()); + + // Validate context_task_id exists before using it (to avoid FK constraint violation) + let context_task_id = if let Some(task_id) = request.context_task_id { + match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(_)) => Some(task_id), + Ok(None) => { + tracing::warn!("context_task_id {} not found, ignoring", task_id); + None + } + Err(e) => { + tracing::warn!("Failed to validate context_task_id {}: {}", task_id, e); + None + } + } + } else { + None + }; + + // Save user message + if let Err(e) = repository::add_chat_message( + pool, + conversation.id, + "user", + &request.message, + &context_type, + context_task_id, + None, + None, + ) + .await + { + tracing::warn!("Failed to save user message to DB: {}", e); + } + + // Serialize tool calls for storage + let tool_calls_json = if all_tool_call_infos.is_empty() { + None + } else { + Some(serde_json::to_value(&all_tool_call_infos).unwrap_or_default()) + }; + + // Serialize pending questions for storage + let pending_questions_json = pending_questions + .as_ref() + .map(|q| serde_json::to_value(q).unwrap_or_default()); + + // Save assistant message + if let Err(e) = repository::add_chat_message( + pool, + conversation.id, + "assistant", + &response_text, + &context_type, + context_task_id, + tool_calls_json, + pending_questions_json, + ) + .await + { + tracing::warn!("Failed to save assistant message to DB: {}", e); + } + + tracing::info!( + conversation_id = %conversation.id, + context_type = %context_type, + "Saved mesh chat messages to database" + ); + } + + ( + StatusCode::OK, + Json(MeshChatResponse { + response: response_text, + tool_calls: all_tool_call_infos, + pending_questions, + }), + ) + .into_response() +} + +/// Result from handling an async mesh tool request +struct MeshRequestResult { + success: bool, + message: String, + data: Option<serde_json::Value>, +} + +/// Handle async mesh tool requests that require database/daemon access +async fn handle_mesh_request( + pool: &sqlx::PgPool, + state: &SharedState, + request: MeshToolRequest, + owner_id: Uuid, +) -> MeshRequestResult { + match request { + MeshToolRequest::CreateTask { + name, + plan, + parent_task_id, + repository_url, + base_branch, + merge_mode, + priority, + } => { + // Check if repository_url matches a daemon's working directory (for this owner) + let is_daemon_working_dir = repository_url.as_ref().map(|url| { + state.daemon_connections.iter().any(|entry| { + entry.value().owner_id == owner_id && + entry.value().working_directory.as_ref() == Some(url) + }) + }).unwrap_or(false); + + // Derive completion_action from merge_mode, or default to "branch" if using daemon working dir + let (completion_action, target_repo_path) = if let Some(ref mode) = merge_mode { + // Explicit merge_mode provided - derive from it + let action = match mode.as_str() { + "pr" => "pr".to_string(), + "auto" => "merge".to_string(), + "manual" => "branch".to_string(), + _ => "none".to_string(), + }; + // If using daemon working dir and action involves the repo, set target_repo_path + let target = if is_daemon_working_dir && action != "none" { + repository_url.clone() + } else { + None + }; + (Some(action), target) + } else if is_daemon_working_dir { + // No merge_mode but using daemon working dir - default to "branch" + (Some("branch".to_string()), repository_url.clone()) + } else { + (None, None) + }; + + let create_req = CreateTaskRequest { + name: name.clone(), + description: None, + plan, + parent_task_id, + repository_url, + base_branch, + target_branch: None, + merge_mode, + priority: priority.unwrap_or(0), + target_repo_path, + completion_action, + continue_from_task_id: None, + copy_files: None, + }; + + match repository::create_task_for_owner(pool, owner_id, create_req).await { + Ok(task) => MeshRequestResult { + success: true, + message: format!("Created task '{}' with ID {}", name, task.id), + data: Some(json!({ + "taskId": task.id, + "name": task.name, + "status": task.status, + })), + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to create task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::RunTask { task_id, daemon_id } => { + // Get task to check status + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "pending" && task.status != "paused" { + return MeshRequestResult { + success: false, + message: format!( + "Task cannot be run - status is '{}' (must be 'pending' or 'paused')", + task.status + ), + data: None, + }; + } + + // Find a daemon to run on (must belong to this owner) + let target_daemon_id = if let Some(id) = daemon_id { + // Verify the specified daemon belongs to this owner + if !state.daemon_connections.iter().any(|d| d.value().id == id && d.value().owner_id == owner_id) { + return MeshRequestResult { + success: false, + message: "Specified daemon not found or not accessible.".to_string(), + data: None, + }; + } + id + } else { + // Find any connected daemon for this owner + let daemon = state.daemon_connections.iter().find(|d| d.value().owner_id == owner_id); + match daemon { + Some(d) => d.value().id, + None => { + return MeshRequestResult { + success: false, + message: "No daemons connected for your account. Cannot run task.".to_string(), + data: None, + } + } + } + }; + + // Check if this is an orchestrator (depth 0 with subtasks) + let subtask_count = match repository::list_subtasks_for_owner(pool, task_id, owner_id).await { + Ok(subtasks) => subtasks.len(), + Err(_) => 0, + }; + let is_orchestrator = task.depth == 0 && subtask_count > 0; + + // Send SpawnTask command to daemon + let command = DaemonCommand::SpawnTask { + task_id, + task_name: task.name.clone(), + plan: task.plan.clone(), + repo_url: task.repository_url.clone(), + base_branch: task.base_branch.clone(), + target_branch: task.target_branch.clone(), + parent_task_id: task.parent_task_id, + depth: task.depth, + is_orchestrator, + target_repo_path: task.target_repo_path.clone(), + completion_action: task.completion_action.clone(), + continue_from_task_id: task.continue_from_task_id, + copy_files: task.copy_files.as_ref().and_then(|v| serde_json::from_value(v.clone()).ok()), + }; + + match state.send_daemon_command(target_daemon_id, command).await { + Ok(()) => { + // Update task status to running + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("running".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "running".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} is now running on daemon {}", task_id, target_daemon_id), + data: Some(json!({ + "taskId": task_id, + "daemonId": target_daemon_id, + "status": "running", + })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to start task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::PauseTask { task_id } => { + // Get task and its daemon + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "running" { + return MeshRequestResult { + success: false, + message: format!("Task is not running (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::PauseTask { task_id }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to pause task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("paused".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "paused".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} paused", task_id), + data: Some(json!({ "taskId": task_id, "status": "paused" })), + } + } + + MeshToolRequest::ResumeTask { task_id } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "paused" { + return MeshRequestResult { + success: false, + message: format!("Task is not paused (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::ResumeTask { task_id }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to resume task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("running".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "running".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!("Task {} resumed", task_id), + data: Some(json!({ "taskId": task_id, "status": "running" })), + } + } + + MeshToolRequest::InterruptTask { task_id, graceful } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { task_id, graceful }; + if let Err(e) = state.send_daemon_command(daemon_id, command).await { + return MeshRequestResult { + success: false, + message: format!("Failed to interrupt task: {}", e), + data: None, + }; + } + } + + // Update status + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("paused".to_string()), + version: Some(task.version), + ..Default::default() + }; + + if let Ok(Some(updated)) = repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "paused".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + } + + MeshRequestResult { + success: true, + message: format!( + "Task {} {}interrupted", + task_id, + if graceful { "gracefully " } else { "" } + ), + data: Some(json!({ "taskId": task_id, "status": "paused" })), + } + } + + MeshToolRequest::DiscardTask { task_id } => { + match repository::delete_task_for_owner(pool, task_id, owner_id).await { + Ok(true) => MeshRequestResult { + success: true, + message: format!("Task {} discarded", task_id), + data: Some(json!({ "taskId": task_id, "deleted": true })), + }, + Ok(false) => MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to delete task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::QueryTaskStatus { task_id } => { + match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(task)) => MeshRequestResult { + success: true, + message: format!("Task '{}' is {}", task.name, task.status), + data: Some(json!({ + "taskId": task.id, + "name": task.name, + "status": task.status, + "priority": task.priority, + "description": task.description, + "plan": task.plan, + "progressSummary": task.progress_summary, + "errorMessage": task.error_message, + "repositoryUrl": task.repository_url, + "baseBranch": task.base_branch, + "targetBranch": task.target_branch, + "mergeMode": task.merge_mode, + "prUrl": task.pr_url, + "daemonId": task.daemon_id, + "containerId": task.container_id, + "createdAt": task.created_at, + "startedAt": task.started_at, + "completedAt": task.completed_at, + })), + }, + Ok(None) => MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListTasks { + status_filter, + parent_task_id, + } => { + // TODO: Add filtering support to repository + match repository::list_tasks_for_owner(pool, owner_id).await { + Ok(mut tasks) => { + // Apply filters + if let Some(ref status) = status_filter { + tasks.retain(|t| &t.status == status); + } + if let Some(ref parent_id) = parent_task_id { + tasks.retain(|t| t.parent_task_id.as_ref() == Some(parent_id)); + } + + let task_data: Vec<serde_json::Value> = tasks + .iter() + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + "parentTaskId": t.parent_task_id, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} tasks", tasks.len()), + data: Some(json!({ "tasks": task_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListSubtasks { task_id } => { + match repository::list_subtasks_for_owner(pool, task_id, owner_id).await { + Ok(subtasks) => { + let subtask_data: Vec<serde_json::Value> = subtasks + .iter() + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} subtasks", subtasks.len()), + data: Some(json!({ "subtasks": subtask_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListSiblings { task_id } => { + // Get task to find parent + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + let Some(parent_id) = task.parent_task_id else { + return MeshRequestResult { + success: true, + message: "Task has no parent, so no siblings".to_string(), + data: Some(json!({ "siblings": [] })), + }; + }; + + // Get all subtasks of parent, excluding current task + match repository::list_subtasks_for_owner(pool, parent_id, owner_id).await { + Ok(siblings) => { + let sibling_data: Vec<serde_json::Value> = siblings + .iter() + .filter(|t| t.id != task_id) + .map(|t| { + json!({ + "taskId": t.id, + "name": t.name, + "status": t.status, + "priority": t.priority, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} sibling tasks", sibling_data.len()), + data: Some(json!({ "siblings": sibling_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ListDaemons => { + // Only list daemons belonging to this owner + let daemons: Vec<serde_json::Value> = state + .daemon_connections + .iter() + .filter(|entry| entry.value().owner_id == owner_id) + .map(|entry| { + let d = entry.value(); + json!({ + "daemonId": d.id, + "connectionId": d.connection_id, + "hostname": d.hostname, + "machineId": d.machine_id, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("{} daemon(s) connected", daemons.len()), + data: Some(json!({ "daemons": daemons })), + } + } + + MeshToolRequest::ListDaemonDirectories => { + let mut directories: Vec<serde_json::Value> = Vec::new(); + + // Only list directories from daemons belonging to this owner + for entry in state.daemon_connections.iter() { + let daemon = entry.value(); + + // Only include daemons belonging to this owner + if daemon.owner_id != owner_id { + continue; + } + + // Add working directory if available + if let Some(ref working_dir) = daemon.working_directory { + directories.push(json!({ + "path": working_dir, + "label": "Working Directory", + "directoryType": "working", + "hostname": daemon.hostname, + })); + } + + // Add home directory if available + if let Some(ref home_dir) = daemon.home_directory { + directories.push(json!({ + "path": home_dir, + "label": "Makima Home", + "directoryType": "home", + "hostname": daemon.hostname, + })); + } + } + + MeshRequestResult { + success: true, + message: format!("Found {} available directories", directories.len()), + data: Some(json!({ "directories": directories })), + } + } + + MeshToolRequest::ListFiles => { + match repository::list_files_for_owner(pool, owner_id).await { + Ok(files) => { + let file_data: Vec<serde_json::Value> = files + .iter() + .map(|f| { + json!({ + "fileId": f.id, + "name": f.name, + "description": f.description, + "createdAt": f.created_at, + "updatedAt": f.updated_at, + }) + }) + .collect(); + + MeshRequestResult { + success: true, + message: format!("Found {} files", files.len()), + data: Some(json!({ "files": file_data })), + } + } + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::ReadFile { file_id } => { + match repository::get_file_for_owner(pool, file_id, owner_id).await { + Ok(Some(file)) => { + // Convert body elements to readable text + let body_content: Vec<serde_json::Value> = file + .body + .iter() + .map(|elem| { + match elem { + crate::db::models::BodyElement::Heading { level, text } => { + json!({ "type": "heading", "level": level, "text": text }) + } + crate::db::models::BodyElement::Paragraph { text } => { + json!({ "type": "paragraph", "text": text }) + } + crate::db::models::BodyElement::Code { language, content } => { + json!({ "type": "code", "language": language, "content": content }) + } + crate::db::models::BodyElement::List { ordered, items } => { + json!({ "type": "list", "ordered": ordered, "items": items }) + } + crate::db::models::BodyElement::Chart { chart_type, title, data, config: _ } => { + let data_count = data.as_array().map(|arr| arr.len()).unwrap_or(0); + json!({ "type": "chart", "chartType": chart_type, "title": title, "dataPoints": data_count }) + } + crate::db::models::BodyElement::Image { src, alt, caption } => { + json!({ "type": "image", "src": src, "alt": alt, "caption": caption }) + } + } + }) + .collect(); + + // Also build a plain text version for easier reading + let plain_text: String = file + .body + .iter() + .filter_map(|elem| { + match elem { + crate::db::models::BodyElement::Heading { level, text } => { + Some(format!("{} {}", "#".repeat(*level as usize), text)) + } + crate::db::models::BodyElement::Paragraph { text } => { + Some(text.clone()) + } + crate::db::models::BodyElement::Code { language, content } => { + let lang = language.as_deref().unwrap_or(""); + Some(format!("```{}\n{}\n```", lang, content)) + } + crate::db::models::BodyElement::List { ordered, items } => { + let list_text: Vec<String> = items.iter().enumerate().map(|(i, item)| { + if *ordered { + format!("{}. {}", i + 1, item) + } else { + format!("- {}", item) + } + }).collect(); + Some(list_text.join("\n")) + } + _ => None, + } + }) + .collect::<Vec<_>>() + .join("\n\n"); + + // Convert transcript entries to JSON + let transcript: Vec<serde_json::Value> = file + .transcript + .iter() + .map(|entry| { + json!({ + "id": entry.id, + "speaker": entry.speaker, + "start": entry.start, + "end": entry.end, + "text": entry.text, + }) + }) + .collect(); + + // Build a plain text transcript for easier reading + let transcript_text: String = file + .transcript + .iter() + .map(|entry| { + format!("[{:.1}s] {}: {}", entry.start, entry.speaker, entry.text) + }) + .collect::<Vec<_>>() + .join("\n"); + + MeshRequestResult { + success: true, + message: format!("Read file '{}'", file.name), + data: Some(json!({ + "fileId": file.id, + "name": file.name, + "description": file.description, + "summary": file.summary, + "body": body_content, + "plainText": plain_text, + "transcript": transcript, + "transcriptText": transcript_text, + "transcriptCount": file.transcript.len(), + "createdAt": file.created_at, + "updatedAt": file.updated_at, + })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: format!("File {} not found", file_id), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + }, + } + } + + MeshToolRequest::SendMessageToTask { task_id, message } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + if task.status != "running" { + return MeshRequestResult { + success: false, + message: format!("Task is not running (status: {})", task.status), + data: None, + }; + } + + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::SendMessage { task_id, message }; + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => MeshRequestResult { + success: true, + message: "Message sent to task".to_string(), + data: Some(json!({ "taskId": task_id })), + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to send message: {}", e), + data: None, + }, + } + } else { + MeshRequestResult { + success: false, + message: "Task has no daemon assigned".to_string(), + data: None, + } + } + } + + MeshToolRequest::UpdateTaskPlan { + task_id, + new_plan, + interrupt_if_running, + } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + // Interrupt if running and requested + if task.status == "running" && interrupt_if_running { + if let Some(daemon_id) = task.daemon_id { + let command = DaemonCommand::InterruptTask { + task_id, + graceful: true, + }; + let _ = state.send_daemon_command(daemon_id, command).await; + } + } + + let update_req = crate::db::models::UpdateTaskRequest { + plan: Some(new_plan), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: updated.status.clone(), + updated_fields: vec!["plan".to_string()], + updated_by: "system".to_string(), + }); + MeshRequestResult { + success: true, + message: "Task plan updated".to_string(), + data: Some(json!({ "taskId": task_id })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to update task: {}", e), + data: None, + }, + } + } + + // Overlay operations - these require daemon communication + // For now, return placeholder responses since daemon implementation is separate + MeshToolRequest::PeekSiblingOverlay { sibling_task_id } => MeshRequestResult { + success: false, + message: format!( + "Overlay operations require a connected daemon. Task {} may not have overlay data yet.", + sibling_task_id + ), + data: None, + }, + + MeshToolRequest::GetOverlayDiff { task_id } => MeshRequestResult { + success: false, + message: format!( + "Overlay operations require a connected daemon. Task {} may not have overlay data yet.", + task_id + ), + data: None, + }, + + MeshToolRequest::PreviewMerge { task_id } => MeshRequestResult { + success: false, + message: format!( + "Merge preview requires a connected daemon. Task {} may not have overlay data yet.", + task_id + ), + data: None, + }, + + MeshToolRequest::MergeSubtask { task_id } => MeshRequestResult { + success: false, + message: format!( + "Merge operations require a connected daemon. Task {}", + task_id + ), + data: None, + }, + + MeshToolRequest::CompleteTask { task_id } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + // Update status to done + let update_req = crate::db::models::UpdateTaskRequest { + status: Some("done".to_string()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: "done".to_string(), + updated_fields: vec!["status".to_string()], + updated_by: "system".to_string(), + }); + let merge_mode = task.merge_mode.unwrap_or_else(|| "pr".to_string()); + MeshRequestResult { + success: true, + message: format!( + "Task {} completed. Merge mode: {}", + task_id, + &merge_mode + ), + data: Some(json!({ + "taskId": task_id, + "status": "done", + "mergeMode": merge_mode, + })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to complete task: {}", e), + data: None, + }, + } + } + + MeshToolRequest::SetMergeMode { task_id, mode } => { + let task = match repository::get_task_for_owner(pool, task_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return MeshRequestResult { + success: false, + message: format!("Task {} not found", task_id), + data: None, + } + } + Err(e) => { + return MeshRequestResult { + success: false, + message: format!("Database error: {}", e), + data: None, + } + } + }; + + let update_req = crate::db::models::UpdateTaskRequest { + merge_mode: Some(mode.clone()), + version: Some(task.version), + ..Default::default() + }; + + match repository::update_task_for_owner(pool, task_id, owner_id, update_req).await { + Ok(Some(updated)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(task.owner_id), + version: updated.version, + status: updated.status, + updated_fields: vec!["merge_mode".to_string()], + updated_by: "system".to_string(), + }); + MeshRequestResult { + success: true, + message: format!("Merge mode set to '{}'", mode), + data: Some(json!({ "taskId": task_id, "mergeMode": mode })), + } + } + Ok(None) => MeshRequestResult { + success: false, + message: "Task not found".to_string(), + data: None, + }, + Err(e) => MeshRequestResult { + success: false, + message: format!("Failed to update merge mode: {}", e), + data: None, + }, + } + } + } +} + +// ============================================================================= +// Chat History Endpoints +// ============================================================================= + +use crate::db::models::MeshChatHistoryResponse; + +/// Get chat history for the current conversation (requires authentication) +#[utoipa::path( + get, + path = "/api/v1/mesh/chat/history", + responses( + (status = 200, description = "Chat history", body = MeshChatHistoryResponse), + (status = 401, description = "Unauthorized"), + (status = 503, description = "Database not configured"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn get_chat_history( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + let conversation = match repository::get_or_create_active_conversation(pool, auth.owner_id).await { + Ok(c) => c, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response() + } + }; + + let messages = match repository::list_chat_messages(pool, conversation.id, None).await { + Ok(m) => m, + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response() + } + }; + + ( + StatusCode::OK, + Json(MeshChatHistoryResponse { + conversation_id: conversation.id, + messages, + }), + ) + .into_response() +} + +/// Clear chat history (archives current conversation and starts new, requires authentication) +#[utoipa::path( + delete, + path = "/api/v1/mesh/chat/history", + responses( + (status = 200, description = "History cleared"), + (status = 401, description = "Unauthorized"), + (status = 503, description = "Database not configured"), + (status = 500, description = "Internal server error") + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn clear_chat_history( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "error": "Database not configured" })), + ) + .into_response(); + }; + + match repository::clear_conversation(pool, auth.owner_id).await { + Ok(new_conv) => ( + StatusCode::OK, + Json(json!({ "success": true, "conversationId": new_conv.id })), + ) + .into_response(), + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": e.to_string() })), + ) + .into_response(), + } +} diff --git a/makima/src/server/handlers/mesh_daemon.rs b/makima/src/server/handlers/mesh_daemon.rs new file mode 100644 index 0000000..644d0bc --- /dev/null +++ b/makima/src/server/handlers/mesh_daemon.rs @@ -0,0 +1,959 @@ +//! WebSocket handler for daemon connections. +//! +//! Daemons connect to report task progress, stream output, and receive commands. +//! Each daemon manages Claude Code containers on its local machine. +//! +//! ## Authentication +//! +//! Daemons authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +//! The API key is validated against the database and the daemon is associated with +//! the corresponding owner_id for data isolation. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; +use futures::{SinkExt, StreamExt}; +use serde::Deserialize; +use sqlx::Row; +use tokio::sync::mpsc; +use uuid::Uuid; + +use crate::db::repository; +use crate::server::auth::{hash_api_key, API_KEY_HEADER}; +use crate::server::messages::ApiError; +use crate::server::state::{ + DaemonCommand, SharedState, TaskOutputNotification, TaskUpdateNotification, +}; + +// ============================================================================= +// Claude Code JSON Output Parsing +// ============================================================================= + +/// Claude Code stream-json message structure +#[derive(Debug, Deserialize)] +struct ClaudeMessage { + #[serde(rename = "type")] + msg_type: String, + subtype: Option<String>, + message: Option<ClaudeMessageContent>, + tool_name: Option<String>, + tool_input: Option<serde_json::Value>, + tool_result: Option<ClaudeToolResult>, + result: Option<String>, + cost_usd: Option<f64>, + duration_ms: Option<u64>, + error: Option<String>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeMessageContent { + content: Option<Vec<ClaudeContentBlock>>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeContentBlock { + #[serde(rename = "type")] + block_type: String, + text: Option<String>, + name: Option<String>, + input: Option<serde_json::Value>, +} + +#[derive(Debug, Deserialize)] +struct ClaudeToolResult { + content: Option<String>, + is_error: Option<bool>, +} + +/// Parse a line of Claude Code output into a structured notification +fn parse_claude_output(task_id: Uuid, owner_id: Uuid, line: &str, is_partial: bool) -> Option<TaskOutputNotification> { + let trimmed = line.trim(); + if trimmed.is_empty() { + return None; + } + + // Try to parse as JSON + if trimmed.starts_with('{') { + if let Ok(msg) = serde_json::from_str::<ClaudeMessage>(trimmed) { + return parse_claude_message(task_id, owner_id, msg, is_partial); + } + } + + // Not JSON or failed to parse - treat as raw output + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "raw".to_string(), + content: trimmed.to_string(), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) +} + +fn parse_claude_message(task_id: Uuid, owner_id: Uuid, msg: ClaudeMessage, is_partial: bool) -> Option<TaskOutputNotification> { + match msg.msg_type.as_str() { + "system" => { + // System messages (init, etc.) - include subtype info + let content = match msg.subtype.as_deref() { + Some("init") => "Session started".to_string(), + Some(sub) => format!("System: {}", sub), + None => "System message".to_string(), + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "assistant" => { + // Extract text content from message blocks + if let Some(message) = msg.message { + if let Some(blocks) = message.content { + // Check for text blocks + let text_content: Vec<String> = blocks + .iter() + .filter(|b| b.block_type == "text") + .filter_map(|b| b.text.clone()) + .collect(); + + if !text_content.is_empty() { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "assistant".to_string(), + content: text_content.join("\n"), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + + // Check for tool_use blocks + if let Some(tool_block) = blocks.iter().find(|b| b.block_type == "tool_use") { + return Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", tool_block.name.as_deref().unwrap_or("unknown")), + tool_name: tool_block.name.clone(), + tool_input: tool_block.input.clone(), + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }); + } + } + } + None + } + + "tool_use" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_use".to_string(), + content: format!("Using tool: {}", msg.tool_name.as_deref().unwrap_or("unknown")), + tool_name: msg.tool_name, + tool_input: msg.tool_input, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + "tool_result" => { + if let Some(result) = msg.tool_result { + let content = result.content.unwrap_or_default(); + // Truncate long results + let content = if content.len() > 500 { + format!("{}...", &content[..500]) + } else { + content + }; + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "tool_result".to_string(), + content, + tool_name: None, + tool_input: None, + is_error: result.is_error, + cost_usd: None, + duration_ms: None, + is_partial, + }) + } else { + None + } + } + + "result" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "result".to_string(), + content: msg.result.unwrap_or_else(|| "Task completed".to_string()), + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: msg.cost_usd, + duration_ms: msg.duration_ms, + is_partial, + }) + } + + "error" => { + Some(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "error".to_string(), + content: msg.error.unwrap_or_else(|| "An error occurred".to_string()), + tool_name: None, + tool_input: None, + is_error: Some(true), + cost_usd: None, + duration_ms: None, + is_partial, + }) + } + + _ => None, // Skip unknown message types + } +} + +/// Message from daemon to server. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonMessage { + /// Authentication request (first message required) + Authenticate { + #[serde(rename = "apiKey")] + api_key: String, + #[serde(rename = "machineId")] + machine_id: String, + hostname: String, + #[serde(rename = "maxConcurrentTasks")] + max_concurrent_tasks: i32, + }, + /// Periodic heartbeat with current status + Heartbeat { + #[serde(rename = "activeTasks")] + active_tasks: Vec<Uuid>, + }, + /// Task output streaming (stdout/stderr from Claude Code) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + output: String, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Task status change notification + TaskStatusChange { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "oldStatus")] + old_status: String, + #[serde(rename = "newStatus")] + new_status: String, + }, + /// Task progress update with summary + TaskProgress { + #[serde(rename = "taskId")] + task_id: Uuid, + summary: String, + }, + /// Task completion notification + TaskComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + error: Option<String>, + }, + /// Register a tool key for orchestrator API access + RegisterToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + /// The API key for this orchestrator to use when calling mesh endpoints + key: String, + }, + /// Revoke a tool key when task completes + RevokeToolKey { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Response to RetryCompletionAction command + CompletionActionResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// PR URL if action was "pr" and successful + #[serde(rename = "prUrl")] + pr_url: Option<String>, + }, + /// Report daemon's available directories for task output + DaemonDirectories { + /// Current working directory of the daemon + #[serde(rename = "workingDirectory")] + working_directory: String, + /// Path to ~/.makima/home directory (for cloning completed work) + #[serde(rename = "homeDirectory")] + home_directory: String, + /// Path to worktrees directory (~/.makima/worktrees) + #[serde(rename = "worktreesDirectory")] + worktrees_directory: String, + }, + /// Response to CloneWorktree command + CloneWorktreeResult { + #[serde(rename = "taskId")] + task_id: Uuid, + success: bool, + message: String, + /// The path where the worktree was cloned + #[serde(rename = "targetDir")] + target_dir: Option<String>, + }, + /// Response to CheckTargetExists command + CheckTargetExistsResult { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Whether the target directory exists + exists: bool, + /// The path that was checked + #[serde(rename = "targetDir")] + target_dir: String, + }, +} + +/// Validated daemon authentication result. +#[derive(Debug, Clone)] +struct DaemonAuthResult { + /// User ID from the API key + user_id: Uuid, + /// Owner ID for data isolation + owner_id: Uuid, +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_daemon_api_key(pool: &sqlx::PgPool, key: &str) -> Result<DaemonAuthResult, String> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| format!("Database error: {}", e))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| format!("Failed to get user_id: {}", e))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| format!("Failed to get owner_id: {}", e))?; + let owner_id = owner_id.ok_or_else(|| "User has no default owner".to_string())?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok(DaemonAuthResult { user_id, owner_id }) + } + None => Err("Invalid or revoked API key".to_string()), + } +} + +/// WebSocket upgrade handler for daemon connections. +/// +/// Daemons must authenticate via the `X-Api-Key` header in the WebSocket upgrade request. +/// The API key is validated against the database and used to determine the owner_id +/// for data isolation. +#[utoipa::path( + get, + path = "/api/v1/mesh/daemons/connect", + params( + ("X-Api-Key" = String, Header, description = "API key for daemon authentication"), + ), + responses( + (status = 101, description = "WebSocket connection established"), + (status = 401, description = "Missing or invalid API key"), + (status = 503, description = "Database not configured"), + ), + tag = "Mesh" +)] +pub async fn daemon_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, + headers: HeaderMap, +) -> Response { + // Extract API key from headers + let api_key = match headers.get(API_KEY_HEADER).or_else(|| headers.get("x-api-key")) { + Some(value) => match value.to_str() { + Ok(key) if !key.is_empty() => key.to_string(), + _ => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("INVALID_API_KEY", "Invalid API key header value")), + ) + .into_response(); + } + }, + None => { + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("MISSING_API_KEY", "X-Api-Key header required")), + ) + .into_response(); + } + }; + + // Validate API key against database + let pool = match state.db_pool.as_ref() { + Some(pool) => pool, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + axum::Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + } + }; + + let auth_result = match validate_daemon_api_key(pool, &api_key).await { + Ok(result) => result, + Err(e) => { + tracing::warn!("Daemon authentication failed: {}", e); + return ( + StatusCode::UNAUTHORIZED, + axum::Json(ApiError::new("AUTH_FAILED", e)), + ) + .into_response(); + } + }; + + tracing::info!( + user_id = %auth_result.user_id, + owner_id = %auth_result.owner_id, + "Daemon authenticated via API key" + ); + + ws.on_upgrade(move |socket| handle_daemon_connection(socket, state, auth_result)) +} + +async fn handle_daemon_connection(socket: WebSocket, state: SharedState, auth_result: DaemonAuthResult) { + let (mut sender, mut receiver) = socket.split(); + + // Generate a unique connection ID and daemon ID + let connection_id = Uuid::new_v4().to_string(); + let daemon_id = Uuid::new_v4(); + let owner_id = auth_result.owner_id; + + // Create command channel for sending commands to this daemon + let (cmd_tx, mut cmd_rx) = mpsc::channel::<DaemonCommand>(64); + + // Wait for the daemon to send its registration info (hostname, machine_id, etc.) + // The daemon is already authenticated via API key header, but we need metadata + #[allow(unused_assignments)] + let mut registered = false; + + // Wait for registration message with metadata + loop { + tokio::select! { + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Authenticate { api_key: _, machine_id, hostname, max_concurrent_tasks }) => { + // API key was already validated via headers, but we use this message + // for backward compatibility to get the machine_id and hostname + + tracing::info!( + daemon_id = %daemon_id, + owner_id = %owner_id, + hostname = %hostname, + machine_id = %machine_id, + max_concurrent_tasks = max_concurrent_tasks, + "Daemon registered" + ); + + // Register daemon in state with owner_id + state.register_daemon( + connection_id.clone(), + daemon_id, + owner_id, + Some(hostname), + Some(machine_id), + cmd_tx.clone(), + ); + + registered = true; + + // Send authentication confirmation + let response = DaemonCommand::Authenticated { daemon_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + + break; // Exit registration loop, continue to main loop + } + Ok(_) => { + // Non-auth message before registration - still requires registration message + let response = DaemonCommand::Error { + code: "NOT_REGISTERED".into(), + message: "Must send registration message (Authenticate) first".into(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + Err(e) => { + let response = DaemonCommand::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Daemon disconnected during registration"); + return; + } + Some(Err(e)) => { + tracing::warn!("Daemon WebSocket error during registration: {}", e); + return; + } + _ => {} + } + } + } + } + + if !registered { + return; + } + + let daemon_uuid = daemon_id; + + // Main message loop after authentication + loop { + tokio::select! { + // Handle incoming messages from daemon + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<DaemonMessage>(&text) { + Ok(DaemonMessage::Heartbeat { active_tasks }) => { + tracing::trace!( + "Daemon {} heartbeat: {} active tasks", + daemon_uuid, active_tasks.len() + ); + // TODO: Update daemon last_heartbeat_at in DB + } + Ok(DaemonMessage::TaskOutput { task_id, output, is_partial }) => { + // Parse the output line and broadcast structured data + if let Some(notification) = parse_claude_output(task_id, owner_id, &output, is_partial) { + // Broadcast to connected clients + state.broadcast_task_output(notification.clone()); + + // Persist to database (fire and forget) + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let notification = notification.clone(); + tokio::spawn(async move { + if let Err(e) = repository::save_task_output( + &pool, + notification.task_id, + ¬ification.message_type, + ¬ification.content, + notification.tool_name.as_deref(), + notification.tool_input.clone(), + notification.is_error, + notification.cost_usd, + notification.duration_ms, + ).await { + tracing::warn!( + task_id = %notification.task_id, + "Failed to persist task output: {}", + e + ); + } + }); + } + } + } + Ok(DaemonMessage::TaskStatusChange { task_id, old_status, new_status }) => { + tracing::info!( + "Task {} status change: {} -> {}", + task_id, old_status, new_status + ); + + // Update task status in database and broadcast + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let new_status_owned = new_status.clone(); + tokio::spawn(async move { + match repository::update_task_status( + &pool, + task_id, + &new_status_owned, + None, + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: new_status_owned, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when updating status" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to update task status: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: new_status, + updated_fields: vec!["status".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::TaskProgress { task_id, summary }) => { + tracing::debug!("Task {} progress: {}", task_id, summary); + // TODO: Update task progress_summary in database + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: "running".into(), + updated_fields: vec!["progress_summary".into()], + updated_by: "daemon".into(), + }); + } + Ok(DaemonMessage::TaskComplete { task_id, success, error }) => { + let status = if success { "done" } else { "failed" }; + tracing::info!( + "Task {} completed: success={}, error={:?}", + task_id, success, error + ); + + // Revoke any tool keys for this task + state.revoke_tool_key(task_id); + + // Update task in database with completion info + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + let state = state.clone(); + let error_clone = error.clone(); + tokio::spawn(async move { + match repository::complete_task( + &pool, + task_id, + success, + error_clone.as_deref(), + ).await { + Ok(Some(updated_task)) => { + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: updated_task.version, + status: updated_task.status.clone(), + updated_fields: vec![ + "status".into(), + "completed_at".into(), + "error_message".into(), + ], + updated_by: "daemon".into(), + }); + } + Ok(None) => { + tracing::warn!( + task_id = %task_id, + "Task not found when completing" + ); + } + Err(e) => { + tracing::error!( + task_id = %task_id, + "Failed to complete task: {}", + e + ); + } + } + }); + } else { + // No DB, just broadcast + state.broadcast_task_update(TaskUpdateNotification { + task_id, + owner_id: Some(owner_id), + version: 0, + status: status.into(), + updated_fields: vec!["status".into(), "completed_at".into()], + updated_by: "daemon".into(), + }); + } + } + Ok(DaemonMessage::Authenticate { .. }) => { + // Already authenticated, ignore + } + Ok(DaemonMessage::RegisterToolKey { task_id, key }) => { + tracing::info!( + task_id = %task_id, + "Registering tool key for orchestrator" + ); + state.register_tool_key(key, task_id); + } + Ok(DaemonMessage::RevokeToolKey { task_id }) => { + tracing::info!( + task_id = %task_id, + "Revoking tool key for task" + ); + state.revoke_tool_key(task_id); + } + Ok(DaemonMessage::DaemonDirectories { working_directory, home_directory, worktrees_directory }) => { + tracing::info!( + daemon_id = %daemon_uuid, + working_directory = %working_directory, + home_directory = %home_directory, + worktrees_directory = %worktrees_directory, + "Daemon directories received" + ); + state.update_daemon_directories( + &connection_id, + working_directory, + home_directory, + worktrees_directory, + ); + } + Ok(DaemonMessage::CompletionActionResult { task_id, success, message, pr_url }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + pr_url = ?pr_url, + "Completion action result received" + ); + + // Update task with PR URL if created + if let Some(ref url) = pr_url { + if let Some(ref pool) = state.db_pool { + let update_req = crate::db::models::UpdateTaskRequest { + pr_url: Some(url.clone()), + ..Default::default() + }; + if let Err(e) = crate::db::repository::update_task(pool, task_id, update_req).await { + tracing::error!("Failed to update task PR URL: {}", e); + } + } + } + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Completion action succeeded: {}", message) + } else { + format!("✗ Completion action failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CloneWorktreeResult { task_id, success, message, target_dir }) => { + tracing::info!( + task_id = %task_id, + success = success, + message = %message, + target_dir = ?target_dir, + "Clone worktree result received" + ); + + // Broadcast as task output so UI can see the result + let output_text = if success { + format!("✓ Clone succeeded: {}", message) + } else { + format!("✗ Clone failed: {}", message) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: Some(!success), + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Ok(DaemonMessage::CheckTargetExistsResult { task_id, exists, target_dir }) => { + tracing::debug!( + task_id = %task_id, + exists = exists, + target_dir = %target_dir, + "Check target exists result received" + ); + + // Broadcast as task output so UI can use the result + let output_text = if exists { + format!("Target directory exists: {}", target_dir) + } else { + format!("Target directory does not exist: {}", target_dir) + }; + state.broadcast_task_output(TaskOutputNotification { + task_id, + owner_id: Some(owner_id), + message_type: "system".to_string(), + content: output_text, + tool_name: None, + tool_input: None, + is_error: None, + cost_usd: None, + duration_ms: None, + is_partial: false, + }); + } + Err(e) => { + tracing::warn!("Failed to parse daemon message: {}", e); + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::info!("Daemon {} disconnected", daemon_uuid); + break; + } + Some(Err(e)) => { + tracing::warn!("Daemon {} WebSocket error: {}", daemon_uuid, e); + break; + } + _ => {} + } + } + + // Handle commands to send to daemon + cmd = cmd_rx.recv() => { + match cmd { + Some(command) => { + let json = serde_json::to_string(&command).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + tracing::warn!("Failed to send command to daemon {}", daemon_uuid); + break; + } + } + None => { + // Channel closed + break; + } + } + } + } + } + + // Cleanup on disconnect + state.unregister_daemon(&connection_id); + + // Clear daemon_id from any tasks that were running on this daemon + if let Some(ref pool) = state.db_pool { + let pool = pool.clone(); + tokio::spawn(async move { + // Find tasks assigned to this daemon that are still active + if let Err(e) = clear_daemon_from_tasks(&pool, daemon_uuid).await { + tracing::error!( + daemon_id = %daemon_uuid, + error = %e, + "Failed to clear daemon from tasks on disconnect" + ); + } + }); + } +} + +/// Clear daemon_id from tasks when daemon disconnects +async fn clear_daemon_from_tasks(pool: &sqlx::PgPool, daemon_id: Uuid) -> Result<(), sqlx::Error> { + // Update tasks that were running on this daemon to failed state + let result = sqlx::query( + r#" + UPDATE tasks + SET daemon_id = NULL, + status = 'failed', + error_message = 'Daemon disconnected', + updated_at = NOW() + WHERE daemon_id = $1 + AND status IN ('starting', 'running', 'initializing') + "#, + ) + .bind(daemon_id) + .execute(pool) + .await?; + + if result.rows_affected() > 0 { + tracing::warn!( + daemon_id = %daemon_id, + tasks_affected = result.rows_affected(), + "Marked tasks as failed due to daemon disconnect" + ); + } + + Ok(()) +} diff --git a/makima/src/server/handlers/mesh_merge.rs b/makima/src/server/handlers/mesh_merge.rs new file mode 100644 index 0000000..2d7c742 --- /dev/null +++ b/makima/src/server/handlers/mesh_merge.rs @@ -0,0 +1,441 @@ +//! Merge operation handlers for orchestrator tasks. +//! +//! These endpoints allow orchestrators to merge subtask branches. +//! Commands are forwarded to the daemon via WebSocket; the daemon +//! responds asynchronously through the WebSocket channel. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use uuid::Uuid; + +use crate::db::models::{ + BranchListResponse, MergeCommitRequest, MergeCompleteCheckResponse, MergeResolveRequest, + MergeResultResponse, MergeSkipRequest, MergeStartRequest, MergeStatusResponse, +}; +use crate::db::repository; +use crate::server::messages::ApiError; +use crate::server::state::{DaemonCommand, SharedState}; + +/// Get the daemon ID for a task, returning error if not found. +async fn get_task_daemon_id( + state: &SharedState, + task_id: Uuid, +) -> Result<Uuid, (StatusCode, Json<ApiError>)> { + let pool = state.db_pool.as_ref().ok_or_else(|| { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("service_unavailable", "Database not configured")), + ) + })?; + + // Get task and its daemon_id + let task = repository::get_task(pool, task_id) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("database_error", format!("Database error: {}", e))), + ) + })? + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + Json(ApiError::new("not_found", format!("Task {} not found", task_id))), + ) + })?; + + task.daemon_id.ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("bad_request", "Task has no assigned daemon")), + ) + }) +} + +/// List all subtask branches for a task. +/// +/// GET /api/v1/mesh/tasks/{id}/branches +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/branches", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Command sent to daemon"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured") + ), + tag = "Mesh" +)] +pub async fn list_branches( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::ListBranches { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(BranchListResponse { branches: vec![] }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Start merging a subtask branch. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/start +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/start", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeStartRequest, + responses( + (status = 202, description = "Merge command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_start( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeStartRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeStart { + task_id, + source_branch: req.source_branch, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Merge command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Get current merge status. +/// +/// GET /api/v1/mesh/tasks/{id}/merge/status +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/merge/status", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Status request sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_status( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeStatus { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeStatusResponse { + in_progress: false, + source_branch: None, + conflicted_files: vec![], + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Resolve a merge conflict. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/resolve +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/resolve", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeResolveRequest, + responses( + (status = 202, description = "Resolve command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_resolve( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeResolveRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeResolve { + task_id, + file: req.file, + strategy: req.strategy, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Resolve command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Commit the current merge. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/commit +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/commit", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeCommitRequest, + responses( + (status = 202, description = "Commit command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_commit( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeCommitRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeCommit { + task_id, + message: req.message, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Commit command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Abort the current merge. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/abort +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/abort", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Abort command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_abort( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeAbort { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Abort command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Skip merging a subtask branch. +/// +/// POST /api/v1/mesh/tasks/{id}/merge/skip +#[utoipa::path( + post, + path = "/api/v1/mesh/tasks/{id}/merge/skip", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + request_body = MergeSkipRequest, + responses( + (status = 202, description = "Skip command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_skip( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, + Json(req): Json<MergeSkipRequest>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::MergeSkip { + task_id, + subtask_id: req.subtask_id, + reason: req.reason, + }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeResultResponse { + success: true, + message: "Skip command sent".to_string(), + commit_sha: None, + conflicts: None, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} + +/// Check if all branches are merged or skipped. +/// +/// GET /api/v1/mesh/tasks/{id}/merge/check +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/{id}/merge/check", + params( + ("id" = Uuid, Path, description = "Task ID") + ), + responses( + (status = 202, description = "Check command sent"), + (status = 404, description = "Task not found"), + (status = 503, description = "Database not configured or daemon not connected") + ), + tag = "Mesh" +)] +pub async fn merge_check( + State(state): State<SharedState>, + Path(task_id): Path<Uuid>, +) -> impl IntoResponse { + let daemon_id = match get_task_daemon_id(&state, task_id).await { + Ok(id) => id, + Err(e) => return e.into_response(), + }; + + let command = DaemonCommand::CheckMergeComplete { task_id }; + + match state.send_daemon_command(daemon_id, command).await { + Ok(()) => ( + StatusCode::ACCEPTED, + Json(MergeCompleteCheckResponse { + can_complete: true, + unmerged_branches: vec![], + merged_count: 0, + skipped_count: 0, + }), + ) + .into_response(), + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("daemon_error", e)), + ) + .into_response(), + } +} diff --git a/makima/src/server/handlers/mesh_ws.rs b/makima/src/server/handlers/mesh_ws.rs new file mode 100644 index 0000000..d15fba7 --- /dev/null +++ b/makima/src/server/handlers/mesh_ws.rs @@ -0,0 +1,346 @@ +//! WebSocket handler for task change subscriptions and output streaming. +//! +//! Clients can subscribe to specific tasks or all tasks to receive real-time notifications +//! when tasks are updated. They can also subscribe to task output for live terminal streaming. +//! +//! ## Owner-scoped filtering +//! +//! Notifications are filtered by owner_id. If a notification has an owner_id set, +//! it will only be delivered to clients who are subscribed to tasks belonging to that owner. +//! The task's owner_id is looked up from the database when the client subscribes. + +use axum::{ + extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade}, + response::Response, +}; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use sqlx::Row; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::server::state::SharedState; + +/// Client message for task subscription management. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskClientMessage { + /// Subscribe to updates for a specific task + Subscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from updates for a specific task + Unsubscribe { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribe to all task updates + SubscribeAll, + /// Unsubscribe from all task updates + UnsubscribeAll, + /// Subscribe to live output streaming for a specific task + SubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscribe from output streaming for a specific task + UnsubscribeOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + }, +} + +/// Server message for task subscription WebSocket. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum TaskServerMessage { + /// Subscription confirmed for specific task + Subscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Unsubscription confirmed for specific task + Unsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Subscribed to all task updates + SubscribedAll, + /// Unsubscribed from all task updates + UnsubscribedAll, + /// Task was updated + TaskUpdated { + #[serde(rename = "taskId")] + task_id: Uuid, + version: i32, + status: String, + #[serde(rename = "updatedFields")] + updated_fields: Vec<String>, + #[serde(rename = "updatedBy")] + updated_by: String, + }, + /// Live output from Claude Code container (parsed and structured) + TaskOutput { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + #[serde(rename = "messageType")] + message_type: String, + /// Main text content + content: String, + /// Tool name if tool_use message + #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")] + tool_name: Option<String>, + /// Tool input JSON if tool_use message + #[serde(rename = "toolInput", skip_serializing_if = "Option::is_none")] + tool_input: Option<serde_json::Value>, + /// Whether tool result was an error + #[serde(rename = "isError", skip_serializing_if = "Option::is_none")] + is_error: Option<bool>, + /// Cost in USD if result message + #[serde(rename = "costUsd", skip_serializing_if = "Option::is_none")] + cost_usd: Option<f64>, + /// Duration in ms if result message + #[serde(rename = "durationMs", skip_serializing_if = "Option::is_none")] + duration_ms: Option<u64>, + #[serde(rename = "isPartial")] + is_partial: bool, + }, + /// Output subscription confirmed + OutputSubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Output unsubscription confirmed + OutputUnsubscribed { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Error occurred + Error { code: String, message: String }, +} + +/// WebSocket upgrade handler for task subscriptions. +#[utoipa::path( + get, + path = "/api/v1/mesh/tasks/subscribe", + responses( + (status = 101, description = "WebSocket connection established"), + ), + tag = "Mesh" +)] +pub async fn task_subscription_handler( + ws: WebSocketUpgrade, + State(state): State<SharedState>, +) -> Response { + ws.on_upgrade(|socket| handle_task_subscription(socket, state)) +} + +/// Look up the owner_id for a task from the database. +async fn get_task_owner_id(pool: &sqlx::PgPool, task_id: Uuid) -> Option<Uuid> { + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .ok()??; + + row.try_get("owner_id").ok() +} + +async fn handle_task_subscription(socket: WebSocket, state: SharedState) { + let (mut sender, mut receiver) = socket.split(); + + // Map of task IDs to their owner_ids for this client's subscriptions + let mut task_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new(); + // Whether client is subscribed to all task updates (not owner-scoped) + let mut subscribed_all = false; + // Map of task IDs to their owner_ids for output streaming subscriptions + let mut output_subscriptions: HashMap<Uuid, Option<Uuid>> = HashMap::new(); + + // Subscribe to broadcast channels + let mut task_update_rx = state.task_updates.subscribe(); + let mut task_output_rx = state.task_output.subscribe(); + + loop { + tokio::select! { + // Handle incoming WebSocket messages from client + msg = receiver.next() => { + match msg { + Some(Ok(Message::Text(text))) => { + match serde_json::from_str::<TaskClientMessage>(&text) { + Ok(TaskClientMessage::Subscribe { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + task_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::Subscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::Unsubscribe { task_id }) => { + task_subscriptions.remove(&task_id); + let response = TaskServerMessage::Unsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from task {}", task_id); + } + Ok(TaskClientMessage::SubscribeAll) => { + subscribed_all = true; + let response = TaskServerMessage::SubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to all tasks"); + } + Ok(TaskClientMessage::UnsubscribeAll) => { + subscribed_all = false; + let response = TaskServerMessage::UnsubscribedAll; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from all tasks"); + } + Ok(TaskClientMessage::SubscribeOutput { task_id }) => { + // Look up owner_id for this task + let owner_id = if let Some(ref pool) = state.db_pool { + get_task_owner_id(pool, task_id).await + } else { + None + }; + output_subscriptions.insert(task_id, owner_id); + let response = TaskServerMessage::OutputSubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client subscribed to output for task {} (owner: {:?})", task_id, owner_id); + } + Ok(TaskClientMessage::UnsubscribeOutput { task_id }) => { + output_subscriptions.remove(&task_id); + let response = TaskServerMessage::OutputUnsubscribed { task_id }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + tracing::debug!("Client unsubscribed from output for task {}", task_id); + } + Err(e) => { + let response = TaskServerMessage::Error { + code: "PARSE_ERROR".into(), + message: e.to_string(), + }; + let json = serde_json::to_string(&response).unwrap(); + let _ = sender.send(Message::Text(json.into())).await; + } + } + } + Some(Ok(Message::Close(_))) | None => { + tracing::debug!("Client disconnected from task subscription"); + break; + } + Some(Err(e)) => { + tracing::warn!("Task WebSocket error: {}", e); + break; + } + _ => {} + } + } + + // Handle task update broadcasts + notification = task_update_rx.recv() => { + match notification { + Ok(notification) => { + // Check if client should receive this notification + let should_forward = if subscribed_all { + // SubscribeAll gets all notifications (typically for admin views) + true + } else if let Some(subscribed_owner) = task_subscriptions.get(¬ification.task_id) { + // Client is subscribed to this specific task + // Verify owner_id matches (if set on both sides) + match (notification.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskUpdated { + task_id: notification.task_id, + version: notification.version, + status: notification.status, + updated_fields: notification.updated_fields, + updated_by: notification.updated_by, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + + // Handle task output broadcasts + output = task_output_rx.recv() => { + match output { + Ok(output) => { + // Check if client should receive this output + let should_forward = if let Some(subscribed_owner) = output_subscriptions.get(&output.task_id) { + // Client is subscribed to output for this task + // Verify owner_id matches (if set on both sides) + match (output.owner_id, subscribed_owner) { + (Some(notif_owner), Some(sub_owner)) => notif_owner == *sub_owner, + _ => true, // Allow if owner_id not set on either side + } + } else { + false + }; + + if should_forward { + let response = TaskServerMessage::TaskOutput { + task_id: output.task_id, + message_type: output.message_type, + content: output.content, + tool_name: output.tool_name, + tool_input: output.tool_input, + is_error: output.is_error, + cost_usd: output.cost_usd, + duration_ms: output.duration_ms, + is_partial: output.is_partial, + }; + let json = serde_json::to_string(&response).unwrap(); + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Task output subscription client lagged, skipped {} messages", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + } + } +} diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index 3211f94..8681104 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,7 +1,14 @@ //! HTTP and WebSocket request handlers. +pub mod api_keys; pub mod chat; pub mod file_ws; pub mod files; pub mod listen; +pub mod mesh; +pub mod mesh_chat; +pub mod mesh_daemon; +pub mod mesh_merge; +pub mod mesh_ws; +pub mod users; pub mod versions; diff --git a/makima/src/server/handlers/users.rs b/makima/src/server/handlers/users.rs new file mode 100644 index 0000000..0b2ccdd --- /dev/null +++ b/makima/src/server/handlers/users.rs @@ -0,0 +1,972 @@ +//! HTTP handlers for user account management. +//! +//! These endpoints allow users to manage their account settings: +//! - Change password +//! - Change email +//! - Delete account + +use axum::{ + extract::State, + http::{HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use crate::server::auth::UserOnly; +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +// ============================================================================= +// Request/Response Types +// ============================================================================= + +/// Request to change password. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangePasswordRequest { + /// The user's current password for verification + pub current_password: String, + /// The new password to set + pub new_password: String, +} + +/// Response after changing password. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangePasswordResponse { + pub success: bool, + pub message: String, +} + +/// Request to change email. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangeEmailRequest { + /// The user's password for verification + pub password: String, + /// The new email address to set + pub new_email: String, +} + +/// Response after changing email. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChangeEmailResponse { + pub success: bool, + pub message: String, + /// Whether a verification email was sent to the new address + pub verification_sent: bool, +} + +/// Request to delete account. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DeleteAccountRequest { + /// The user's password for verification + pub password: String, + /// Confirmation text - must match the user's email + pub confirmation: String, +} + +/// Response after deleting account. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct DeleteAccountResponse { + pub success: bool, + pub message: String, +} + +// ============================================================================= +// Password Validation +// ============================================================================= + +/// Password strength validation result. +#[derive(Debug)] +pub struct PasswordValidation { + pub is_valid: bool, + pub errors: Vec<String>, +} + +/// Validate password strength. +/// Requirements: +/// - At least 6 characters (matches login form) +fn validate_password_strength(password: &str) -> PasswordValidation { + let mut errors = Vec::new(); + + if password.len() < 6 { + errors.push("Password must be at least 6 characters long".to_string()); + } + + PasswordValidation { + is_valid: errors.is_empty(), + errors, + } +} + +/// Validate email format. +fn validate_email(email: &str) -> bool { + // Basic email validation - must contain @ and at least one . after @ + let parts: Vec<&str> = email.split('@').collect(); + if parts.len() != 2 { + return false; + } + let local = parts[0]; + let domain = parts[1]; + // Local part must not be empty + if local.is_empty() { + return false; + } + // Domain must have at least one dot and not start/end with dot + domain.contains('.') && !domain.starts_with('.') && !domain.ends_with('.') +} + +// ============================================================================= +// Supabase Admin Client +// ============================================================================= + +/// Supabase Admin API client for user management operations. +/// Uses the service role key for admin-level operations. +pub struct SupabaseAdminClient { + base_url: String, + secret_api_key: String, + client: reqwest::Client, +} + +impl SupabaseAdminClient { + /// Create a new Supabase admin client from environment variables. + pub fn from_env() -> Option<Self> { + let base_url = std::env::var("SUPABASE_URL").ok()?; + let secret_api_key = std::env::var("SUPABASE_SECRET_API_KEY").ok()?; + + Some(Self { + base_url, + secret_api_key, + client: reqwest::Client::new(), + }) + } + + /// Verify a user's password by attempting to sign in. + pub async fn verify_password(&self, email: &str, password: &str) -> Result<bool, String> { + let url = format!("{}/auth/v1/token?grant_type=password", self.base_url); + + let response = self + .client + .post(&url) + .header("apikey", &self.secret_api_key) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": email, + "password": password + })) + .send() + .await + .map_err(|e| format!("Failed to verify password: {}", e))?; + + Ok(response.status().is_success()) + } + + /// Update a user's password. + pub async fn update_password( + &self, + user_id: &str, + new_password: &str, + ) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .put(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "password": new_password + })) + .send() + .await + .map_err(|e| format!("Failed to update password: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update password: {}", error_text)) + } + } + + /// Update a user's email. + pub async fn update_email( + &self, + user_id: &str, + new_email: &str, + ) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .put(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": new_email, + "email_confirm": true + })) + .send() + .await + .map_err(|e| format!("Failed to update email: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update email: {}", error_text)) + } + } + + /// Delete a user from Supabase Auth. + pub async fn delete_user(&self, user_id: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .delete(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .send() + .await + .map_err(|e| format!("Failed to delete user: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to delete user: {}", error_text)) + } + } + + /// Get user info including email. + pub async fn get_user(&self, user_id: &str) -> Result<Option<String>, String> { + let url = format!("{}/auth/v1/admin/users/{}", self.base_url, user_id); + + let response = self + .client + .get(&url) + .header("apikey", &self.secret_api_key) + .header("Authorization", format!("Bearer {}", self.secret_api_key)) + .send() + .await + .map_err(|e| format!("Failed to get user: {}", e))?; + + if response.status().is_success() { + let json: serde_json::Value = response + .json() + .await + .map_err(|e| format!("Failed to parse user data: {}", e))?; + Ok(json.get("email").and_then(|e| e.as_str()).map(String::from)) + } else if response.status() == reqwest::StatusCode::NOT_FOUND { + Ok(None) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to get user: {}", error_text)) + } + } +} + +// ============================================================================= +// Supabase User Client (uses user's JWT, no admin required) +// ============================================================================= + +/// Supabase User API client for self-service operations. +/// Uses the user's JWT token - no admin/service role key required. +pub struct SupabaseUserClient { + base_url: String, + anon_key: String, + jwt_token: String, + client: reqwest::Client, +} + +impl SupabaseUserClient { + /// Create a new Supabase user client from environment and JWT token. + pub fn new(jwt_token: String) -> Option<Self> { + let base_url = std::env::var("SUPABASE_URL").ok()?; + let anon_key = std::env::var("SUPABASE_ANON_KEY").ok()?; + + Some(Self { + base_url, + anon_key, + jwt_token, + client: reqwest::Client::new(), + }) + } + + /// Update the user's password using their own JWT. + pub async fn update_password(&self, new_password: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/user", self.base_url); + + let response = self + .client + .put(&url) + .header("apikey", &self.anon_key) + .header("Authorization", format!("Bearer {}", self.jwt_token)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "password": new_password + })) + .send() + .await + .map_err(|e| format!("Failed to update password: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update password: {}", error_text)) + } + } + + /// Update the user's email using their own JWT. + pub async fn update_email(&self, new_email: &str) -> Result<(), String> { + let url = format!("{}/auth/v1/user", self.base_url); + + let response = self + .client + .put(&url) + .header("apikey", &self.anon_key) + .header("Authorization", format!("Bearer {}", self.jwt_token)) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": new_email + })) + .send() + .await + .map_err(|e| format!("Failed to update email: {}", e))?; + + if response.status().is_success() { + Ok(()) + } else { + let error_text = response.text().await.unwrap_or_default(); + Err(format!("Failed to update email: {}", error_text)) + } + } + + /// Verify current password by attempting to sign in. + pub async fn verify_password(&self, email: &str, password: &str) -> Result<bool, String> { + let url = format!("{}/auth/v1/token?grant_type=password", self.base_url); + + let response = self + .client + .post(&url) + .header("apikey", &self.anon_key) + .header("Content-Type", "application/json") + .json(&serde_json::json!({ + "email": email, + "password": password + })) + .send() + .await + .map_err(|e| format!("Failed to verify password: {}", e))?; + + Ok(response.status().is_success()) + } +} + +// ============================================================================= +// Handlers +// ============================================================================= + +/// Change the authenticated user's password. +/// +/// Requires verification of the current password before allowing the change. +/// The new password must meet strength requirements. +#[utoipa::path( + put, + path = "/api/v1/users/me/password", + request_body = ChangePasswordRequest, + responses( + (status = 200, description = "Password changed successfully", body = ChangePasswordResponse), + (status = 400, description = "Invalid request (weak password, wrong current password)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn change_password_handler( + State(_state): State<SharedState>, + headers: HeaderMap, + UserOnly(user): UserOnly, + Json(req): Json<ChangePasswordRequest>, +) -> impl IntoResponse { + // Validate new password strength + let validation = validate_password_strength(&req.new_password); + if !validation.is_valid { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "WEAK_PASSWORD", + &validation.errors.join("; "), + )), + ) + .into_response(); + } + + // Get user's email (required for password verification) + let email = match &user.email { + Some(email) => email.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Extract JWT from Authorization header for user-level API calls + let jwt_token = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|s| s.to_string()); + + // Try user client first (uses JWT, no admin required), fall back to admin client + if let Some(token) = jwt_token { + if let Some(user_client) = SupabaseUserClient::new(token) { + // Verify current password + match user_client.verify_password(&email, &req.current_password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Current password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update password using user's JWT + return match user_client.update_password(&req.new_password).await { + Ok(()) => { + tracing::info!("Password changed for user {}", user.user_id); + Json(ChangePasswordResponse { + success: true, + message: "Password changed successfully".to_string(), + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to update password: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update password")), + ) + .into_response() + } + }; + } + } + + // Fall back to admin client if user client not available + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_NOT_CONFIGURED", + "Supabase not configured. Ensure SUPABASE_URL and SUPABASE_ANON_KEY are set.", + )), + ) + .into_response(); + } + }; + + // Verify current password + match admin_client.verify_password(&email, &req.current_password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Current password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update password in Supabase + match admin_client + .update_password(&user.user_id.to_string(), &req.new_password) + .await + { + Ok(()) => { + tracing::info!("Password changed for user {}", user.user_id); + Json(ChangePasswordResponse { + success: true, + message: "Password changed successfully".to_string(), + }) + .into_response() + } + Err(e) => { + tracing::error!("Failed to update password: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update password")), + ) + .into_response() + } + } +} + +/// Change the authenticated user's email address. +/// +/// Requires password verification before allowing the change. +/// The new email will be updated directly (Supabase handles verification if configured). +#[utoipa::path( + put, + path = "/api/v1/users/me/email", + request_body = ChangeEmailRequest, + responses( + (status = 200, description = "Email changed successfully", body = ChangeEmailResponse), + (status = 400, description = "Invalid request (invalid email, wrong password)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn change_email_handler( + State(state): State<SharedState>, + headers: HeaderMap, + UserOnly(user): UserOnly, + Json(req): Json<ChangeEmailRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Validate new email format + if !validate_email(&req.new_email) { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_EMAIL", "Invalid email format")), + ) + .into_response(); + } + + // Get user's current email (required for password verification) + let current_email = match &user.email { + Some(email) => email.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Extract JWT from Authorization header for user-level API calls + let jwt_token = headers + .get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|s| s.to_string()); + + // Try user client first (uses JWT, no admin required), fall back to admin client + if let Some(token) = jwt_token { + if let Some(user_client) = SupabaseUserClient::new(token) { + // Verify password + match user_client.verify_password(¤t_email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update email using user's JWT + if let Err(e) = user_client.update_email(&req.new_email).await { + tracing::error!("Failed to update email in Supabase: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update email")), + ) + .into_response(); + } + + // Update email in our database + if let Err(e) = sqlx::query("UPDATE users SET email = $1, updated_at = NOW() WHERE id = $2") + .bind(&req.new_email) + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to update email in database: {}", e); + } + + tracing::info!( + "Email changed for user {} from {} to {}", + user.user_id, + current_email, + req.new_email + ); + + return Json(ChangeEmailResponse { + success: true, + message: "Email changed successfully".to_string(), + verification_sent: false, + }) + .into_response(); + } + } + + // Fall back to admin client if user client not available + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_NOT_CONFIGURED", + "Supabase not configured. Ensure SUPABASE_URL and SUPABASE_ANON_KEY are set.", + )), + ) + .into_response(); + } + }; + + // Verify password + match admin_client.verify_password(¤t_email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Update email in Supabase + if let Err(e) = admin_client + .update_email(&user.user_id.to_string(), &req.new_email) + .await + { + tracing::error!("Failed to update email in Supabase: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to update email")), + ) + .into_response(); + } + + // Update email in our database + if let Err(e) = sqlx::query("UPDATE users SET email = $1, updated_at = NOW() WHERE id = $2") + .bind(&req.new_email) + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to update email in database: {}", e); + } + + tracing::info!( + "Email changed for user {} from {} to {}", + user.user_id, + current_email, + req.new_email + ); + + Json(ChangeEmailResponse { + success: true, + message: "Email changed successfully".to_string(), + verification_sent: false, + }) + .into_response() +} + +/// Delete the authenticated user's account. +/// +/// This permanently deletes: +/// - The user's Supabase Auth account +/// - The user's record in our database +/// - All associated data (API keys, files, tasks, etc. via CASCADE) +/// +/// Requires password verification and confirmation text matching the user's email. +#[utoipa::path( + delete, + path = "/api/v1/users/me", + request_body = DeleteAccountRequest, + responses( + (status = 200, description = "Account deleted successfully", body = DeleteAccountResponse), + (status = 400, description = "Invalid request (wrong password, wrong confirmation)", body = ApiError), + (status = 401, description = "Not authenticated", body = ApiError), + (status = 403, description = "Forbidden (tool keys not allowed)", body = ApiError), + (status = 503, description = "Supabase not configured", body = ApiError), + (status = 500, description = "Internal server error", body = ApiError), + ), + security( + ("bearer_auth" = []) + ), + tag = "Users" +)] +pub async fn delete_account_handler( + State(state): State<SharedState>, + UserOnly(user): UserOnly, + Json(req): Json<DeleteAccountRequest>, +) -> impl IntoResponse { + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new("DB_UNAVAILABLE", "Database not configured")), + ) + .into_response(); + }; + + // Get Supabase admin client - required for full account deletion + let admin_client = match SupabaseAdminClient::from_env() { + Some(client) => client, + None => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(ApiError::new( + "SUPABASE_ADMIN_NOT_CONFIGURED", + "Account deletion requires SUPABASE_SECRET_API_KEY to be configured", + )), + ) + .into_response(); + } + }; + + // Verify confirmation is "DELETE MY ACCOUNT" + const REQUIRED_CONFIRMATION: &str = "DELETE MY ACCOUNT"; + if req.confirmation != REQUIRED_CONFIRMATION { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new( + "INVALID_CONFIRMATION", + format!("Confirmation text must be exactly: {}", REQUIRED_CONFIRMATION), + )), + ) + .into_response(); + } + + // Get user's email (required for password verification) + let email = match &user.email { + Some(e) => e.clone(), + None => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("EMAIL_REQUIRED", "User email not available")), + ) + .into_response(); + } + }; + + // Verify password + match admin_client.verify_password(&email, &req.password).await { + Ok(true) => {} + Ok(false) => { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("INVALID_PASSWORD", "Password is incorrect")), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Failed to verify password: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to verify password")), + ) + .into_response(); + } + } + + // Delete from our database first (this will cascade to related records) + // Get the owner_id before deleting + let owner_id = user.owner_id; + + // Delete API keys for this user (explicit deletion for audit purposes) + if let Err(e) = sqlx::query("UPDATE api_keys SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL") + .bind(user.user_id) + .execute(pool) + .await + { + tracing::warn!("Failed to revoke API keys during account deletion: {}", e); + } + + // Delete user record + if let Err(e) = sqlx::query("DELETE FROM users WHERE id = $1") + .bind(user.user_id) + .execute(pool) + .await + { + tracing::error!("Failed to delete user from database: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("INTERNAL_ERROR", "Failed to delete account")), + ) + .into_response(); + } + + // Delete files owned by this user + if let Err(e) = sqlx::query("DELETE FROM files WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user files: {}", e); + } + + // Delete tasks owned by this user + if let Err(e) = sqlx::query("DELETE FROM tasks WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user tasks: {}", e); + } + + // Delete mesh chat conversations owned by this user + if let Err(e) = sqlx::query("DELETE FROM mesh_chat_conversations WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete mesh chat conversations: {}", e); + } + + // Delete daemons owned by this user + if let Err(e) = sqlx::query("DELETE FROM daemons WHERE owner_id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete user daemons: {}", e); + } + + // Delete owner record + if let Err(e) = sqlx::query("DELETE FROM owners WHERE id = $1") + .bind(owner_id) + .execute(pool) + .await + { + tracing::warn!("Failed to delete owner record: {}", e); + } + + // Delete from Supabase Auth + if let Err(e) = admin_client.delete_user(&user.user_id.to_string()).await { + tracing::error!("Failed to delete user from Supabase Auth: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new( + "SUPABASE_DELETE_FAILED", + "Failed to delete user from authentication system", + )), + ) + .into_response(); + } + + tracing::info!("Account deleted for user {} ({})", user.user_id, email); + + Json(DeleteAccountResponse { + success: true, + message: "Account deleted successfully".to_string(), + }) + .into_response() +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_password_validation_success() { + // Minimum 6 characters + let result = validate_password_strength("abcdef"); + assert!(result.is_valid); + assert!(result.errors.is_empty()); + + let result = validate_password_strength("Password123"); + assert!(result.is_valid); + assert!(result.errors.is_empty()); + } + + #[test] + fn test_password_validation_too_short() { + let result = validate_password_strength("12345"); + assert!(!result.is_valid); + assert!(result.errors.iter().any(|e| e.contains("6 characters"))); + } + + #[test] + fn test_email_validation_valid() { + assert!(validate_email("user@example.com")); + assert!(validate_email("user.name@example.co.uk")); + assert!(validate_email("user+tag@example.org")); + } + + #[test] + fn test_email_validation_invalid() { + assert!(!validate_email("userexample.com")); + assert!(!validate_email("user@")); + assert!(!validate_email("@example.com")); + assert!(!validate_email("user@.com")); + assert!(!validate_email("user@example.")); + } +} diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs index ee5e9bd..a096a5c 100644 --- a/makima/src/server/mod.rs +++ b/makima/src/server/mod.rs @@ -1,5 +1,6 @@ //! Web server module for the makima audio API. +pub mod auth; pub mod handlers; pub mod messages; pub mod openapi; @@ -17,7 +18,7 @@ use tower_http::trace::TraceLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::server::handlers::{chat, file_ws, files, listen, versions}; +use crate::server::handlers::{api_keys, chat, file_ws, files, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_ws, users, versions}; use crate::server::openapi::ApiDoc; use crate::server::state::SharedState; @@ -56,6 +57,62 @@ pub fn make_router(state: SharedState) -> Router { .route("/files/{id}/versions", get(versions::list_versions)) .route("/files/{id}/versions/{version}", get(versions::get_version)) .route("/files/{id}/versions/restore", post(versions::restore_version)) + // Mesh/task orchestration endpoints + .route( + "/mesh/tasks", + get(mesh::list_tasks).post(mesh::create_task), + ) + .route( + "/mesh/tasks/{id}", + get(mesh::get_task) + .put(mesh::update_task) + .delete(mesh::delete_task), + ) + .route("/mesh/tasks/{id}/subtasks", get(mesh::list_subtasks)) + .route("/mesh/tasks/{id}/events", get(mesh::list_task_events)) + .route("/mesh/tasks/{id}/output", get(mesh::get_task_output)) + .route("/mesh/tasks/{id}/start", post(mesh::start_task)) + .route("/mesh/tasks/{id}/stop", post(mesh::stop_task)) + .route("/mesh/tasks/{id}/message", post(mesh::send_message)) + .route("/mesh/tasks/{id}/retry-completion", post(mesh::retry_completion_action)) + .route("/mesh/tasks/{id}/clone", post(mesh::clone_worktree)) + .route("/mesh/tasks/{id}/check-target", post(mesh::check_target_exists)) + .route("/mesh/chat", post(mesh_chat::mesh_toplevel_chat_handler)) + .route( + "/mesh/chat/history", + get(mesh_chat::get_chat_history).delete(mesh_chat::clear_chat_history), + ) + .route("/mesh/tasks/{id}/chat", post(mesh_chat::mesh_chat_handler)) + .route("/mesh/daemons", get(mesh::list_daemons)) + .route("/mesh/daemons/directories", get(mesh::get_daemon_directories)) + .route("/mesh/daemons/{id}", get(mesh::get_daemon)) + // Merge endpoints for orchestrators + .route("/mesh/tasks/{id}/branches", get(mesh_merge::list_branches)) + .route("/mesh/tasks/{id}/merge/start", post(mesh_merge::merge_start)) + .route("/mesh/tasks/{id}/merge/status", get(mesh_merge::merge_status)) + .route("/mesh/tasks/{id}/merge/resolve", post(mesh_merge::merge_resolve)) + .route("/mesh/tasks/{id}/merge/commit", post(mesh_merge::merge_commit)) + .route("/mesh/tasks/{id}/merge/abort", post(mesh_merge::merge_abort)) + .route("/mesh/tasks/{id}/merge/skip", post(mesh_merge::merge_skip)) + .route("/mesh/tasks/{id}/merge/check", get(mesh_merge::merge_check)) + // Mesh WebSocket endpoints + .route("/mesh/tasks/subscribe", get(mesh_ws::task_subscription_handler)) + .route("/mesh/daemons/connect", get(mesh_daemon::daemon_handler)) + // API key management endpoints + .route( + "/auth/api-keys", + post(api_keys::create_api_key_handler) + .get(api_keys::get_api_key_handler) + .delete(api_keys::revoke_api_key_handler), + ) + .route("/auth/api-keys/refresh", post(api_keys::refresh_api_key_handler)) + // User account management endpoints + .route( + "/users/me", + axum::routing::delete(users::delete_account_handler), + ) + .route("/users/me/password", axum::routing::put(users::change_password_handler)) + .route("/users/me/email", axum::routing::put(users::change_email_handler)) .with_state(state); let swagger = SwaggerUi::new("/swagger-ui") diff --git a/makima/src/server/openapi.rs b/makima/src/server/openapi.rs index b946ff3..425c466 100644 --- a/makima/src/server/openapi.rs +++ b/makima/src/server/openapi.rs @@ -3,9 +3,19 @@ use utoipa::OpenApi; use crate::db::models::{ - CreateFileRequest, File, FileListResponse, FileSummary, TranscriptEntry, UpdateFileRequest, + BranchInfo, BranchListResponse, CreateFileRequest, CreateTaskRequest, Daemon, + DaemonDirectoriesResponse, DaemonDirectory, DaemonListResponse, File, FileListResponse, + FileSummary, MergeCommitRequest, MergeCompleteCheckResponse, MergeResolveRequest, + MergeResultResponse, MergeSkipRequest, MergeStartRequest, MergeStatusResponse, + MeshChatConversation, MeshChatHistoryResponse, MeshChatMessageRecord, SendMessageRequest, + Task, TaskEventListResponse, TaskListResponse, TaskSummary, TaskWithSubtasks, TranscriptEntry, + UpdateFileRequest, UpdateTaskRequest, }; -use crate::server::handlers::{files, listen}; +use crate::server::auth::{ + ApiKey, ApiKeyInfoResponse, CreateApiKeyRequest, CreateApiKeyResponse, + RefreshApiKeyRequest, RefreshApiKeyResponse, RevokeApiKeyResponse, +}; +use crate::server::handlers::{api_keys, files, listen, mesh, mesh_chat, mesh_merge, users}; use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage, TranscriptMessage}; #[derive(OpenApi)] @@ -23,6 +33,44 @@ use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage files::create_file, files::update_file, files::delete_file, + // Mesh endpoints + mesh::list_tasks, + mesh::get_task, + mesh::create_task, + mesh::update_task, + mesh::delete_task, + mesh::list_subtasks, + mesh::list_task_events, + mesh::get_task_output, + mesh::start_task, + mesh::stop_task, + mesh::send_message, + mesh::retry_completion_action, + mesh::list_daemons, + mesh::get_daemon, + mesh::get_daemon_directories, + mesh::clone_worktree, + mesh::check_target_exists, + mesh_chat::get_chat_history, + mesh_chat::clear_chat_history, + // Merge endpoints + mesh_merge::list_branches, + mesh_merge::merge_start, + mesh_merge::merge_status, + mesh_merge::merge_resolve, + mesh_merge::merge_commit, + mesh_merge::merge_abort, + mesh_merge::merge_skip, + mesh_merge::merge_check, + // API key endpoints + api_keys::create_api_key_handler, + api_keys::get_api_key_handler, + api_keys::refresh_api_key_handler, + api_keys::revoke_api_key_handler, + // User account management endpoints + users::change_password_handler, + users::change_email_handler, + users::delete_account_handler, ), components( schemas( @@ -38,11 +86,55 @@ use crate::server::messages::{ApiError, AudioEncoding, StartMessage, StopMessage CreateFileRequest, UpdateFileRequest, TranscriptEntry, + // Mesh/Task schemas + Task, + TaskSummary, + TaskListResponse, + TaskWithSubtasks, + CreateTaskRequest, + UpdateTaskRequest, + SendMessageRequest, + TaskEventListResponse, + Daemon, + DaemonListResponse, + DaemonDirectoriesResponse, + DaemonDirectory, + MeshChatConversation, + MeshChatMessageRecord, + MeshChatHistoryResponse, + // Merge schemas + BranchInfo, + BranchListResponse, + MergeStartRequest, + MergeStatusResponse, + MergeResolveRequest, + MergeCommitRequest, + MergeSkipRequest, + MergeResultResponse, + MergeCompleteCheckResponse, + // API key schemas + ApiKey, + ApiKeyInfoResponse, + CreateApiKeyRequest, + CreateApiKeyResponse, + RefreshApiKeyRequest, + RefreshApiKeyResponse, + RevokeApiKeyResponse, + // User account management schemas + users::ChangePasswordRequest, + users::ChangePasswordResponse, + users::ChangeEmailRequest, + users::ChangeEmailResponse, + users::DeleteAccountRequest, + users::DeleteAccountResponse, ) ), tags( (name = "Listen", description = "Speech-to-text streaming endpoints"), (name = "Files", description = "Transcript file management"), + (name = "Mesh", description = "Task orchestration for Claude Code instances"), + (name = "API Keys", description = "API key management for programmatic access"), + (name = "Users", description = "User account management"), ) )] pub struct ApiDoc; diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index 239ab77..e89197a 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -1,11 +1,13 @@ //! Application state holding shared ML models and database pool. use std::sync::Arc; +use dashmap::DashMap; use sqlx::PgPool; -use tokio::sync::{broadcast, Mutex}; +use tokio::sync::{broadcast, mpsc, Mutex}; use uuid::Uuid; use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer}; +use crate::server::auth::{AuthConfig, JwtVerifier}; /// Notification payload for file updates (broadcast to WebSocket subscribers). #[derive(Debug, Clone)] @@ -20,6 +22,262 @@ pub struct FileUpdateNotification { pub updated_by: String, } +// ============================================================================= +// Task/Mesh Notifications +// ============================================================================= + +/// Notification payload for task updates (broadcast to WebSocket subscribers). +#[derive(Debug, Clone)] +pub struct TaskUpdateNotification { + /// ID of the updated task + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + pub owner_id: Option<Uuid>, + /// New version number after update + pub version: i32, + /// Current task status + pub status: String, + /// List of fields that were updated + pub updated_fields: Vec<String>, + /// Source of the update: "user", "daemon", or "system" + pub updated_by: String, +} + +/// Notification for streaming task output from Claude Code containers. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskOutputNotification { + /// ID of the task producing output + pub task_id: Uuid, + /// Owner ID for data isolation (notifications are scoped to owner) + #[serde(skip)] + pub owner_id: Option<Uuid>, + /// Type of message: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" + pub message_type: String, + /// Main text content of the message + pub content: String, + /// Tool name if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option<String>, + /// Tool input (JSON) if this is a tool_use message + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_input: Option<serde_json::Value>, + /// Whether tool result was an error + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option<bool>, + /// Cost in USD if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub cost_usd: Option<f64>, + /// Duration in milliseconds if this is a result message + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option<u64>, + /// Whether this is a partial line (more coming) or complete + pub is_partial: bool, +} + +/// Command sent from server to daemon. +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum DaemonCommand { + /// Confirm successful authentication + Authenticated { + #[serde(rename = "daemonId")] + daemon_id: Uuid, + }, + /// Spawn a new task in a container + SpawnTask { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + plan: String, + #[serde(rename = "repoUrl")] + repo_url: Option<String>, + #[serde(rename = "baseBranch")] + base_branch: Option<String>, + /// Target branch to merge into (used for completion actions) + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + /// Parent task ID if this is a subtask + #[serde(rename = "parentTaskId")] + parent_task_id: Option<Uuid>, + /// Depth in task hierarchy (0=top-level, 1=subtask, 2=sub-subtask) + depth: i32, + /// Whether this task should run as an orchestrator (true if depth==0 and has subtasks) + #[serde(rename = "isOrchestrator")] + is_orchestrator: bool, + /// Path to user's local repository (outside ~/.makima) for completion actions + #[serde(rename = "targetRepoPath")] + target_repo_path: Option<String>, + /// Action on completion: "none", "branch", "merge", "pr" + #[serde(rename = "completionAction")] + completion_action: Option<String>, + /// Task ID to continue from (copy worktree from this task) + #[serde(rename = "continueFromTaskId")] + continue_from_task_id: Option<Uuid>, + /// Files to copy from parent task's worktree + #[serde(rename = "copyFiles")] + copy_files: Option<Vec<String>>, + }, + /// Pause a running task + PauseTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resume a paused task + ResumeTask { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Interrupt a task (gracefully or forced) + InterruptTask { + #[serde(rename = "taskId")] + task_id: Uuid, + graceful: bool, + }, + /// Send a message to a running task + SendMessage { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Inject context about sibling task progress + InjectSiblingContext { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "siblingTaskId")] + sibling_task_id: Uuid, + #[serde(rename = "siblingName")] + sibling_name: String, + #[serde(rename = "siblingStatus")] + sibling_status: String, + #[serde(rename = "progressSummary")] + progress_summary: Option<String>, + #[serde(rename = "changedFiles")] + changed_files: Vec<String>, + }, + + // ========================================================================= + // Merge Commands (for orchestrators to merge subtask branches) + // ========================================================================= + + /// List all subtask branches for a task + ListBranches { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Start merging a subtask branch + MergeStart { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "sourceBranch")] + source_branch: String, + }, + /// Get current merge status + MergeStatus { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Resolve a merge conflict + MergeResolve { + #[serde(rename = "taskId")] + task_id: Uuid, + file: String, + /// "ours" or "theirs" + strategy: String, + }, + /// Commit the current merge + MergeCommit { + #[serde(rename = "taskId")] + task_id: Uuid, + message: String, + }, + /// Abort the current merge + MergeAbort { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + /// Skip merging a subtask branch (mark as intentionally not merged) + MergeSkip { + #[serde(rename = "taskId")] + task_id: Uuid, + #[serde(rename = "subtaskId")] + subtask_id: Uuid, + reason: String, + }, + /// Check if all subtask branches have been merged or skipped (completion gate) + CheckMergeComplete { + #[serde(rename = "taskId")] + task_id: Uuid, + }, + + // ========================================================================= + // Completion Action Commands + // ========================================================================= + + /// Retry a completion action for a completed task + RetryCompletionAction { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Human-readable task name (used for commit messages) + #[serde(rename = "taskName")] + task_name: String, + /// The action to execute: "branch", "merge", or "pr" + action: String, + /// Path to the target repository + #[serde(rename = "targetRepoPath")] + target_repo_path: String, + /// Target branch to merge into (for merge/pr actions) + #[serde(rename = "targetBranch")] + target_branch: Option<String>, + }, + + /// Clone worktree to a target directory + CloneWorktree { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to the target directory + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Check if a target directory exists + CheckTargetExists { + #[serde(rename = "taskId")] + task_id: Uuid, + /// Path to check + #[serde(rename = "targetDir")] + target_dir: String, + }, + + /// Error response + Error { code: String, message: String }, +} + +/// Active daemon connection info stored in state. +#[derive(Debug)] +pub struct DaemonConnectionInfo { + /// Database ID of the daemon + pub id: Uuid, + /// Owner ID for data isolation (from API key authentication) + pub owner_id: Uuid, + /// WebSocket connection identifier + pub connection_id: String, + /// Daemon hostname + pub hostname: Option<String>, + /// Machine identifier + pub machine_id: Option<String>, + /// Channel to send commands to this daemon + pub command_sender: mpsc::Sender<DaemonCommand>, + /// Current working directory of the daemon + pub working_directory: Option<String>, + /// Path to ~/.makima/home directory on daemon (for cloning completed work) + pub home_directory: Option<String>, + /// Path to worktrees directory (~/.makima/worktrees) on daemon + pub worktrees_directory: Option<String>, +} + /// Shared application state containing ML models and database pool. /// /// Models are wrapped in `Mutex` for thread-safe mutable access during inference. @@ -34,6 +292,16 @@ pub struct AppState { pub db_pool: Option<PgPool>, /// Broadcast channel for file update notifications pub file_updates: broadcast::Sender<FileUpdateNotification>, + /// Broadcast channel for task update notifications + pub task_updates: broadcast::Sender<TaskUpdateNotification>, + /// Broadcast channel for task output streaming + pub task_output: broadcast::Sender<TaskOutputNotification>, + /// Active daemon connections (keyed by connection_id) + pub daemon_connections: DashMap<String, DaemonConnectionInfo>, + /// Tool keys for orchestrator API access (key -> task_id) + pub tool_keys: DashMap<String, Uuid>, + /// JWT verifier for Supabase authentication (None if not configured) + pub jwt_verifier: Option<JwtVerifier>, } impl AppState { @@ -56,8 +324,38 @@ impl AppState { DiarizationConfig::callhome(), )?; - // Create broadcast channel with buffer for 256 messages + // Create broadcast channels with buffer for 256 messages let (file_updates, _) = broadcast::channel(256); + let (task_updates, _) = broadcast::channel(256); + let (task_output, _) = broadcast::channel(1024); // Larger buffer for output streaming + + // Initialize JWT verifier from environment (optional) + // Requires SUPABASE_URL and either SUPABASE_JWT_PUBLIC_KEY (RS256) or SUPABASE_JWT_SECRET (HS256) + let jwt_verifier = match AuthConfig::from_env() { + Some(config) => match JwtVerifier::new(config) { + Ok(verifier) => { + tracing::info!("JWT authentication configured"); + Some(verifier) + } + Err(e) => { + tracing::error!("Failed to initialize JWT verifier: {}", e); + None + } + }, + None => { + // Log which env vars are missing + let has_url = std::env::var("SUPABASE_URL").is_ok(); + let has_public_key = std::env::var("SUPABASE_JWT_PUBLIC_KEY").is_ok(); + let has_secret = std::env::var("SUPABASE_JWT_SECRET").is_ok(); + + if !has_url { + tracing::info!("JWT authentication not configured (SUPABASE_URL not set)"); + } else if !has_public_key && !has_secret { + tracing::info!("JWT authentication not configured (set SUPABASE_JWT_PUBLIC_KEY for RS256 or SUPABASE_JWT_SECRET for HS256)"); + } + None + } + }; Ok(Self { parakeet: Mutex::new(parakeet), @@ -65,6 +363,11 @@ impl AppState { sortformer: Mutex::new(sortformer), db_pool: None, file_updates, + task_updates, + task_output, + daemon_connections: DashMap::new(), + tool_keys: DashMap::new(), + jwt_verifier, }) } @@ -81,6 +384,166 @@ impl AppState { // Ignore send errors - they just mean no one is listening let _ = self.file_updates.send(notification); } + + /// Broadcast a task update notification to all subscribers. + /// + /// This is a no-op if there are no subscribers (ignores send errors). + pub fn broadcast_task_update(&self, notification: TaskUpdateNotification) { + let _ = self.task_updates.send(notification); + } + + /// Broadcast task output to all subscribers. + /// + /// Used for streaming Claude Code container output to frontend clients. + pub fn broadcast_task_output(&self, notification: TaskOutputNotification) { + let _ = self.task_output.send(notification); + } + + /// Register a new daemon connection. + /// + /// Returns the connection_id for later reference. + pub fn register_daemon( + &self, + connection_id: String, + daemon_id: Uuid, + owner_id: Uuid, + hostname: Option<String>, + machine_id: Option<String>, + command_sender: mpsc::Sender<DaemonCommand>, + ) { + self.daemon_connections.insert( + connection_id.clone(), + DaemonConnectionInfo { + id: daemon_id, + owner_id, + connection_id, + hostname, + machine_id, + command_sender, + working_directory: None, + home_directory: None, + worktrees_directory: None, + }, + ); + } + + /// Update daemon directory information. + pub fn update_daemon_directories( + &self, + connection_id: &str, + working_directory: String, + home_directory: String, + worktrees_directory: String, + ) { + if let Some(mut entry) = self.daemon_connections.get_mut(connection_id) { + entry.working_directory = Some(working_directory); + entry.home_directory = Some(home_directory); + entry.worktrees_directory = Some(worktrees_directory); + } + } + + /// Unregister a daemon connection. + pub fn unregister_daemon(&self, connection_id: &str) { + self.daemon_connections.remove(connection_id); + } + + /// Get a daemon connection by connection_id. + pub fn get_daemon(&self, connection_id: &str) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> { + self.daemon_connections.get(connection_id) + } + + /// Get a daemon by its database ID. + pub fn get_daemon_by_id(&self, daemon_id: Uuid) -> Option<dashmap::mapref::one::Ref<'_, String, DaemonConnectionInfo>> { + self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + .map(|entry| { + // Return a reference to the found entry + self.daemon_connections.get(entry.key()).unwrap() + }) + } + + /// Send a command to a specific daemon by its database ID. + pub async fn send_daemon_command(&self, daemon_id: Uuid, command: DaemonCommand) -> Result<(), String> { + if let Some(daemon) = self.daemon_connections + .iter() + .find(|entry| entry.value().id == daemon_id) + { + daemon.value().command_sender.send(command).await + .map_err(|e| format!("Failed to send command to daemon: {}", e)) + } else { + Err(format!("Daemon {} not connected", daemon_id)) + } + } + + /// Broadcast sibling progress to all running sibling tasks. + /// + /// This is used for sibling awareness - when a task makes progress, + /// its siblings are notified so they can adjust their work if needed. + pub async fn broadcast_sibling_progress( + &self, + source_task_id: Uuid, + source_task_name: &str, + source_task_status: &str, + progress_summary: Option<String>, + changed_files: Vec<String>, + running_sibling_daemon_ids: Vec<(Uuid, Uuid)>, // (task_id, daemon_id) + ) { + for (sibling_task_id, daemon_id) in running_sibling_daemon_ids { + let command = DaemonCommand::InjectSiblingContext { + task_id: sibling_task_id, + sibling_task_id: source_task_id, + sibling_name: source_task_name.to_string(), + sibling_status: source_task_status.to_string(), + progress_summary: progress_summary.clone(), + changed_files: changed_files.clone(), + }; + + // Fire and forget - don't block on sending to all daemons + if let Err(e) = self.send_daemon_command(daemon_id, command).await { + tracing::warn!( + "Failed to inject sibling context to task {}: {}", + sibling_task_id, + e + ); + } + } + } + + /// Get list of connected daemon IDs. + pub fn list_connected_daemon_ids(&self) -> Vec<Uuid> { + self.daemon_connections + .iter() + .map(|entry| entry.value().id) + .collect() + } + + // ========================================================================= + // Tool Key Management + // ========================================================================= + + /// Register a tool key for a task. + /// + /// This allows orchestrators to authenticate with the API using + /// the `X-Makima-Tool-Key` header. + pub fn register_tool_key(&self, key: String, task_id: Uuid) { + tracing::info!(task_id = %task_id, "Registering tool key"); + self.tool_keys.insert(key, task_id); + } + + /// Validate a tool key and return the associated task ID. + pub fn validate_tool_key(&self, key: &str) -> Option<Uuid> { + self.tool_keys.get(key).map(|entry| *entry.value()) + } + + /// Revoke a tool key for a task. + /// + /// This should be called when a task completes or is terminated. + pub fn revoke_tool_key(&self, task_id: Uuid) { + // Find and remove the key for this task + self.tool_keys.retain(|_, v| *v != task_id); + tracing::info!(task_id = %task_id, "Revoked tool key"); + } } /// Type alias for the shared application state. |
