diff options
Diffstat (limited to 'makima/src/server/auth.rs')
| -rw-r--r-- | makima/src/server/auth.rs | 1238 |
1 files changed, 1238 insertions, 0 deletions
diff --git a/makima/src/server/auth.rs b/makima/src/server/auth.rs new file mode 100644 index 0000000..b694df6 --- /dev/null +++ b/makima/src/server/auth.rs @@ -0,0 +1,1238 @@ +//! Authentication module for Makima server. +//! +//! Supports multiple authentication methods: +//! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification) +//! - API keys for programmatic access (daemons, CLI) +//! - Tool keys for orchestrator internal access + +use axum::{ + extract::FromRequestParts, + http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode}, + response::IntoResponse, + Json, +}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use chrono::{DateTime, Utc}; +use dashmap::DashMap; +use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use sqlx::{FromRow, PgPool, Row}; +use std::time::{Duration, Instant}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::server::messages::ApiError; +use crate::server::state::SharedState; + +// ============================================================================= +// Configuration +// ============================================================================= + +/// JWT algorithm configuration. +#[derive(Debug, Clone)] +pub enum JwtAlgorithm { + /// RS256 with RSA public key + Rs256 { public_key: String }, + /// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys) + Es256 { public_key: String }, +} + +/// Authentication configuration loaded from environment. +#[derive(Debug, Clone)] +pub struct AuthConfig { + /// Supabase project URL (e.g., https://your-project.supabase.co) + pub supabase_url: String, + /// JWT algorithm and key material + pub algorithm: JwtAlgorithm, +} + +impl AuthConfig { + /// Load auth config from environment variables. + /// + /// Supports two modes (checked in order): + /// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA) + /// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key) + /// + /// Returns None if auth is not configured. + pub fn from_env() -> Option<Self> { + let supabase_url = std::env::var("SUPABASE_URL").ok()?; + + // Try ES256 first (default for Supabase), then RS256 + let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") { + tracing::info!("Using ES256 JWT verification with ECDSA public key"); + JwtAlgorithm::Es256 { public_key } + } else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") { + tracing::info!("Using RS256 JWT verification with RSA public key"); + JwtAlgorithm::Rs256 { public_key } + } else { + return None; + }; + + Some(Self { + supabase_url, + algorithm, + }) + } +} + +// ============================================================================= +// JWT Claims +// ============================================================================= + +/// JWT claims from Supabase Auth tokens. +#[derive(Debug, Serialize, Deserialize)] +pub struct SupabaseClaims { + /// Audience (e.g., "authenticated") + pub aud: String, + /// Expiration time (Unix timestamp) + pub exp: i64, + /// Issued at (Unix timestamp) + pub iat: i64, + /// Issuer (Supabase project URL + /auth/v1) + pub iss: String, + /// Subject (user ID) + pub sub: Uuid, + /// User's email + pub email: Option<String>, + /// User's phone + pub phone: Option<String>, + /// App metadata (set by server/admin) + pub app_metadata: Option<serde_json::Value>, + /// User metadata (set by user) + pub user_metadata: Option<serde_json::Value>, + /// Role (e.g., "authenticated") + pub role: Option<String>, + /// Session ID + pub session_id: Option<Uuid>, +} + +// ============================================================================= +// JWT Verifier +// ============================================================================= + +/// JWT verifier for Supabase tokens. +pub struct JwtVerifier { + supabase_url: String, + decoding_key: DecodingKey, + algorithm: Algorithm, +} + +impl JwtVerifier { + /// Create a new JWT verifier from auth config. + /// + /// Supports multiple key formats: + /// - JWK (JSON Web Key) - detected by presence of `{` + /// - PEM - detected by `-----BEGIN` + /// - Base64-encoded DER - fallback + pub fn new(config: AuthConfig) -> Result<Self, AuthError> { + let (decoding_key, algorithm) = match &config.algorithm { + JwtAlgorithm::Rs256 { public_key } => { + let key = Self::parse_public_key(public_key, "RSA")?; + (key, Algorithm::RS256) + } + JwtAlgorithm::Es256 { public_key } => { + let key = Self::parse_public_key(public_key, "EC")?; + (key, Algorithm::ES256) + } + }; + + Ok(Self { + supabase_url: config.supabase_url, + decoding_key, + algorithm, + }) + } + + /// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER). + fn parse_public_key(key_data: &str, key_type: &str) -> Result<DecodingKey, AuthError> { + let trimmed = key_data.trim(); + + // Check for JSON format (JWK or JWKS) + if trimmed.starts_with('{') { + // First try to parse as a generic JSON value to inspect structure + let mut json_value: serde_json::Value = serde_json::from_str(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?; + + // Check if it's a JWKS (has "keys" array) + if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) { + // Find the first signing key (or just use the first key) + let jwk_value = keys.first_mut() + .ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?; + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = jwk_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWKS (first key)"); + return DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))); + } + + // Remove private key component if present (user may have pasted full keypair) + if let Some(obj) = json_value.as_object_mut() { + if obj.remove("d").is_some() { + tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); + } + } + + // Try as single JWK + let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?; + + tracing::info!("Loaded JWT public key from JWK"); + DecodingKey::from_jwk(&jwk) + .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))) + } + // Check for PEM format + else if trimmed.contains("-----BEGIN") { + tracing::info!("Loaded JWT public key from PEM"); + match key_type { + "RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))), + "EC" => DecodingKey::from_ec_pem(trimmed.as_bytes()) + .map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + // Assume base64-encoded DER + else { + tracing::info!("Loaded JWT public key from base64 DER"); + let der_bytes = base64::engine::general_purpose::STANDARD + .decode(trimmed) + .map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?; + + match key_type { + "RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)), + "EC" => Ok(DecodingKey::from_ec_der(&der_bytes)), + _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), + } + } + } + + /// Verify a JWT token and return claims. + pub fn verify(&self, token: &str) -> Result<SupabaseClaims, AuthError> { + // Decode header to check algorithm mismatch + let header = jsonwebtoken::decode_header(token) + .map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?; + + tracing::debug!( + "JWT header: algorithm={:?}, typ={:?}, kid={:?}", + header.alg, + header.typ, + header.kid + ); + + if header.alg != self.algorithm { + let hint = match header.alg { + Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings", + Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key", + _ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported", + }; + tracing::warn!( + "JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}", + header.alg, + self.algorithm, + hint + ); + return Err(AuthError::InvalidToken(format!( + "Algorithm mismatch: token is {:?}, expected {:?}", + header.alg, self.algorithm + ))); + } + + let mut validation = Validation::new(self.algorithm); + validation.set_audience(&["authenticated"]); + validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]); + + // First try with full validation + let token_data = match decode::<SupabaseClaims>(token, &self.decoding_key, &validation) { + Ok(data) => data, + Err(e) => { + // Log detailed error info + tracing::warn!( + "JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)", + e, + self.algorithm, + self.supabase_url + ); + + // If it's InvalidAlgorithm, try to understand why by decoding payload manually + if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) { + // Decode the payload part of the JWT manually (base64) + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() >= 2 { + if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) { + if let Ok(payload_str) = String::from_utf8(payload_bytes) { + if let Ok(claims) = serde_json::from_str::<serde_json::Value>(&payload_str) { + tracing::warn!( + "JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}", + claims.get("iss"), + claims.get("aud"), + claims.get("sub") + ); + } + } + } + } + } + + return Err(AuthError::InvalidToken(e.to_string())); + } + }; + + Ok(token_data.claims) + } + + /// Extract user ID from a token. + pub fn get_user_id(&self, token: &str) -> Result<Uuid, AuthError> { + let claims = self.verify(token)?; + Ok(claims.sub) + } +} + +// ============================================================================= +// Auth Error +// ============================================================================= + +/// Authentication error types. +#[derive(Debug)] +pub enum AuthError { + /// No authentication token provided + MissingToken, + /// Token format is invalid + InvalidToken(String), + /// Token has expired + ExpiredToken, + /// User not found in database + UserNotFound, + /// API key is invalid or revoked + InvalidApiKey, + /// Database error during auth lookup + DatabaseError(String), + /// Authentication is not configured + NotConfigured, + /// Insufficient permissions for the operation + InsufficientPermissions, +} + +impl std::fmt::Display for AuthError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthError::MissingToken => write!(f, "Missing authentication token"), + AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg), + AuthError::ExpiredToken => write!(f, "Token has expired"), + AuthError::UserNotFound => write!(f, "User not found"), + AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"), + AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg), + AuthError::NotConfigured => write!(f, "Authentication not configured"), + AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"), + } + } +} + +impl std::error::Error for AuthError {} + +impl IntoResponse for AuthError { + fn into_response(self) -> axum::response::Response { + let (status, code, message) = match &self { + AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"), + AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"), + AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"), + AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"), + AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"), + AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"), + AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"), + AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"), + }; + + (status, Json(ApiError::new(code, message))).into_response() + } +} + +// ============================================================================= +// Auth Source +// ============================================================================= + +/// Source of authentication. +#[derive(Debug, Clone)] +pub enum AuthSource { + /// Authenticated via Supabase JWT (web client) + Jwt, + /// Authenticated via API key (daemon, CLI, integrations) + ApiKey, + /// Authenticated via tool key (orchestrator internal access) + ToolKey(Uuid), +} + +// ============================================================================= +// Authenticated User +// ============================================================================= + +/// Authenticated user context extracted from request. +/// +/// Contains the resolved user_id and owner_id for database operations. +#[derive(Debug, Clone)] +pub struct AuthenticatedUser { + /// Supabase auth user ID (from auth.users) + pub user_id: Uuid, + /// Owner ID for data isolation (from users.default_owner_id) + pub owner_id: Uuid, + /// How the user was authenticated + pub auth_source: AuthSource, + /// User's email (if available) + pub email: Option<String>, +} + +// ============================================================================= +// Header Constants +// ============================================================================= + +/// Header name for tool key authentication (orchestrators). +pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; + +/// Header name for API key authentication. +pub const API_KEY_HEADER: &str = "x-makima-api-key"; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Hash an API key for database lookup. +pub fn hash_api_key(key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + hex::encode(hasher.finalize()) +} + +// ============================================================================= +// API Key Generation +// ============================================================================= + +/// API key prefix for identification. +pub const API_KEY_PREFIX: &str = "mk_"; + +/// Result of generating an API key. +pub struct GeneratedApiKey { + /// The full API key (shown only once to user) + pub full_key: String, + /// SHA-256 hash of the key (stored in database) + pub key_hash: String, + /// Prefix for display (first 8 chars after mk_) + pub key_prefix: String, +} + +/// Generate a new API key with mk_ prefix. +/// +/// Returns the full key (to show once), hash (to store), and prefix (for display). +pub fn generate_api_key() -> GeneratedApiKey { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + + let key_bytes = URL_SAFE_NO_PAD.encode(bytes); + let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes); + + let key_hash = hash_api_key(&full_key); + let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]); + + GeneratedApiKey { + full_key, + key_hash, + key_prefix, + } +} + +// ============================================================================= +// API Key Cache +// ============================================================================= + +/// Cache entry for validated API keys. +struct ApiKeyCacheEntry { + user_id: Uuid, + owner_id: Uuid, + cached_at: Instant, +} + +/// In-memory cache for API key validation to avoid database lookups on every request. +pub struct ApiKeyCache { + /// key_hash -> (user_id, owner_id, cached_at) + cache: DashMap<String, ApiKeyCacheEntry>, + /// Time-to-live for cache entries + ttl: Duration, +} + +impl ApiKeyCache { + /// Create a new cache with the specified TTL in seconds. + pub fn new(ttl_seconds: u64) -> Self { + Self { + cache: DashMap::new(), + ttl: Duration::from_secs(ttl_seconds), + } + } + + /// Get cached user_id and owner_id for a key hash, if not expired. + pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> { + self.cache.get(key_hash).and_then(|entry| { + if entry.cached_at.elapsed() < self.ttl { + Some((entry.user_id, entry.owner_id)) + } else { + None + } + }) + } + + /// Cache a validated API key. + pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) { + self.cache.insert( + key_hash, + ApiKeyCacheEntry { + user_id, + owner_id, + cached_at: Instant::now(), + }, + ); + } + + /// Invalidate a cache entry (e.g., on key revocation). + pub fn invalidate(&self, key_hash: &str) { + self.cache.remove(key_hash); + } + + /// Clear all cache entries. + pub fn clear(&self) { + self.cache.clear(); + } +} + +impl Default for ApiKeyCache { + fn default() -> Self { + // Default TTL: 5 minutes + Self::new(300) + } +} + +// ============================================================================= +// API Key Models +// ============================================================================= + +/// API key record from the database. +#[derive(Debug, Clone, FromRow, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKey { + pub id: Uuid, + pub user_id: Uuid, + #[serde(skip)] + pub key_hash: String, + pub key_prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, + pub revoked_at: Option<DateTime<Utc>>, +} + +/// Request to create a new API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyRequest { + /// User-provided label for the key + pub name: Option<String>, +} + +/// Response after creating an API key (includes the full key - shown only once). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct CreateApiKeyResponse { + pub id: Uuid, + /// The full API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, +} + +/// Response for getting API key info (excludes the full key). +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ApiKeyInfoResponse { + pub id: Uuid, + pub prefix: String, + pub name: Option<String>, + pub last_used_at: Option<DateTime<Utc>>, + pub created_at: DateTime<Utc>, +} + +impl From<ApiKey> for ApiKeyInfoResponse { + fn from(key: ApiKey) -> Self { + Self { + id: key.id, + prefix: key.key_prefix, + name: key.name, + last_used_at: key.last_used_at, + created_at: key.created_at, + } + } +} + +/// Request to refresh an API key. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyRequest { + /// New name for the refreshed key + pub name: Option<String>, +} + +/// Response after refreshing an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RefreshApiKeyResponse { + pub id: Uuid, + /// The new API key - save this, it won't be shown again! + pub key: String, + pub prefix: String, + pub name: Option<String>, + pub created_at: DateTime<Utc>, + pub previous_key_revoked: bool, +} + +/// Response after revoking an API key. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct RevokeApiKeyResponse { + pub message: String, + pub revoked_key_prefix: String, +} + +/// API key event types for audit logging. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ApiKeyEventType { + Created, + Used, + Revoked, + Refreshed, +} + +impl std::fmt::Display for ApiKeyEventType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyEventType::Created => write!(f, "created"), + ApiKeyEventType::Used => write!(f, "used"), + ApiKeyEventType::Revoked => write!(f, "revoked"), + ApiKeyEventType::Refreshed => write!(f, "refreshed"), + } + } +} + +// ============================================================================= +// API Keys Repository +// ============================================================================= + +/// Repository error for API key operations. +#[derive(Debug)] +pub enum ApiKeyError { + /// Database error + Database(sqlx::Error), + /// An active API key already exists for this user + KeyAlreadyExists, + /// No active API key found for this user + KeyNotFound, +} + +impl std::fmt::Display for ApiKeyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeyError::Database(e) => write!(f, "Database error: {}", e), + ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"), + ApiKeyError::KeyNotFound => write!(f, "No active API key found"), + } + } +} + +impl std::error::Error for ApiKeyError {} + +impl From<sqlx::Error> for ApiKeyError { + fn from(e: sqlx::Error) -> Self { + ApiKeyError::Database(e) + } +} + +/// Get the active API key for a user (if any). +pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result<Option<ApiKey>, sqlx::Error> { + sqlx::query_as::<_, ApiKey>( + r#" + SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + FROM api_keys + WHERE user_id = $1 AND revoked_at IS NULL + "#, + ) + .bind(user_id) + .fetch_optional(pool) + .await +} + +/// Create a new API key for a user. +/// +/// Returns an error if the user already has an active key. +/// The `generated` parameter should be created using `generate_api_key()`. +pub async fn create_api_key( + pool: &PgPool, + user_id: Uuid, + generated: &GeneratedApiKey, + name: Option<&str>, +) -> Result<ApiKey, ApiKeyError> { + // Check if user already has an active key + if let Some(_) = get_active_api_key(pool, user_id).await? { + return Err(ApiKeyError::KeyAlreadyExists); + } + + let key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&generated.key_hash) + .bind(&generated.key_prefix) + .bind(name) + .fetch_one(pool) + .await?; + + // Log the creation event + let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await; + + Ok(key) +} + +/// Revoke an API key by marking it with revoked_at timestamp. +pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result<ApiKey, ApiKeyError> { + // Get the active key first + let key = get_active_api_key(pool, user_id) + .await? + .ok_or(ApiKeyError::KeyNotFound)?; + + // Revoke it + let revoked = sqlx::query_as::<_, ApiKey>( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(key.id) + .fetch_one(pool) + .await?; + + // Log the revocation event + let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await; + + Ok(revoked) +} + +/// Refresh an API key: revoke the old one and create a new one atomically. +/// +/// Returns the new key. The caller should use `generate_api_key()` to create +/// the `new_generated` parameter. +pub async fn refresh_api_key( + pool: &PgPool, + user_id: Uuid, + new_generated: &GeneratedApiKey, + new_name: Option<&str>, +) -> Result<(ApiKey, Option<String>), ApiKeyError> { + // Get and revoke the old key (if exists) + let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? { + let old_prefix = old_key.key_prefix.clone(); + + // Revoke the old key + sqlx::query( + r#" + UPDATE api_keys + SET revoked_at = NOW() + WHERE id = $1 + "#, + ) + .bind(old_key.id) + .execute(pool) + .await?; + + // Log the refresh event on the old key + let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await; + + Some(old_prefix) + } else { + None + }; + + // Create the new key + let new_key = sqlx::query_as::<_, ApiKey>( + r#" + INSERT INTO api_keys (user_id, key_hash, key_prefix, name) + VALUES ($1, $2, $3, $4) + RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at + "#, + ) + .bind(user_id) + .bind(&new_generated.key_hash) + .bind(&new_generated.key_prefix) + .bind(new_name) + .fetch_one(pool) + .await?; + + // Log the creation event on the new key + let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await; + + Ok((new_key, old_prefix)) +} + +/// Update last_used_at timestamp for an API key. +pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + UPDATE api_keys + SET last_used_at = NOW() + WHERE key_hash = $1 AND revoked_at IS NULL + "#, + ) + .bind(key_hash) + .execute(pool) + .await?; + + Ok(()) +} + +/// Log an API key event for audit purposes. +pub async fn log_api_key_event( + pool: &PgPool, + api_key_id: Uuid, + event_type: ApiKeyEventType, + ip_address: Option<&str>, + user_agent: Option<&str>, +) -> Result<(), sqlx::Error> { + sqlx::query( + r#" + INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent) + VALUES ($1, $2, $3::inet, $4) + "#, + ) + .bind(api_key_id) + .bind(event_type.to_string()) + .bind(ip_address) + .bind(user_agent) + .execute(pool) + .await?; + + Ok(()) +} + +// ============================================================================= +// Internal Helper Functions +// ============================================================================= + +/// Resolve owner_id from user_id by looking up the users table. +/// If the user doesn't exist, auto-creates them on first login. +/// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously. +async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> { + // First, try to get existing user + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + if let Some(row) = row { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + return owner_id.ok_or(AuthError::UserNotFound); + } + + // User doesn't exist - auto-create on first login + tracing::info!("Creating new user record for {}", user_id); + + // Create owner first (use ON CONFLICT to handle race conditions) + let owner_id = Uuid::new_v4(); + sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING") + .bind(owner_id) + .bind(email.unwrap_or("Unknown")) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?; + + // Create user with reference to owner (use ON CONFLICT to handle race conditions) + sqlx::query( + "INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING" + ) + .bind(user_id) + .bind(email) + .bind(owner_id) + .execute(pool) + .await + .map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?; + + // Re-fetch the user to get the actual owner_id (in case another request created it first) + let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") + .bind(user_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + owner_id.ok_or(AuthError::UserNotFound) + } + None => Err(AuthError::DatabaseError("Failed to create user record".to_string())) + } +} + +/// Validate an API key and return (user_id, owner_id). +async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> { + let key_hash = hash_api_key(key); + + // Look up the API key and join with users to get owner_id + let row = sqlx::query( + r#" + SELECT ak.user_id, u.default_owner_id + FROM api_keys ak + JOIN users u ON u.id = ak.user_id + WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL + "#, + ) + .bind(&key_hash) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + match row { + Some(row) => { + let user_id: Uuid = row.try_get("user_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id: Option<Uuid> = row.try_get("default_owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + let owner_id = owner_id.ok_or(AuthError::UserNotFound)?; + + // Update last_used_at asynchronously (fire and forget) + let pool_clone = pool.clone(); + let key_hash_clone = key_hash.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") + .bind(&key_hash_clone) + .execute(&pool_clone) + .await; + }); + + Ok((user_id, owner_id)) + } + None => Err(AuthError::InvalidApiKey), + } +} + +/// Extract authentication from request headers. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +async fn extract_auth( + state: &SharedState, + headers: &HeaderMap, +) -> Result<AuthenticatedUser, AuthError> { + // 1. Check for tool key (orchestrator access) + if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { + if let Ok(key_str) = tool_key.to_str() { + if let Some(task_id) = state.validate_tool_key(key_str) { + // Tool keys are trusted - use a placeholder user/owner for orchestrator actions + // The orchestrator inherits the owner_id from its task + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + + // Get owner_id from the task + let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") + .bind(task_id) + .fetch_optional(pool) + .await + .map_err(|e| AuthError::DatabaseError(e.to_string()))? + .ok_or(AuthError::UserNotFound)?; + + let task_owner: Uuid = row.try_get("owner_id") + .map_err(|e| AuthError::DatabaseError(e.to_string()))?; + + return Ok(AuthenticatedUser { + user_id: Uuid::nil(), // Tool keys don't have a user + owner_id: task_owner, + auth_source: AuthSource::ToolKey(task_id), + email: None, + }); + } + tracing::warn!("Invalid tool key provided"); + } + } + + // 2. Check for API key + if let Some(api_key) = headers.get(API_KEY_HEADER) { + if let Ok(key_str) = api_key.to_str() { + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let (user_id, owner_id) = validate_api_key(pool, key_str).await?; + + return Ok(AuthenticatedUser { + user_id, + owner_id, + auth_source: AuthSource::ApiKey, + email: None, + }); + } + } + + // 3. Check for JWT (Bearer token) + if let Some(auth_header) = headers.get(AUTHORIZATION) { + if let Ok(auth_str) = auth_header.to_str() { + if let Some(token) = auth_str.strip_prefix("Bearer ") { + let verifier = state + .jwt_verifier + .as_ref() + .ok_or(AuthError::NotConfigured)?; + + let claims = verifier.verify(token)?; + let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; + let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?; + + return Ok(AuthenticatedUser { + user_id: claims.sub, + owner_id, + auth_source: AuthSource::Jwt, + email: claims.email, + }); + } + } + } + + Err(AuthError::MissingToken) +} + +// ============================================================================= +// Extractors +// ============================================================================= + +/// Extractor for authenticated requests. +/// +/// Tries authentication methods in order: +/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators +/// 2. API Key (X-Makima-API-Key) - for daemons/CLI +/// 3. JWT (Authorization: Bearer) - for web clients +/// +/// Returns 401 Unauthorized if no valid authentication is found. +/// +/// # Example +/// ```ignore +/// async fn protected_handler( +/// Authenticated(user): Authenticated, +/// ) -> impl IntoResponse { +/// Json(format!("Hello user {}", user.user_id)) +/// } +/// ``` +pub struct Authenticated(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for Authenticated { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + Ok(Authenticated(user)) + } +} + +/// Extractor for user-only authentication (JWT or API key, no tool keys). +/// +/// Use this for endpoints that should only be accessible to actual users, +/// not orchestrators with tool keys. +/// +/// Returns 401 Unauthorized if no valid user authentication is found. +/// Returns 403 Forbidden if a tool key is used. +/// +/// # Example +/// ```ignore +/// async fn user_profile( +/// UserOnly(user): UserOnly, +/// ) -> impl IntoResponse { +/// // Only actual users can access this +/// Json(format!("User profile for {}", user.user_id)) +/// } +/// ``` +pub struct UserOnly(pub AuthenticatedUser); + +impl FromRequestParts<SharedState> for UserOnly { + type Rejection = AuthError; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await?; + + // Reject tool key authentication + if matches!(user.auth_source, AuthSource::ToolKey(_)) { + return Err(AuthError::InsufficientPermissions); + } + + Ok(UserOnly(user)) + } +} + +/// Extractor for optional authentication. +/// +/// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise. +/// Never returns an error - invalid auth is treated as no auth. +/// +/// # Example +/// ```ignore +/// async fn public_or_private( +/// MaybeAuthenticated(user): MaybeAuthenticated, +/// ) -> impl IntoResponse { +/// match user { +/// Some(u) => Json(format!("Hello {}", u.user_id)), +/// None => Json("Hello anonymous".to_string()), +/// } +/// } +/// ``` +pub struct MaybeAuthenticated(pub Option<AuthenticatedUser>); + +impl FromRequestParts<SharedState> for MaybeAuthenticated { + type Rejection = std::convert::Infallible; + + async fn from_request_parts( + parts: &mut Parts, + state: &SharedState, + ) -> Result<Self, Self::Rejection> { + let user = extract_auth(state, &parts.headers).await.ok(); + Ok(MaybeAuthenticated(user)) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hash_api_key() { + let key = "mk_test123456789"; + let hash = hash_api_key(key); + + // Hash should be consistent + assert_eq!(hash, hash_api_key(key)); + + // Hash should be 64 characters (SHA-256 hex) + assert_eq!(hash.len(), 64); + } + + #[test] + fn test_auth_error_display() { + assert_eq!( + AuthError::MissingToken.to_string(), + "Missing authentication token" + ); + assert_eq!( + AuthError::InvalidToken("bad".to_string()).to_string(), + "Invalid token: bad" + ); + } + + #[test] + fn test_generate_api_key_format() { + let generated = generate_api_key(); + + // Full key should start with mk_ prefix + assert!(generated.full_key.starts_with(API_KEY_PREFIX)); + + // Full key should be mk_ + 43 chars (32 bytes base64url encoded) + assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43 + + // Prefix should be mk_ + first 8 chars + assert!(generated.key_prefix.starts_with(API_KEY_PREFIX)); + assert_eq!(generated.key_prefix.len(), 3 + 8); + + // Hash should be 64 hex chars (SHA-256) + assert_eq!(generated.key_hash.len(), 64); + } + + #[test] + fn test_generate_api_key_uniqueness() { + let key1 = generate_api_key(); + let key2 = generate_api_key(); + + // Keys should be unique + assert_ne!(key1.full_key, key2.full_key); + assert_ne!(key1.key_hash, key2.key_hash); + } + + #[test] + fn test_api_key_cache_basic() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_123"; + + // Cache miss initially + assert!(cache.get(key_hash).is_none()); + + // Set and verify cache hit + cache.set(key_hash.to_string(), user_id, owner_id); + let result = cache.get(key_hash); + assert!(result.is_some()); + let (cached_user, cached_owner) = result.unwrap(); + assert_eq!(cached_user, user_id); + assert_eq!(cached_owner, owner_id); + } + + #[test] + fn test_api_key_cache_invalidate() { + let cache = ApiKeyCache::new(300); + let user_id = Uuid::new_v4(); + let owner_id = Uuid::new_v4(); + let key_hash = "test_hash_456"; + + cache.set(key_hash.to_string(), user_id, owner_id); + assert!(cache.get(key_hash).is_some()); + + cache.invalidate(key_hash); + assert!(cache.get(key_hash).is_none()); + } + + #[test] + fn test_api_key_cache_clear() { + let cache = ApiKeyCache::new(300); + + cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4()); + cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4()); + + assert!(cache.get("hash1").is_some()); + assert!(cache.get("hash2").is_some()); + + cache.clear(); + + assert!(cache.get("hash1").is_none()); + assert!(cache.get("hash2").is_none()); + } + + #[test] + fn test_api_key_event_type_display() { + assert_eq!(ApiKeyEventType::Created.to_string(), "created"); + assert_eq!(ApiKeyEventType::Used.to_string(), "used"); + assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked"); + assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed"); + } +} |
