summaryrefslogtreecommitdiff
path: root/makima/src/server/auth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/auth.rs')
-rw-r--r--makima/src/server/auth.rs1238
1 files changed, 1238 insertions, 0 deletions
diff --git a/makima/src/server/auth.rs b/makima/src/server/auth.rs
new file mode 100644
index 0000000..b694df6
--- /dev/null
+++ b/makima/src/server/auth.rs
@@ -0,0 +1,1238 @@
+//! Authentication module for Makima server.
+//!
+//! Supports multiple authentication methods:
+//! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification)
+//! - API keys for programmatic access (daemons, CLI)
+//! - Tool keys for orchestrator internal access
+
+use axum::{
+ extract::FromRequestParts,
+ http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode},
+ response::IntoResponse,
+ Json,
+};
+use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
+use chrono::{DateTime, Utc};
+use dashmap::DashMap;
+use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
+use rand::Rng;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use sqlx::{FromRow, PgPool, Row};
+use std::time::{Duration, Instant};
+use utoipa::ToSchema;
+use uuid::Uuid;
+
+use crate::server::messages::ApiError;
+use crate::server::state::SharedState;
+
+// =============================================================================
+// Configuration
+// =============================================================================
+
+/// JWT algorithm configuration.
+#[derive(Debug, Clone)]
+pub enum JwtAlgorithm {
+ /// RS256 with RSA public key
+ Rs256 { public_key: String },
+ /// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys)
+ Es256 { public_key: String },
+}
+
+/// Authentication configuration loaded from environment.
+#[derive(Debug, Clone)]
+pub struct AuthConfig {
+ /// Supabase project URL (e.g., https://your-project.supabase.co)
+ pub supabase_url: String,
+ /// JWT algorithm and key material
+ pub algorithm: JwtAlgorithm,
+}
+
+impl AuthConfig {
+ /// Load auth config from environment variables.
+ ///
+ /// Supports two modes (checked in order):
+ /// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA)
+ /// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key)
+ ///
+ /// Returns None if auth is not configured.
+ pub fn from_env() -> Option<Self> {
+ let supabase_url = std::env::var("SUPABASE_URL").ok()?;
+
+ // Try ES256 first (default for Supabase), then RS256
+ let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") {
+ tracing::info!("Using ES256 JWT verification with ECDSA public key");
+ JwtAlgorithm::Es256 { public_key }
+ } else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") {
+ tracing::info!("Using RS256 JWT verification with RSA public key");
+ JwtAlgorithm::Rs256 { public_key }
+ } else {
+ return None;
+ };
+
+ Some(Self {
+ supabase_url,
+ algorithm,
+ })
+ }
+}
+
+// =============================================================================
+// JWT Claims
+// =============================================================================
+
+/// JWT claims from Supabase Auth tokens.
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SupabaseClaims {
+ /// Audience (e.g., "authenticated")
+ pub aud: String,
+ /// Expiration time (Unix timestamp)
+ pub exp: i64,
+ /// Issued at (Unix timestamp)
+ pub iat: i64,
+ /// Issuer (Supabase project URL + /auth/v1)
+ pub iss: String,
+ /// Subject (user ID)
+ pub sub: Uuid,
+ /// User's email
+ pub email: Option<String>,
+ /// User's phone
+ pub phone: Option<String>,
+ /// App metadata (set by server/admin)
+ pub app_metadata: Option<serde_json::Value>,
+ /// User metadata (set by user)
+ pub user_metadata: Option<serde_json::Value>,
+ /// Role (e.g., "authenticated")
+ pub role: Option<String>,
+ /// Session ID
+ pub session_id: Option<Uuid>,
+}
+
+// =============================================================================
+// JWT Verifier
+// =============================================================================
+
+/// JWT verifier for Supabase tokens.
+pub struct JwtVerifier {
+ supabase_url: String,
+ decoding_key: DecodingKey,
+ algorithm: Algorithm,
+}
+
+impl JwtVerifier {
+ /// Create a new JWT verifier from auth config.
+ ///
+ /// Supports multiple key formats:
+ /// - JWK (JSON Web Key) - detected by presence of `{`
+ /// - PEM - detected by `-----BEGIN`
+ /// - Base64-encoded DER - fallback
+ pub fn new(config: AuthConfig) -> Result<Self, AuthError> {
+ let (decoding_key, algorithm) = match &config.algorithm {
+ JwtAlgorithm::Rs256 { public_key } => {
+ let key = Self::parse_public_key(public_key, "RSA")?;
+ (key, Algorithm::RS256)
+ }
+ JwtAlgorithm::Es256 { public_key } => {
+ let key = Self::parse_public_key(public_key, "EC")?;
+ (key, Algorithm::ES256)
+ }
+ };
+
+ Ok(Self {
+ supabase_url: config.supabase_url,
+ decoding_key,
+ algorithm,
+ })
+ }
+
+ /// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER).
+ fn parse_public_key(key_data: &str, key_type: &str) -> Result<DecodingKey, AuthError> {
+ let trimmed = key_data.trim();
+
+ // Check for JSON format (JWK or JWKS)
+ if trimmed.starts_with('{') {
+ // First try to parse as a generic JSON value to inspect structure
+ let mut json_value: serde_json::Value = serde_json::from_str(trimmed)
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?;
+
+ // Check if it's a JWKS (has "keys" array)
+ if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) {
+ // Find the first signing key (or just use the first key)
+ let jwk_value = keys.first_mut()
+ .ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?;
+
+ // Remove private key component if present (user may have pasted full keypair)
+ if let Some(obj) = jwk_value.as_object_mut() {
+ if obj.remove("d").is_some() {
+ tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification");
+ }
+ }
+
+ let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone())
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?;
+
+ tracing::info!("Loaded JWT public key from JWKS (first key)");
+ return DecodingKey::from_jwk(&jwk)
+ .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e)));
+ }
+
+ // Remove private key component if present (user may have pasted full keypair)
+ if let Some(obj) = json_value.as_object_mut() {
+ if obj.remove("d").is_some() {
+ tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification");
+ }
+ }
+
+ // Try as single JWK
+ let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value)
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?;
+
+ tracing::info!("Loaded JWT public key from JWK");
+ DecodingKey::from_jwk(&jwk)
+ .map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e)))
+ }
+ // Check for PEM format
+ else if trimmed.contains("-----BEGIN") {
+ tracing::info!("Loaded JWT public key from PEM");
+ match key_type {
+ "RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes())
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))),
+ "EC" => DecodingKey::from_ec_pem(trimmed.as_bytes())
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))),
+ _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))),
+ }
+ }
+ // Assume base64-encoded DER
+ else {
+ tracing::info!("Loaded JWT public key from base64 DER");
+ let der_bytes = base64::engine::general_purpose::STANDARD
+ .decode(trimmed)
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?;
+
+ match key_type {
+ "RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)),
+ "EC" => Ok(DecodingKey::from_ec_der(&der_bytes)),
+ _ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))),
+ }
+ }
+ }
+
+ /// Verify a JWT token and return claims.
+ pub fn verify(&self, token: &str) -> Result<SupabaseClaims, AuthError> {
+ // Decode header to check algorithm mismatch
+ let header = jsonwebtoken::decode_header(token)
+ .map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?;
+
+ tracing::debug!(
+ "JWT header: algorithm={:?}, typ={:?}, kid={:?}",
+ header.alg,
+ header.typ,
+ header.kid
+ );
+
+ if header.alg != self.algorithm {
+ let hint = match header.alg {
+ Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings",
+ Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key",
+ _ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported",
+ };
+ tracing::warn!(
+ "JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}",
+ header.alg,
+ self.algorithm,
+ hint
+ );
+ return Err(AuthError::InvalidToken(format!(
+ "Algorithm mismatch: token is {:?}, expected {:?}",
+ header.alg, self.algorithm
+ )));
+ }
+
+ let mut validation = Validation::new(self.algorithm);
+ validation.set_audience(&["authenticated"]);
+ validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]);
+
+ // First try with full validation
+ let token_data = match decode::<SupabaseClaims>(token, &self.decoding_key, &validation) {
+ Ok(data) => data,
+ Err(e) => {
+ // Log detailed error info
+ tracing::warn!(
+ "JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)",
+ e,
+ self.algorithm,
+ self.supabase_url
+ );
+
+ // If it's InvalidAlgorithm, try to understand why by decoding payload manually
+ if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) {
+ // Decode the payload part of the JWT manually (base64)
+ let parts: Vec<&str> = token.split('.').collect();
+ if parts.len() >= 2 {
+ if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) {
+ if let Ok(payload_str) = String::from_utf8(payload_bytes) {
+ if let Ok(claims) = serde_json::from_str::<serde_json::Value>(&payload_str) {
+ tracing::warn!(
+ "JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}",
+ claims.get("iss"),
+ claims.get("aud"),
+ claims.get("sub")
+ );
+ }
+ }
+ }
+ }
+ }
+
+ return Err(AuthError::InvalidToken(e.to_string()));
+ }
+ };
+
+ Ok(token_data.claims)
+ }
+
+ /// Extract user ID from a token.
+ pub fn get_user_id(&self, token: &str) -> Result<Uuid, AuthError> {
+ let claims = self.verify(token)?;
+ Ok(claims.sub)
+ }
+}
+
+// =============================================================================
+// Auth Error
+// =============================================================================
+
+/// Authentication error types.
+#[derive(Debug)]
+pub enum AuthError {
+ /// No authentication token provided
+ MissingToken,
+ /// Token format is invalid
+ InvalidToken(String),
+ /// Token has expired
+ ExpiredToken,
+ /// User not found in database
+ UserNotFound,
+ /// API key is invalid or revoked
+ InvalidApiKey,
+ /// Database error during auth lookup
+ DatabaseError(String),
+ /// Authentication is not configured
+ NotConfigured,
+ /// Insufficient permissions for the operation
+ InsufficientPermissions,
+}
+
+impl std::fmt::Display for AuthError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ AuthError::MissingToken => write!(f, "Missing authentication token"),
+ AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg),
+ AuthError::ExpiredToken => write!(f, "Token has expired"),
+ AuthError::UserNotFound => write!(f, "User not found"),
+ AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"),
+ AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
+ AuthError::NotConfigured => write!(f, "Authentication not configured"),
+ AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"),
+ }
+ }
+}
+
+impl std::error::Error for AuthError {}
+
+impl IntoResponse for AuthError {
+ fn into_response(self) -> axum::response::Response {
+ let (status, code, message) = match &self {
+ AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"),
+ AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"),
+ AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"),
+ AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"),
+ AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"),
+ AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"),
+ AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"),
+ AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"),
+ };
+
+ (status, Json(ApiError::new(code, message))).into_response()
+ }
+}
+
+// =============================================================================
+// Auth Source
+// =============================================================================
+
+/// Source of authentication.
+#[derive(Debug, Clone)]
+pub enum AuthSource {
+ /// Authenticated via Supabase JWT (web client)
+ Jwt,
+ /// Authenticated via API key (daemon, CLI, integrations)
+ ApiKey,
+ /// Authenticated via tool key (orchestrator internal access)
+ ToolKey(Uuid),
+}
+
+// =============================================================================
+// Authenticated User
+// =============================================================================
+
+/// Authenticated user context extracted from request.
+///
+/// Contains the resolved user_id and owner_id for database operations.
+#[derive(Debug, Clone)]
+pub struct AuthenticatedUser {
+ /// Supabase auth user ID (from auth.users)
+ pub user_id: Uuid,
+ /// Owner ID for data isolation (from users.default_owner_id)
+ pub owner_id: Uuid,
+ /// How the user was authenticated
+ pub auth_source: AuthSource,
+ /// User's email (if available)
+ pub email: Option<String>,
+}
+
+// =============================================================================
+// Header Constants
+// =============================================================================
+
+/// Header name for tool key authentication (orchestrators).
+pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key";
+
+/// Header name for API key authentication.
+pub const API_KEY_HEADER: &str = "x-makima-api-key";
+
+// =============================================================================
+// Helper Functions
+// =============================================================================
+
+/// Hash an API key for database lookup.
+pub fn hash_api_key(key: &str) -> String {
+ let mut hasher = Sha256::new();
+ hasher.update(key.as_bytes());
+ hex::encode(hasher.finalize())
+}
+
+// =============================================================================
+// API Key Generation
+// =============================================================================
+
+/// API key prefix for identification.
+pub const API_KEY_PREFIX: &str = "mk_";
+
+/// Result of generating an API key.
+pub struct GeneratedApiKey {
+ /// The full API key (shown only once to user)
+ pub full_key: String,
+ /// SHA-256 hash of the key (stored in database)
+ pub key_hash: String,
+ /// Prefix for display (first 8 chars after mk_)
+ pub key_prefix: String,
+}
+
+/// Generate a new API key with mk_ prefix.
+///
+/// Returns the full key (to show once), hash (to store), and prefix (for display).
+pub fn generate_api_key() -> GeneratedApiKey {
+ let mut rng = rand::thread_rng();
+ let mut bytes = [0u8; 32];
+ rng.fill(&mut bytes);
+
+ let key_bytes = URL_SAFE_NO_PAD.encode(bytes);
+ let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes);
+
+ let key_hash = hash_api_key(&full_key);
+ let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]);
+
+ GeneratedApiKey {
+ full_key,
+ key_hash,
+ key_prefix,
+ }
+}
+
+// =============================================================================
+// API Key Cache
+// =============================================================================
+
+/// Cache entry for validated API keys.
+struct ApiKeyCacheEntry {
+ user_id: Uuid,
+ owner_id: Uuid,
+ cached_at: Instant,
+}
+
+/// In-memory cache for API key validation to avoid database lookups on every request.
+pub struct ApiKeyCache {
+ /// key_hash -> (user_id, owner_id, cached_at)
+ cache: DashMap<String, ApiKeyCacheEntry>,
+ /// Time-to-live for cache entries
+ ttl: Duration,
+}
+
+impl ApiKeyCache {
+ /// Create a new cache with the specified TTL in seconds.
+ pub fn new(ttl_seconds: u64) -> Self {
+ Self {
+ cache: DashMap::new(),
+ ttl: Duration::from_secs(ttl_seconds),
+ }
+ }
+
+ /// Get cached user_id and owner_id for a key hash, if not expired.
+ pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> {
+ self.cache.get(key_hash).and_then(|entry| {
+ if entry.cached_at.elapsed() < self.ttl {
+ Some((entry.user_id, entry.owner_id))
+ } else {
+ None
+ }
+ })
+ }
+
+ /// Cache a validated API key.
+ pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) {
+ self.cache.insert(
+ key_hash,
+ ApiKeyCacheEntry {
+ user_id,
+ owner_id,
+ cached_at: Instant::now(),
+ },
+ );
+ }
+
+ /// Invalidate a cache entry (e.g., on key revocation).
+ pub fn invalidate(&self, key_hash: &str) {
+ self.cache.remove(key_hash);
+ }
+
+ /// Clear all cache entries.
+ pub fn clear(&self) {
+ self.cache.clear();
+ }
+}
+
+impl Default for ApiKeyCache {
+ fn default() -> Self {
+ // Default TTL: 5 minutes
+ Self::new(300)
+ }
+}
+
+// =============================================================================
+// API Key Models
+// =============================================================================
+
+/// API key record from the database.
+#[derive(Debug, Clone, FromRow, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct ApiKey {
+ pub id: Uuid,
+ pub user_id: Uuid,
+ #[serde(skip)]
+ pub key_hash: String,
+ pub key_prefix: String,
+ pub name: Option<String>,
+ pub last_used_at: Option<DateTime<Utc>>,
+ pub created_at: DateTime<Utc>,
+ pub revoked_at: Option<DateTime<Utc>>,
+}
+
+/// Request to create a new API key.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct CreateApiKeyRequest {
+ /// User-provided label for the key
+ pub name: Option<String>,
+}
+
+/// Response after creating an API key (includes the full key - shown only once).
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct CreateApiKeyResponse {
+ pub id: Uuid,
+ /// The full API key - save this, it won't be shown again!
+ pub key: String,
+ pub prefix: String,
+ pub name: Option<String>,
+ pub created_at: DateTime<Utc>,
+}
+
+/// Response for getting API key info (excludes the full key).
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct ApiKeyInfoResponse {
+ pub id: Uuid,
+ pub prefix: String,
+ pub name: Option<String>,
+ pub last_used_at: Option<DateTime<Utc>>,
+ pub created_at: DateTime<Utc>,
+}
+
+impl From<ApiKey> for ApiKeyInfoResponse {
+ fn from(key: ApiKey) -> Self {
+ Self {
+ id: key.id,
+ prefix: key.key_prefix,
+ name: key.name,
+ last_used_at: key.last_used_at,
+ created_at: key.created_at,
+ }
+ }
+}
+
+/// Request to refresh an API key.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct RefreshApiKeyRequest {
+ /// New name for the refreshed key
+ pub name: Option<String>,
+}
+
+/// Response after refreshing an API key.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct RefreshApiKeyResponse {
+ pub id: Uuid,
+ /// The new API key - save this, it won't be shown again!
+ pub key: String,
+ pub prefix: String,
+ pub name: Option<String>,
+ pub created_at: DateTime<Utc>,
+ pub previous_key_revoked: bool,
+}
+
+/// Response after revoking an API key.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct RevokeApiKeyResponse {
+ pub message: String,
+ pub revoked_key_prefix: String,
+}
+
+/// API key event types for audit logging.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ApiKeyEventType {
+ Created,
+ Used,
+ Revoked,
+ Refreshed,
+}
+
+impl std::fmt::Display for ApiKeyEventType {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ApiKeyEventType::Created => write!(f, "created"),
+ ApiKeyEventType::Used => write!(f, "used"),
+ ApiKeyEventType::Revoked => write!(f, "revoked"),
+ ApiKeyEventType::Refreshed => write!(f, "refreshed"),
+ }
+ }
+}
+
+// =============================================================================
+// API Keys Repository
+// =============================================================================
+
+/// Repository error for API key operations.
+#[derive(Debug)]
+pub enum ApiKeyError {
+ /// Database error
+ Database(sqlx::Error),
+ /// An active API key already exists for this user
+ KeyAlreadyExists,
+ /// No active API key found for this user
+ KeyNotFound,
+}
+
+impl std::fmt::Display for ApiKeyError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ApiKeyError::Database(e) => write!(f, "Database error: {}", e),
+ ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"),
+ ApiKeyError::KeyNotFound => write!(f, "No active API key found"),
+ }
+ }
+}
+
+impl std::error::Error for ApiKeyError {}
+
+impl From<sqlx::Error> for ApiKeyError {
+ fn from(e: sqlx::Error) -> Self {
+ ApiKeyError::Database(e)
+ }
+}
+
+/// Get the active API key for a user (if any).
+pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result<Option<ApiKey>, sqlx::Error> {
+ sqlx::query_as::<_, ApiKey>(
+ r#"
+ SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
+ FROM api_keys
+ WHERE user_id = $1 AND revoked_at IS NULL
+ "#,
+ )
+ .bind(user_id)
+ .fetch_optional(pool)
+ .await
+}
+
+/// Create a new API key for a user.
+///
+/// Returns an error if the user already has an active key.
+/// The `generated` parameter should be created using `generate_api_key()`.
+pub async fn create_api_key(
+ pool: &PgPool,
+ user_id: Uuid,
+ generated: &GeneratedApiKey,
+ name: Option<&str>,
+) -> Result<ApiKey, ApiKeyError> {
+ // Check if user already has an active key
+ if let Some(_) = get_active_api_key(pool, user_id).await? {
+ return Err(ApiKeyError::KeyAlreadyExists);
+ }
+
+ let key = sqlx::query_as::<_, ApiKey>(
+ r#"
+ INSERT INTO api_keys (user_id, key_hash, key_prefix, name)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
+ "#,
+ )
+ .bind(user_id)
+ .bind(&generated.key_hash)
+ .bind(&generated.key_prefix)
+ .bind(name)
+ .fetch_one(pool)
+ .await?;
+
+ // Log the creation event
+ let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await;
+
+ Ok(key)
+}
+
+/// Revoke an API key by marking it with revoked_at timestamp.
+pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result<ApiKey, ApiKeyError> {
+ // Get the active key first
+ let key = get_active_api_key(pool, user_id)
+ .await?
+ .ok_or(ApiKeyError::KeyNotFound)?;
+
+ // Revoke it
+ let revoked = sqlx::query_as::<_, ApiKey>(
+ r#"
+ UPDATE api_keys
+ SET revoked_at = NOW()
+ WHERE id = $1
+ RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
+ "#,
+ )
+ .bind(key.id)
+ .fetch_one(pool)
+ .await?;
+
+ // Log the revocation event
+ let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await;
+
+ Ok(revoked)
+}
+
+/// Refresh an API key: revoke the old one and create a new one atomically.
+///
+/// Returns the new key. The caller should use `generate_api_key()` to create
+/// the `new_generated` parameter.
+pub async fn refresh_api_key(
+ pool: &PgPool,
+ user_id: Uuid,
+ new_generated: &GeneratedApiKey,
+ new_name: Option<&str>,
+) -> Result<(ApiKey, Option<String>), ApiKeyError> {
+ // Get and revoke the old key (if exists)
+ let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? {
+ let old_prefix = old_key.key_prefix.clone();
+
+ // Revoke the old key
+ sqlx::query(
+ r#"
+ UPDATE api_keys
+ SET revoked_at = NOW()
+ WHERE id = $1
+ "#,
+ )
+ .bind(old_key.id)
+ .execute(pool)
+ .await?;
+
+ // Log the refresh event on the old key
+ let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await;
+
+ Some(old_prefix)
+ } else {
+ None
+ };
+
+ // Create the new key
+ let new_key = sqlx::query_as::<_, ApiKey>(
+ r#"
+ INSERT INTO api_keys (user_id, key_hash, key_prefix, name)
+ VALUES ($1, $2, $3, $4)
+ RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
+ "#,
+ )
+ .bind(user_id)
+ .bind(&new_generated.key_hash)
+ .bind(&new_generated.key_prefix)
+ .bind(new_name)
+ .fetch_one(pool)
+ .await?;
+
+ // Log the creation event on the new key
+ let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await;
+
+ Ok((new_key, old_prefix))
+}
+
+/// Update last_used_at timestamp for an API key.
+pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> {
+ sqlx::query(
+ r#"
+ UPDATE api_keys
+ SET last_used_at = NOW()
+ WHERE key_hash = $1 AND revoked_at IS NULL
+ "#,
+ )
+ .bind(key_hash)
+ .execute(pool)
+ .await?;
+
+ Ok(())
+}
+
+/// Log an API key event for audit purposes.
+pub async fn log_api_key_event(
+ pool: &PgPool,
+ api_key_id: Uuid,
+ event_type: ApiKeyEventType,
+ ip_address: Option<&str>,
+ user_agent: Option<&str>,
+) -> Result<(), sqlx::Error> {
+ sqlx::query(
+ r#"
+ INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent)
+ VALUES ($1, $2, $3::inet, $4)
+ "#,
+ )
+ .bind(api_key_id)
+ .bind(event_type.to_string())
+ .bind(ip_address)
+ .bind(user_agent)
+ .execute(pool)
+ .await?;
+
+ Ok(())
+}
+
+// =============================================================================
+// Internal Helper Functions
+// =============================================================================
+
+/// Resolve owner_id from user_id by looking up the users table.
+/// If the user doesn't exist, auto-creates them on first login.
+/// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously.
+async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> {
+ // First, try to get existing user
+ let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1")
+ .bind(user_id)
+ .fetch_optional(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+
+ if let Some(row) = row {
+ let owner_id: Option<Uuid> = row.try_get("default_owner_id")
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+ return owner_id.ok_or(AuthError::UserNotFound);
+ }
+
+ // User doesn't exist - auto-create on first login
+ tracing::info!("Creating new user record for {}", user_id);
+
+ // Create owner first (use ON CONFLICT to handle race conditions)
+ let owner_id = Uuid::new_v4();
+ sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING")
+ .bind(owner_id)
+ .bind(email.unwrap_or("Unknown"))
+ .execute(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?;
+
+ // Create user with reference to owner (use ON CONFLICT to handle race conditions)
+ sqlx::query(
+ "INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING"
+ )
+ .bind(user_id)
+ .bind(email)
+ .bind(owner_id)
+ .execute(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?;
+
+ // Re-fetch the user to get the actual owner_id (in case another request created it first)
+ let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1")
+ .bind(user_id)
+ .fetch_optional(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+
+ match row {
+ Some(row) => {
+ let owner_id: Option<Uuid> = row.try_get("default_owner_id")
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+ owner_id.ok_or(AuthError::UserNotFound)
+ }
+ None => Err(AuthError::DatabaseError("Failed to create user record".to_string()))
+ }
+}
+
+/// Validate an API key and return (user_id, owner_id).
+async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> {
+ let key_hash = hash_api_key(key);
+
+ // Look up the API key and join with users to get owner_id
+ let row = sqlx::query(
+ r#"
+ SELECT ak.user_id, u.default_owner_id
+ FROM api_keys ak
+ JOIN users u ON u.id = ak.user_id
+ WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL
+ "#,
+ )
+ .bind(&key_hash)
+ .fetch_optional(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+
+ match row {
+ Some(row) => {
+ let user_id: Uuid = row.try_get("user_id")
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+ let owner_id: Option<Uuid> = row.try_get("default_owner_id")
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+ let owner_id = owner_id.ok_or(AuthError::UserNotFound)?;
+
+ // Update last_used_at asynchronously (fire and forget)
+ let pool_clone = pool.clone();
+ let key_hash_clone = key_hash.clone();
+ tokio::spawn(async move {
+ let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1")
+ .bind(&key_hash_clone)
+ .execute(&pool_clone)
+ .await;
+ });
+
+ Ok((user_id, owner_id))
+ }
+ None => Err(AuthError::InvalidApiKey),
+ }
+}
+
+/// Extract authentication from request headers.
+///
+/// Tries authentication methods in order:
+/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators
+/// 2. API Key (X-Makima-API-Key) - for daemons/CLI
+/// 3. JWT (Authorization: Bearer) - for web clients
+async fn extract_auth(
+ state: &SharedState,
+ headers: &HeaderMap,
+) -> Result<AuthenticatedUser, AuthError> {
+ // 1. Check for tool key (orchestrator access)
+ if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) {
+ if let Ok(key_str) = tool_key.to_str() {
+ if let Some(task_id) = state.validate_tool_key(key_str) {
+ // Tool keys are trusted - use a placeholder user/owner for orchestrator actions
+ // The orchestrator inherits the owner_id from its task
+ let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
+
+ // Get owner_id from the task
+ let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
+ .bind(task_id)
+ .fetch_optional(pool)
+ .await
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?
+ .ok_or(AuthError::UserNotFound)?;
+
+ let task_owner: Uuid = row.try_get("owner_id")
+ .map_err(|e| AuthError::DatabaseError(e.to_string()))?;
+
+ return Ok(AuthenticatedUser {
+ user_id: Uuid::nil(), // Tool keys don't have a user
+ owner_id: task_owner,
+ auth_source: AuthSource::ToolKey(task_id),
+ email: None,
+ });
+ }
+ tracing::warn!("Invalid tool key provided");
+ }
+ }
+
+ // 2. Check for API key
+ if let Some(api_key) = headers.get(API_KEY_HEADER) {
+ if let Ok(key_str) = api_key.to_str() {
+ let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
+ let (user_id, owner_id) = validate_api_key(pool, key_str).await?;
+
+ return Ok(AuthenticatedUser {
+ user_id,
+ owner_id,
+ auth_source: AuthSource::ApiKey,
+ email: None,
+ });
+ }
+ }
+
+ // 3. Check for JWT (Bearer token)
+ if let Some(auth_header) = headers.get(AUTHORIZATION) {
+ if let Ok(auth_str) = auth_header.to_str() {
+ if let Some(token) = auth_str.strip_prefix("Bearer ") {
+ let verifier = state
+ .jwt_verifier
+ .as_ref()
+ .ok_or(AuthError::NotConfigured)?;
+
+ let claims = verifier.verify(token)?;
+ let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
+ let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?;
+
+ return Ok(AuthenticatedUser {
+ user_id: claims.sub,
+ owner_id,
+ auth_source: AuthSource::Jwt,
+ email: claims.email,
+ });
+ }
+ }
+ }
+
+ Err(AuthError::MissingToken)
+}
+
+// =============================================================================
+// Extractors
+// =============================================================================
+
+/// Extractor for authenticated requests.
+///
+/// Tries authentication methods in order:
+/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators
+/// 2. API Key (X-Makima-API-Key) - for daemons/CLI
+/// 3. JWT (Authorization: Bearer) - for web clients
+///
+/// Returns 401 Unauthorized if no valid authentication is found.
+///
+/// # Example
+/// ```ignore
+/// async fn protected_handler(
+/// Authenticated(user): Authenticated,
+/// ) -> impl IntoResponse {
+/// Json(format!("Hello user {}", user.user_id))
+/// }
+/// ```
+pub struct Authenticated(pub AuthenticatedUser);
+
+impl FromRequestParts<SharedState> for Authenticated {
+ type Rejection = AuthError;
+
+ async fn from_request_parts(
+ parts: &mut Parts,
+ state: &SharedState,
+ ) -> Result<Self, Self::Rejection> {
+ let user = extract_auth(state, &parts.headers).await?;
+ Ok(Authenticated(user))
+ }
+}
+
+/// Extractor for user-only authentication (JWT or API key, no tool keys).
+///
+/// Use this for endpoints that should only be accessible to actual users,
+/// not orchestrators with tool keys.
+///
+/// Returns 401 Unauthorized if no valid user authentication is found.
+/// Returns 403 Forbidden if a tool key is used.
+///
+/// # Example
+/// ```ignore
+/// async fn user_profile(
+/// UserOnly(user): UserOnly,
+/// ) -> impl IntoResponse {
+/// // Only actual users can access this
+/// Json(format!("User profile for {}", user.user_id))
+/// }
+/// ```
+pub struct UserOnly(pub AuthenticatedUser);
+
+impl FromRequestParts<SharedState> for UserOnly {
+ type Rejection = AuthError;
+
+ async fn from_request_parts(
+ parts: &mut Parts,
+ state: &SharedState,
+ ) -> Result<Self, Self::Rejection> {
+ let user = extract_auth(state, &parts.headers).await?;
+
+ // Reject tool key authentication
+ if matches!(user.auth_source, AuthSource::ToolKey(_)) {
+ return Err(AuthError::InsufficientPermissions);
+ }
+
+ Ok(UserOnly(user))
+ }
+}
+
+/// Extractor for optional authentication.
+///
+/// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise.
+/// Never returns an error - invalid auth is treated as no auth.
+///
+/// # Example
+/// ```ignore
+/// async fn public_or_private(
+/// MaybeAuthenticated(user): MaybeAuthenticated,
+/// ) -> impl IntoResponse {
+/// match user {
+/// Some(u) => Json(format!("Hello {}", u.user_id)),
+/// None => Json("Hello anonymous".to_string()),
+/// }
+/// }
+/// ```
+pub struct MaybeAuthenticated(pub Option<AuthenticatedUser>);
+
+impl FromRequestParts<SharedState> for MaybeAuthenticated {
+ type Rejection = std::convert::Infallible;
+
+ async fn from_request_parts(
+ parts: &mut Parts,
+ state: &SharedState,
+ ) -> Result<Self, Self::Rejection> {
+ let user = extract_auth(state, &parts.headers).await.ok();
+ Ok(MaybeAuthenticated(user))
+ }
+}
+
+// =============================================================================
+// Tests
+// =============================================================================
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_hash_api_key() {
+ let key = "mk_test123456789";
+ let hash = hash_api_key(key);
+
+ // Hash should be consistent
+ assert_eq!(hash, hash_api_key(key));
+
+ // Hash should be 64 characters (SHA-256 hex)
+ assert_eq!(hash.len(), 64);
+ }
+
+ #[test]
+ fn test_auth_error_display() {
+ assert_eq!(
+ AuthError::MissingToken.to_string(),
+ "Missing authentication token"
+ );
+ assert_eq!(
+ AuthError::InvalidToken("bad".to_string()).to_string(),
+ "Invalid token: bad"
+ );
+ }
+
+ #[test]
+ fn test_generate_api_key_format() {
+ let generated = generate_api_key();
+
+ // Full key should start with mk_ prefix
+ assert!(generated.full_key.starts_with(API_KEY_PREFIX));
+
+ // Full key should be mk_ + 43 chars (32 bytes base64url encoded)
+ assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43
+
+ // Prefix should be mk_ + first 8 chars
+ assert!(generated.key_prefix.starts_with(API_KEY_PREFIX));
+ assert_eq!(generated.key_prefix.len(), 3 + 8);
+
+ // Hash should be 64 hex chars (SHA-256)
+ assert_eq!(generated.key_hash.len(), 64);
+ }
+
+ #[test]
+ fn test_generate_api_key_uniqueness() {
+ let key1 = generate_api_key();
+ let key2 = generate_api_key();
+
+ // Keys should be unique
+ assert_ne!(key1.full_key, key2.full_key);
+ assert_ne!(key1.key_hash, key2.key_hash);
+ }
+
+ #[test]
+ fn test_api_key_cache_basic() {
+ let cache = ApiKeyCache::new(300);
+ let user_id = Uuid::new_v4();
+ let owner_id = Uuid::new_v4();
+ let key_hash = "test_hash_123";
+
+ // Cache miss initially
+ assert!(cache.get(key_hash).is_none());
+
+ // Set and verify cache hit
+ cache.set(key_hash.to_string(), user_id, owner_id);
+ let result = cache.get(key_hash);
+ assert!(result.is_some());
+ let (cached_user, cached_owner) = result.unwrap();
+ assert_eq!(cached_user, user_id);
+ assert_eq!(cached_owner, owner_id);
+ }
+
+ #[test]
+ fn test_api_key_cache_invalidate() {
+ let cache = ApiKeyCache::new(300);
+ let user_id = Uuid::new_v4();
+ let owner_id = Uuid::new_v4();
+ let key_hash = "test_hash_456";
+
+ cache.set(key_hash.to_string(), user_id, owner_id);
+ assert!(cache.get(key_hash).is_some());
+
+ cache.invalidate(key_hash);
+ assert!(cache.get(key_hash).is_none());
+ }
+
+ #[test]
+ fn test_api_key_cache_clear() {
+ let cache = ApiKeyCache::new(300);
+
+ cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4());
+ cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4());
+
+ assert!(cache.get("hash1").is_some());
+ assert!(cache.get("hash2").is_some());
+
+ cache.clear();
+
+ assert!(cache.get("hash1").is_none());
+ assert!(cache.get("hash2").is_none());
+ }
+
+ #[test]
+ fn test_api_key_event_type_display() {
+ assert_eq!(ApiKeyEventType::Created.to_string(), "created");
+ assert_eq!(ApiKeyEventType::Used.to_string(), "used");
+ assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked");
+ assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed");
+ }
+}