//! Authentication module for Makima server. //! //! Supports multiple authentication methods: //! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification) //! - API keys for programmatic access (daemons, CLI) //! - Tool keys for orchestrator internal access use axum::{ extract::FromRequestParts, http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode}, response::IntoResponse, Json, }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use chrono::{DateTime, Utc}; use dashmap::DashMap; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; use rand::Rng; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use sqlx::{FromRow, PgPool, Row}; use std::time::{Duration, Instant}; use utoipa::ToSchema; use uuid::Uuid; use crate::server::messages::ApiError; use crate::server::state::SharedState; // ============================================================================= // Configuration // ============================================================================= /// JWT algorithm configuration. #[derive(Debug, Clone)] pub enum JwtAlgorithm { /// RS256 with RSA public key Rs256 { public_key: String }, /// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys) Es256 { public_key: String }, } /// Authentication configuration loaded from environment. #[derive(Debug, Clone)] pub struct AuthConfig { /// Supabase project URL (e.g., https://your-project.supabase.co) pub supabase_url: String, /// JWT algorithm and key material pub algorithm: JwtAlgorithm, } impl AuthConfig { /// Load auth config from environment variables. /// /// Supports two modes (checked in order): /// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA) /// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key) /// /// Returns None if auth is not configured. pub fn from_env() -> Option { let supabase_url = std::env::var("SUPABASE_URL").ok()?; // Try ES256 first (default for Supabase), then RS256 let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") { tracing::info!("Using ES256 JWT verification with ECDSA public key"); JwtAlgorithm::Es256 { public_key } } else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") { tracing::info!("Using RS256 JWT verification with RSA public key"); JwtAlgorithm::Rs256 { public_key } } else { return None; }; Some(Self { supabase_url, algorithm, }) } } // ============================================================================= // JWT Claims // ============================================================================= /// JWT claims from Supabase Auth tokens. #[derive(Debug, Serialize, Deserialize)] pub struct SupabaseClaims { /// Audience (e.g., "authenticated") pub aud: String, /// Expiration time (Unix timestamp) pub exp: i64, /// Issued at (Unix timestamp) pub iat: i64, /// Issuer (Supabase project URL + /auth/v1) pub iss: String, /// Subject (user ID) pub sub: Uuid, /// User's email pub email: Option, /// User's phone pub phone: Option, /// App metadata (set by server/admin) pub app_metadata: Option, /// User metadata (set by user) pub user_metadata: Option, /// Role (e.g., "authenticated") pub role: Option, /// Session ID pub session_id: Option, } // ============================================================================= // JWT Verifier // ============================================================================= /// JWT verifier for Supabase tokens. pub struct JwtVerifier { supabase_url: String, decoding_key: DecodingKey, algorithm: Algorithm, } impl JwtVerifier { /// Create a new JWT verifier from auth config. /// /// Supports multiple key formats: /// - JWK (JSON Web Key) - detected by presence of `{` /// - PEM - detected by `-----BEGIN` /// - Base64-encoded DER - fallback pub fn new(config: AuthConfig) -> Result { let (decoding_key, algorithm) = match &config.algorithm { JwtAlgorithm::Rs256 { public_key } => { let key = Self::parse_public_key(public_key, "RSA")?; (key, Algorithm::RS256) } JwtAlgorithm::Es256 { public_key } => { let key = Self::parse_public_key(public_key, "EC")?; (key, Algorithm::ES256) } }; Ok(Self { supabase_url: config.supabase_url, decoding_key, algorithm, }) } /// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER). fn parse_public_key(key_data: &str, key_type: &str) -> Result { let trimmed = key_data.trim(); // Check for JSON format (JWK or JWKS) if trimmed.starts_with('{') { // First try to parse as a generic JSON value to inspect structure let mut json_value: serde_json::Value = serde_json::from_str(trimmed) .map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?; // Check if it's a JWKS (has "keys" array) if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) { // Find the first signing key (or just use the first key) let jwk_value = keys.first_mut() .ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?; // Remove private key component if present (user may have pasted full keypair) if let Some(obj) = jwk_value.as_object_mut() { if obj.remove("d").is_some() { tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); } } let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone()) .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?; tracing::info!("Loaded JWT public key from JWKS (first key)"); return DecodingKey::from_jwk(&jwk) .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))); } // Remove private key component if present (user may have pasted full keypair) if let Some(obj) = json_value.as_object_mut() { if obj.remove("d").is_some() { tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification"); } } // Try as single JWK let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value) .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?; tracing::info!("Loaded JWT public key from JWK"); DecodingKey::from_jwk(&jwk) .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e))) } // Check for PEM format else if trimmed.contains("-----BEGIN") { tracing::info!("Loaded JWT public key from PEM"); match key_type { "RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes()) .map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))), "EC" => DecodingKey::from_ec_pem(trimmed.as_bytes()) .map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))), _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), } } // Assume base64-encoded DER else { tracing::info!("Loaded JWT public key from base64 DER"); let der_bytes = base64::engine::general_purpose::STANDARD .decode(trimmed) .map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?; match key_type { "RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)), "EC" => Ok(DecodingKey::from_ec_der(&der_bytes)), _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))), } } } /// Verify a JWT token and return claims. pub fn verify(&self, token: &str) -> Result { // Decode header to check algorithm mismatch let header = jsonwebtoken::decode_header(token) .map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?; tracing::debug!( "JWT header: algorithm={:?}, typ={:?}, kid={:?}", header.alg, header.typ, header.kid ); if header.alg != self.algorithm { let hint = match header.alg { Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings", Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key", _ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported", }; tracing::warn!( "JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}", header.alg, self.algorithm, hint ); return Err(AuthError::InvalidToken(format!( "Algorithm mismatch: token is {:?}, expected {:?}", header.alg, self.algorithm ))); } let mut validation = Validation::new(self.algorithm); validation.set_audience(&["authenticated"]); validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]); // First try with full validation let token_data = match decode::(token, &self.decoding_key, &validation) { Ok(data) => data, Err(e) => { // Log detailed error info tracing::warn!( "JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)", e, self.algorithm, self.supabase_url ); // If it's InvalidAlgorithm, try to understand why by decoding payload manually if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) { // Decode the payload part of the JWT manually (base64) let parts: Vec<&str> = token.split('.').collect(); if parts.len() >= 2 { if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) { if let Ok(payload_str) = String::from_utf8(payload_bytes) { if let Ok(claims) = serde_json::from_str::(&payload_str) { tracing::warn!( "JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}", claims.get("iss"), claims.get("aud"), claims.get("sub") ); } } } } } return Err(AuthError::InvalidToken(e.to_string())); } }; Ok(token_data.claims) } /// Extract user ID from a token. pub fn get_user_id(&self, token: &str) -> Result { let claims = self.verify(token)?; Ok(claims.sub) } } // ============================================================================= // Auth Error // ============================================================================= /// Authentication error types. #[derive(Debug)] pub enum AuthError { /// No authentication token provided MissingToken, /// Token format is invalid InvalidToken(String), /// Token has expired ExpiredToken, /// User not found in database UserNotFound, /// API key is invalid or revoked InvalidApiKey, /// Database error during auth lookup DatabaseError(String), /// Authentication is not configured NotConfigured, /// Insufficient permissions for the operation InsufficientPermissions, } impl std::fmt::Display for AuthError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { AuthError::MissingToken => write!(f, "Missing authentication token"), AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg), AuthError::ExpiredToken => write!(f, "Token has expired"), AuthError::UserNotFound => write!(f, "User not found"), AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"), AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg), AuthError::NotConfigured => write!(f, "Authentication not configured"), AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"), } } } impl std::error::Error for AuthError {} impl IntoResponse for AuthError { fn into_response(self) -> axum::response::Response { let (status, code, message) = match &self { AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"), AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"), AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"), AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"), AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"), AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"), AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"), AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"), }; (status, Json(ApiError::new(code, message))).into_response() } } // ============================================================================= // Auth Source // ============================================================================= /// Source of authentication. #[derive(Debug, Clone)] pub enum AuthSource { /// Authenticated via Supabase JWT (web client) Jwt, /// Authenticated via API key (daemon, CLI, integrations) ApiKey, /// Authenticated via tool key (orchestrator internal access) ToolKey(Uuid), } // ============================================================================= // Authenticated User // ============================================================================= /// Authenticated user context extracted from request. /// /// Contains the resolved user_id and owner_id for database operations. #[derive(Debug, Clone)] pub struct AuthenticatedUser { /// Supabase auth user ID (from auth.users) pub user_id: Uuid, /// Owner ID for data isolation (from users.default_owner_id) pub owner_id: Uuid, /// How the user was authenticated pub auth_source: AuthSource, /// User's email (if available) pub email: Option, } // ============================================================================= // Header Constants // ============================================================================= /// Header name for tool key authentication (orchestrators). pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key"; /// Header name for API key authentication. pub const API_KEY_HEADER: &str = "x-makima-api-key"; // ============================================================================= // Helper Functions // ============================================================================= /// Hash an API key for database lookup. pub fn hash_api_key(key: &str) -> String { let mut hasher = Sha256::new(); hasher.update(key.as_bytes()); hex::encode(hasher.finalize()) } // ============================================================================= // API Key Generation // ============================================================================= /// API key prefix for identification. pub const API_KEY_PREFIX: &str = "mk_"; /// Result of generating an API key. pub struct GeneratedApiKey { /// The full API key (shown only once to user) pub full_key: String, /// SHA-256 hash of the key (stored in database) pub key_hash: String, /// Prefix for display (first 8 chars after mk_) pub key_prefix: String, } /// Generate a new API key with mk_ prefix. /// /// Returns the full key (to show once), hash (to store), and prefix (for display). pub fn generate_api_key() -> GeneratedApiKey { let mut rng = rand::thread_rng(); let mut bytes = [0u8; 32]; rng.fill(&mut bytes); let key_bytes = URL_SAFE_NO_PAD.encode(bytes); let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes); let key_hash = hash_api_key(&full_key); let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]); GeneratedApiKey { full_key, key_hash, key_prefix, } } // ============================================================================= // API Key Cache // ============================================================================= /// Cache entry for validated API keys. struct ApiKeyCacheEntry { user_id: Uuid, owner_id: Uuid, cached_at: Instant, } /// In-memory cache for API key validation to avoid database lookups on every request. pub struct ApiKeyCache { /// key_hash -> (user_id, owner_id, cached_at) cache: DashMap, /// Time-to-live for cache entries ttl: Duration, } impl ApiKeyCache { /// Create a new cache with the specified TTL in seconds. pub fn new(ttl_seconds: u64) -> Self { Self { cache: DashMap::new(), ttl: Duration::from_secs(ttl_seconds), } } /// Get cached user_id and owner_id for a key hash, if not expired. pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> { self.cache.get(key_hash).and_then(|entry| { if entry.cached_at.elapsed() < self.ttl { Some((entry.user_id, entry.owner_id)) } else { None } }) } /// Cache a validated API key. pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) { self.cache.insert( key_hash, ApiKeyCacheEntry { user_id, owner_id, cached_at: Instant::now(), }, ); } /// Invalidate a cache entry (e.g., on key revocation). pub fn invalidate(&self, key_hash: &str) { self.cache.remove(key_hash); } /// Clear all cache entries. pub fn clear(&self) { self.cache.clear(); } } impl Default for ApiKeyCache { fn default() -> Self { // Default TTL: 5 minutes Self::new(300) } } // ============================================================================= // API Key Models // ============================================================================= /// API key record from the database. #[derive(Debug, Clone, FromRow, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct ApiKey { pub id: Uuid, pub user_id: Uuid, #[serde(skip)] pub key_hash: String, pub key_prefix: String, pub name: Option, pub last_used_at: Option>, pub created_at: DateTime, pub revoked_at: Option>, } /// Request to create a new API key. #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct CreateApiKeyRequest { /// User-provided label for the key pub name: Option, } /// Response after creating an API key (includes the full key - shown only once). #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct CreateApiKeyResponse { pub id: Uuid, /// The full API key - save this, it won't be shown again! pub key: String, pub prefix: String, pub name: Option, pub created_at: DateTime, } /// Response for getting API key info (excludes the full key). #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct ApiKeyInfoResponse { pub id: Uuid, pub prefix: String, pub name: Option, pub last_used_at: Option>, pub created_at: DateTime, } impl From for ApiKeyInfoResponse { fn from(key: ApiKey) -> Self { Self { id: key.id, prefix: key.key_prefix, name: key.name, last_used_at: key.last_used_at, created_at: key.created_at, } } } /// Request to refresh an API key. #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct RefreshApiKeyRequest { /// New name for the refreshed key pub name: Option, } /// Response after refreshing an API key. #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct RefreshApiKeyResponse { pub id: Uuid, /// The new API key - save this, it won't be shown again! pub key: String, pub prefix: String, pub name: Option, pub created_at: DateTime, pub previous_key_revoked: bool, } /// Response after revoking an API key. #[derive(Debug, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct RevokeApiKeyResponse { pub message: String, pub revoked_key_prefix: String, } /// API key event types for audit logging. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ApiKeyEventType { Created, Used, Revoked, Refreshed, } impl std::fmt::Display for ApiKeyEventType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ApiKeyEventType::Created => write!(f, "created"), ApiKeyEventType::Used => write!(f, "used"), ApiKeyEventType::Revoked => write!(f, "revoked"), ApiKeyEventType::Refreshed => write!(f, "refreshed"), } } } // ============================================================================= // API Keys Repository // ============================================================================= /// Repository error for API key operations. #[derive(Debug)] pub enum ApiKeyError { /// Database error Database(sqlx::Error), /// An active API key already exists for this user KeyAlreadyExists, /// No active API key found for this user KeyNotFound, } impl std::fmt::Display for ApiKeyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ApiKeyError::Database(e) => write!(f, "Database error: {}", e), ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"), ApiKeyError::KeyNotFound => write!(f, "No active API key found"), } } } impl std::error::Error for ApiKeyError {} impl From for ApiKeyError { fn from(e: sqlx::Error) -> Self { ApiKeyError::Database(e) } } /// Get the active API key for a user (if any). pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result, sqlx::Error> { sqlx::query_as::<_, ApiKey>( r#" SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at FROM api_keys WHERE user_id = $1 AND revoked_at IS NULL "#, ) .bind(user_id) .fetch_optional(pool) .await } /// Create a new API key for a user. /// /// Returns an error if the user already has an active key. /// The `generated` parameter should be created using `generate_api_key()`. pub async fn create_api_key( pool: &PgPool, user_id: Uuid, generated: &GeneratedApiKey, name: Option<&str>, ) -> Result { // Check if user already has an active key if let Some(_) = get_active_api_key(pool, user_id).await? { return Err(ApiKeyError::KeyAlreadyExists); } let key = sqlx::query_as::<_, ApiKey>( r#" INSERT INTO api_keys (user_id, key_hash, key_prefix, name) VALUES ($1, $2, $3, $4) RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at "#, ) .bind(user_id) .bind(&generated.key_hash) .bind(&generated.key_prefix) .bind(name) .fetch_one(pool) .await?; // Log the creation event let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await; Ok(key) } /// Revoke an API key by marking it with revoked_at timestamp. pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result { // Get the active key first let key = get_active_api_key(pool, user_id) .await? .ok_or(ApiKeyError::KeyNotFound)?; // Revoke it let revoked = sqlx::query_as::<_, ApiKey>( r#" UPDATE api_keys SET revoked_at = NOW() WHERE id = $1 RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at "#, ) .bind(key.id) .fetch_one(pool) .await?; // Log the revocation event let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await; Ok(revoked) } /// Refresh an API key: revoke the old one and create a new one atomically. /// /// Returns the new key. The caller should use `generate_api_key()` to create /// the `new_generated` parameter. pub async fn refresh_api_key( pool: &PgPool, user_id: Uuid, new_generated: &GeneratedApiKey, new_name: Option<&str>, ) -> Result<(ApiKey, Option), ApiKeyError> { // Get and revoke the old key (if exists) let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? { let old_prefix = old_key.key_prefix.clone(); // Revoke the old key sqlx::query( r#" UPDATE api_keys SET revoked_at = NOW() WHERE id = $1 "#, ) .bind(old_key.id) .execute(pool) .await?; // Log the refresh event on the old key let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await; Some(old_prefix) } else { None }; // Create the new key let new_key = sqlx::query_as::<_, ApiKey>( r#" INSERT INTO api_keys (user_id, key_hash, key_prefix, name) VALUES ($1, $2, $3, $4) RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at "#, ) .bind(user_id) .bind(&new_generated.key_hash) .bind(&new_generated.key_prefix) .bind(new_name) .fetch_one(pool) .await?; // Log the creation event on the new key let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await; Ok((new_key, old_prefix)) } /// Update last_used_at timestamp for an API key. pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> { sqlx::query( r#" UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1 AND revoked_at IS NULL "#, ) .bind(key_hash) .execute(pool) .await?; Ok(()) } /// Log an API key event for audit purposes. pub async fn log_api_key_event( pool: &PgPool, api_key_id: Uuid, event_type: ApiKeyEventType, ip_address: Option<&str>, user_agent: Option<&str>, ) -> Result<(), sqlx::Error> { sqlx::query( r#" INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent) VALUES ($1, $2, $3::inet, $4) "#, ) .bind(api_key_id) .bind(event_type.to_string()) .bind(ip_address) .bind(user_agent) .execute(pool) .await?; Ok(()) } // ============================================================================= // Internal Helper Functions // ============================================================================= /// Resolve owner_id from user_id by looking up the users table. /// If the user doesn't exist, auto-creates them on first login. /// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously. async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result { // First, try to get existing user let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") .bind(user_id) .fetch_optional(pool) .await .map_err(|e| AuthError::DatabaseError(e.to_string()))?; if let Some(row) = row { let owner_id: Option = row.try_get("default_owner_id") .map_err(|e| AuthError::DatabaseError(e.to_string()))?; return owner_id.ok_or(AuthError::UserNotFound); } // User doesn't exist - auto-create on first login tracing::info!("Creating new user record for {}", user_id); // Create owner first (use ON CONFLICT to handle race conditions) let owner_id = Uuid::new_v4(); sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING") .bind(owner_id) .bind(email.unwrap_or("Unknown")) .execute(pool) .await .map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?; // Create user with reference to owner (use ON CONFLICT to handle race conditions) sqlx::query( "INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING" ) .bind(user_id) .bind(email) .bind(owner_id) .execute(pool) .await .map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?; // Re-fetch the user to get the actual owner_id (in case another request created it first) let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1") .bind(user_id) .fetch_optional(pool) .await .map_err(|e| AuthError::DatabaseError(e.to_string()))?; match row { Some(row) => { let owner_id: Option = row.try_get("default_owner_id") .map_err(|e| AuthError::DatabaseError(e.to_string()))?; owner_id.ok_or(AuthError::UserNotFound) } None => Err(AuthError::DatabaseError("Failed to create user record".to_string())) } } /// Validate an API key and return (user_id, owner_id). async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> { let key_hash = hash_api_key(key); // Look up the API key and join with users to get owner_id let row = sqlx::query( r#" SELECT ak.user_id, u.default_owner_id FROM api_keys ak JOIN users u ON u.id = ak.user_id WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL "#, ) .bind(&key_hash) .fetch_optional(pool) .await .map_err(|e| AuthError::DatabaseError(e.to_string()))?; match row { Some(row) => { let user_id: Uuid = row.try_get("user_id") .map_err(|e| AuthError::DatabaseError(e.to_string()))?; let owner_id: Option = row.try_get("default_owner_id") .map_err(|e| AuthError::DatabaseError(e.to_string()))?; let owner_id = owner_id.ok_or(AuthError::UserNotFound)?; // Update last_used_at asynchronously (fire and forget) let pool_clone = pool.clone(); let key_hash_clone = key_hash.clone(); tokio::spawn(async move { let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1") .bind(&key_hash_clone) .execute(&pool_clone) .await; }); Ok((user_id, owner_id)) } None => Err(AuthError::InvalidApiKey), } } /// Extract authentication from request headers. /// /// Tries authentication methods in order: /// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators /// 2. API Key (X-Makima-API-Key) - for daemons/CLI /// 3. JWT (Authorization: Bearer) - for web clients async fn extract_auth( state: &SharedState, headers: &HeaderMap, ) -> Result { // 1. Check for tool key (orchestrator access) if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) { if let Ok(key_str) = tool_key.to_str() { if let Some(task_id) = state.validate_tool_key(key_str) { // Tool keys are trusted - use a placeholder user/owner for orchestrator actions // The orchestrator inherits the owner_id from its task let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; // Get owner_id from the task let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1") .bind(task_id) .fetch_optional(pool) .await .map_err(|e| AuthError::DatabaseError(e.to_string()))? .ok_or(AuthError::UserNotFound)?; let task_owner: Uuid = row.try_get("owner_id") .map_err(|e| AuthError::DatabaseError(e.to_string()))?; return Ok(AuthenticatedUser { user_id: Uuid::nil(), // Tool keys don't have a user owner_id: task_owner, auth_source: AuthSource::ToolKey(task_id), email: None, }); } tracing::warn!("Invalid tool key provided"); } } // 2. Check for API key if let Some(api_key) = headers.get(API_KEY_HEADER) { if let Ok(key_str) = api_key.to_str() { let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; let (user_id, owner_id) = validate_api_key(pool, key_str).await?; return Ok(AuthenticatedUser { user_id, owner_id, auth_source: AuthSource::ApiKey, email: None, }); } } // 3. Check for JWT (Bearer token) if let Some(auth_header) = headers.get(AUTHORIZATION) { if let Ok(auth_str) = auth_header.to_str() { if let Some(token) = auth_str.strip_prefix("Bearer ") { let verifier = state .jwt_verifier .as_ref() .ok_or(AuthError::NotConfigured)?; let claims = verifier.verify(token)?; let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?; let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?; return Ok(AuthenticatedUser { user_id: claims.sub, owner_id, auth_source: AuthSource::Jwt, email: claims.email, }); } } } Err(AuthError::MissingToken) } // ============================================================================= // Extractors // ============================================================================= /// Extractor for authenticated requests. /// /// Tries authentication methods in order: /// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators /// 2. API Key (X-Makima-API-Key) - for daemons/CLI /// 3. JWT (Authorization: Bearer) - for web clients /// /// Returns 401 Unauthorized if no valid authentication is found. /// /// # Example /// ```ignore /// async fn protected_handler( /// Authenticated(user): Authenticated, /// ) -> impl IntoResponse { /// Json(format!("Hello user {}", user.user_id)) /// } /// ``` pub struct Authenticated(pub AuthenticatedUser); impl FromRequestParts for Authenticated { type Rejection = AuthError; async fn from_request_parts( parts: &mut Parts, state: &SharedState, ) -> Result { let user = extract_auth(state, &parts.headers).await?; Ok(Authenticated(user)) } } /// Extractor for user-only authentication (JWT or API key, no tool keys). /// /// Use this for endpoints that should only be accessible to actual users, /// not orchestrators with tool keys. /// /// Returns 401 Unauthorized if no valid user authentication is found. /// Returns 403 Forbidden if a tool key is used. /// /// # Example /// ```ignore /// async fn user_profile( /// UserOnly(user): UserOnly, /// ) -> impl IntoResponse { /// // Only actual users can access this /// Json(format!("User profile for {}", user.user_id)) /// } /// ``` pub struct UserOnly(pub AuthenticatedUser); impl FromRequestParts for UserOnly { type Rejection = AuthError; async fn from_request_parts( parts: &mut Parts, state: &SharedState, ) -> Result { let user = extract_auth(state, &parts.headers).await?; // Reject tool key authentication if matches!(user.auth_source, AuthSource::ToolKey(_)) { return Err(AuthError::InsufficientPermissions); } Ok(UserOnly(user)) } } /// Extractor for optional authentication. /// /// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise. /// Never returns an error - invalid auth is treated as no auth. /// /// # Example /// ```ignore /// async fn public_or_private( /// MaybeAuthenticated(user): MaybeAuthenticated, /// ) -> impl IntoResponse { /// match user { /// Some(u) => Json(format!("Hello {}", u.user_id)), /// None => Json("Hello anonymous".to_string()), /// } /// } /// ``` pub struct MaybeAuthenticated(pub Option); impl FromRequestParts for MaybeAuthenticated { type Rejection = std::convert::Infallible; async fn from_request_parts( parts: &mut Parts, state: &SharedState, ) -> Result { let user = extract_auth(state, &parts.headers).await.ok(); Ok(MaybeAuthenticated(user)) } } // ============================================================================= // Tests // ============================================================================= #[cfg(test)] mod tests { use super::*; #[test] fn test_hash_api_key() { let key = "mk_test123456789"; let hash = hash_api_key(key); // Hash should be consistent assert_eq!(hash, hash_api_key(key)); // Hash should be 64 characters (SHA-256 hex) assert_eq!(hash.len(), 64); } #[test] fn test_auth_error_display() { assert_eq!( AuthError::MissingToken.to_string(), "Missing authentication token" ); assert_eq!( AuthError::InvalidToken("bad".to_string()).to_string(), "Invalid token: bad" ); } #[test] fn test_generate_api_key_format() { let generated = generate_api_key(); // Full key should start with mk_ prefix assert!(generated.full_key.starts_with(API_KEY_PREFIX)); // Full key should be mk_ + 43 chars (32 bytes base64url encoded) assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43 // Prefix should be mk_ + first 8 chars assert!(generated.key_prefix.starts_with(API_KEY_PREFIX)); assert_eq!(generated.key_prefix.len(), 3 + 8); // Hash should be 64 hex chars (SHA-256) assert_eq!(generated.key_hash.len(), 64); } #[test] fn test_generate_api_key_uniqueness() { let key1 = generate_api_key(); let key2 = generate_api_key(); // Keys should be unique assert_ne!(key1.full_key, key2.full_key); assert_ne!(key1.key_hash, key2.key_hash); } #[test] fn test_api_key_cache_basic() { let cache = ApiKeyCache::new(300); let user_id = Uuid::new_v4(); let owner_id = Uuid::new_v4(); let key_hash = "test_hash_123"; // Cache miss initially assert!(cache.get(key_hash).is_none()); // Set and verify cache hit cache.set(key_hash.to_string(), user_id, owner_id); let result = cache.get(key_hash); assert!(result.is_some()); let (cached_user, cached_owner) = result.unwrap(); assert_eq!(cached_user, user_id); assert_eq!(cached_owner, owner_id); } #[test] fn test_api_key_cache_invalidate() { let cache = ApiKeyCache::new(300); let user_id = Uuid::new_v4(); let owner_id = Uuid::new_v4(); let key_hash = "test_hash_456"; cache.set(key_hash.to_string(), user_id, owner_id); assert!(cache.get(key_hash).is_some()); cache.invalidate(key_hash); assert!(cache.get(key_hash).is_none()); } #[test] fn test_api_key_cache_clear() { let cache = ApiKeyCache::new(300); cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4()); cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4()); assert!(cache.get("hash1").is_some()); assert!(cache.get("hash2").is_some()); cache.clear(); assert!(cache.get("hash1").is_none()); assert!(cache.get("hash2").is_none()); } #[test] fn test_api_key_event_type_display() { assert_eq!(ApiKeyEventType::Created.to_string(), "created"); assert_eq!(ApiKeyEventType::Used.to_string(), "used"); assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked"); assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed"); } }