//! Authentication module for Makima server.
//!
//! Supports multiple authentication methods:
//! - Supabase JWT tokens for web clients (ES256 or RS256 public key verification)
//! - API keys for programmatic access (daemons, CLI)
//! - Tool keys for orchestrator internal access
use axum::{
extract::FromRequestParts,
http::{header::AUTHORIZATION, request::Parts, HeaderMap, StatusCode},
response::IntoResponse,
Json,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sqlx::{FromRow, PgPool, Row};
use std::time::{Duration, Instant};
use utoipa::ToSchema;
use uuid::Uuid;
use crate::server::messages::ApiError;
use crate::server::state::SharedState;
// =============================================================================
// Configuration
// =============================================================================
/// JWT algorithm configuration.
#[derive(Debug, Clone)]
pub enum JwtAlgorithm {
/// RS256 with RSA public key
Rs256 { public_key: String },
/// ES256 with ECDSA public key (Supabase projects with JWT Signing Keys)
Es256 { public_key: String },
}
/// Authentication configuration loaded from environment.
#[derive(Debug, Clone)]
pub struct AuthConfig {
/// Supabase project URL (e.g., https://your-project.supabase.co)
pub supabase_url: String,
/// JWT algorithm and key material
pub algorithm: JwtAlgorithm,
}
impl AuthConfig {
/// Load auth config from environment variables.
///
/// Supports two modes (checked in order):
/// - ES256: Set SUPABASE_URL and SUPABASE_JWT_PUBLIC_KEY (Supabase with ECDSA)
/// - RS256: Set SUPABASE_URL and SUPABASE_JWT_RSA_PUBLIC_KEY (RSA public key)
///
/// Returns None if auth is not configured.
pub fn from_env() -> Option<Self> {
let supabase_url = std::env::var("SUPABASE_URL").ok()?;
// Try ES256 first (default for Supabase), then RS256
let algorithm = if let Ok(public_key) = std::env::var("SUPABASE_JWT_PUBLIC_KEY") {
tracing::info!("Using ES256 JWT verification with ECDSA public key");
JwtAlgorithm::Es256 { public_key }
} else if let Ok(public_key) = std::env::var("SUPABASE_JWT_RSA_PUBLIC_KEY") {
tracing::info!("Using RS256 JWT verification with RSA public key");
JwtAlgorithm::Rs256 { public_key }
} else {
return None;
};
Some(Self {
supabase_url,
algorithm,
})
}
}
// =============================================================================
// JWT Claims
// =============================================================================
/// JWT claims from Supabase Auth tokens.
#[derive(Debug, Serialize, Deserialize)]
pub struct SupabaseClaims {
/// Audience (e.g., "authenticated")
pub aud: String,
/// Expiration time (Unix timestamp)
pub exp: i64,
/// Issued at (Unix timestamp)
pub iat: i64,
/// Issuer (Supabase project URL + /auth/v1)
pub iss: String,
/// Subject (user ID)
pub sub: Uuid,
/// User's email
pub email: Option<String>,
/// User's phone
pub phone: Option<String>,
/// App metadata (set by server/admin)
pub app_metadata: Option<serde_json::Value>,
/// User metadata (set by user)
pub user_metadata: Option<serde_json::Value>,
/// Role (e.g., "authenticated")
pub role: Option<String>,
/// Session ID
pub session_id: Option<Uuid>,
}
// =============================================================================
// JWT Verifier
// =============================================================================
/// JWT verifier for Supabase tokens.
pub struct JwtVerifier {
supabase_url: String,
decoding_key: DecodingKey,
algorithm: Algorithm,
}
impl JwtVerifier {
/// Create a new JWT verifier from auth config.
///
/// Supports multiple key formats:
/// - JWK (JSON Web Key) - detected by presence of `{`
/// - PEM - detected by `-----BEGIN`
/// - Base64-encoded DER - fallback
pub fn new(config: AuthConfig) -> Result<Self, AuthError> {
let (decoding_key, algorithm) = match &config.algorithm {
JwtAlgorithm::Rs256 { public_key } => {
let key = Self::parse_public_key(public_key, "RSA")?;
(key, Algorithm::RS256)
}
JwtAlgorithm::Es256 { public_key } => {
let key = Self::parse_public_key(public_key, "EC")?;
(key, Algorithm::ES256)
}
};
Ok(Self {
supabase_url: config.supabase_url,
decoding_key,
algorithm,
})
}
/// Parse a public key from various formats (JWK, JWKS, PEM, or base64 DER).
fn parse_public_key(key_data: &str, key_type: &str) -> Result<DecodingKey, AuthError> {
let trimmed = key_data.trim();
// Check for JSON format (JWK or JWKS)
if trimmed.starts_with('{') {
// First try to parse as a generic JSON value to inspect structure
let mut json_value: serde_json::Value = serde_json::from_str(trimmed)
.map_err(|e| AuthError::InvalidToken(format!("Invalid JSON: {}", e)))?;
// Check if it's a JWKS (has "keys" array)
if let Some(keys) = json_value.get_mut("keys").and_then(|k| k.as_array_mut()) {
// Find the first signing key (or just use the first key)
let jwk_value = keys.first_mut()
.ok_or_else(|| AuthError::InvalidToken("JWKS has no keys".to_string()))?;
// Remove private key component if present (user may have pasted full keypair)
if let Some(obj) = jwk_value.as_object_mut() {
if obj.remove("d").is_some() {
tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification");
}
}
let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(jwk_value.clone())
.map_err(|e| AuthError::InvalidToken(format!("Invalid JWK in JWKS: {}", e)))?;
tracing::info!("Loaded JWT public key from JWKS (first key)");
return DecodingKey::from_jwk(&jwk)
.map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e)));
}
// Remove private key component if present (user may have pasted full keypair)
if let Some(obj) = json_value.as_object_mut() {
if obj.remove("d").is_some() {
tracing::warn!("Removed private key component 'd' from JWK - only public key is needed for verification");
}
}
// Try as single JWK
let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(json_value)
.map_err(|e| AuthError::InvalidToken(format!("Invalid JWK: {}", e)))?;
tracing::info!("Loaded JWT public key from JWK");
DecodingKey::from_jwk(&jwk)
.map_err(|e| AuthError::InvalidToken(format!("Failed to create key from JWK: {}", e)))
}
// Check for PEM format
else if trimmed.contains("-----BEGIN") {
tracing::info!("Loaded JWT public key from PEM");
match key_type {
"RSA" => DecodingKey::from_rsa_pem(trimmed.as_bytes())
.map_err(|e| AuthError::InvalidToken(format!("Invalid RSA PEM key: {}", e))),
"EC" => DecodingKey::from_ec_pem(trimmed.as_bytes())
.map_err(|e| AuthError::InvalidToken(format!("Invalid EC PEM key: {}", e))),
_ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))),
}
}
// Assume base64-encoded DER
else {
tracing::info!("Loaded JWT public key from base64 DER");
let der_bytes = base64::engine::general_purpose::STANDARD
.decode(trimmed)
.map_err(|e| AuthError::InvalidToken(format!("Invalid base64 key: {}", e)))?;
match key_type {
"RSA" => Ok(DecodingKey::from_rsa_der(&der_bytes)),
"EC" => Ok(DecodingKey::from_ec_der(&der_bytes)),
_ => Err(AuthError::InvalidToken(format!("Unknown key type: {}", key_type))),
}
}
}
/// Verify a JWT token and return claims.
pub fn verify(&self, token: &str) -> Result<SupabaseClaims, AuthError> {
// Decode header to check algorithm mismatch
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::InvalidToken(format!("Invalid JWT header: {}", e)))?;
tracing::debug!(
"JWT header: algorithm={:?}, typ={:?}, kid={:?}",
header.alg,
header.typ,
header.kid
);
if header.alg != self.algorithm {
let hint = match header.alg {
Algorithm::ES256 => "Set SUPABASE_JWT_PUBLIC_KEY with the EC public key from Supabase Dashboard → Project Settings → API → JWT Settings",
Algorithm::RS256 => "Set SUPABASE_JWT_RSA_PUBLIC_KEY with the RSA public key",
_ => "Check your Supabase JWT configuration - only ES256 and RS256 are supported",
};
tracing::warn!(
"JWT algorithm mismatch: token uses {:?}, server configured for {:?}. {}",
header.alg,
self.algorithm,
hint
);
return Err(AuthError::InvalidToken(format!(
"Algorithm mismatch: token is {:?}, expected {:?}",
header.alg, self.algorithm
)));
}
let mut validation = Validation::new(self.algorithm);
validation.set_audience(&["authenticated"]);
validation.set_issuer(&[format!("{}/auth/v1", self.supabase_url)]);
// First try with full validation
let token_data = match decode::<SupabaseClaims>(token, &self.decoding_key, &validation) {
Ok(data) => data,
Err(e) => {
// Log detailed error info
tracing::warn!(
"JWT verification failed: {} (algorithm: {:?}, issuer: {}/auth/v1)",
e,
self.algorithm,
self.supabase_url
);
// If it's InvalidAlgorithm, try to understand why by decoding payload manually
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidAlgorithm) {
// Decode the payload part of the JWT manually (base64)
let parts: Vec<&str> = token.split('.').collect();
if parts.len() >= 2 {
if let Ok(payload_bytes) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(parts[1]) {
if let Ok(payload_str) = String::from_utf8(payload_bytes) {
if let Ok(claims) = serde_json::from_str::<serde_json::Value>(&payload_str) {
tracing::warn!(
"JWT payload (unverified): iss={:?}, aud={:?}, sub={:?}",
claims.get("iss"),
claims.get("aud"),
claims.get("sub")
);
}
}
}
}
}
return Err(AuthError::InvalidToken(e.to_string()));
}
};
Ok(token_data.claims)
}
/// Extract user ID from a token.
pub fn get_user_id(&self, token: &str) -> Result<Uuid, AuthError> {
let claims = self.verify(token)?;
Ok(claims.sub)
}
}
// =============================================================================
// Auth Error
// =============================================================================
/// Authentication error types.
#[derive(Debug)]
pub enum AuthError {
/// No authentication token provided
MissingToken,
/// Token format is invalid
InvalidToken(String),
/// Token has expired
ExpiredToken,
/// User not found in database
UserNotFound,
/// API key is invalid or revoked
InvalidApiKey,
/// Database error during auth lookup
DatabaseError(String),
/// Authentication is not configured
NotConfigured,
/// Insufficient permissions for the operation
InsufficientPermissions,
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthError::MissingToken => write!(f, "Missing authentication token"),
AuthError::InvalidToken(msg) => write!(f, "Invalid token: {}", msg),
AuthError::ExpiredToken => write!(f, "Token has expired"),
AuthError::UserNotFound => write!(f, "User not found"),
AuthError::InvalidApiKey => write!(f, "Invalid or revoked API key"),
AuthError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
AuthError::NotConfigured => write!(f, "Authentication not configured"),
AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"),
}
}
}
impl std::error::Error for AuthError {}
impl IntoResponse for AuthError {
fn into_response(self) -> axum::response::Response {
let (status, code, message) = match &self {
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "MISSING_TOKEN", "Authentication required"),
AuthError::InvalidToken(_) => (StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid authentication token"),
AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "EXPIRED_TOKEN", "Token has expired"),
AuthError::UserNotFound => (StatusCode::UNAUTHORIZED, "USER_NOT_FOUND", "User not found"),
AuthError::InvalidApiKey => (StatusCode::UNAUTHORIZED, "INVALID_API_KEY", "Invalid or revoked API key"),
AuthError::DatabaseError(_) => (StatusCode::INTERNAL_SERVER_ERROR, "DB_ERROR", "Database error"),
AuthError::NotConfigured => (StatusCode::SERVICE_UNAVAILABLE, "AUTH_NOT_CONFIGURED", "Authentication not configured"),
AuthError::InsufficientPermissions => (StatusCode::FORBIDDEN, "FORBIDDEN", "Insufficient permissions"),
};
(status, Json(ApiError::new(code, message))).into_response()
}
}
// =============================================================================
// Auth Source
// =============================================================================
/// Source of authentication.
#[derive(Debug, Clone)]
pub enum AuthSource {
/// Authenticated via Supabase JWT (web client)
Jwt,
/// Authenticated via API key (daemon, CLI, integrations)
ApiKey,
/// Authenticated via tool key (orchestrator internal access)
ToolKey(Uuid),
}
// =============================================================================
// Authenticated User
// =============================================================================
/// Authenticated user context extracted from request.
///
/// Contains the resolved user_id and owner_id for database operations.
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
/// Supabase auth user ID (from auth.users)
pub user_id: Uuid,
/// Owner ID for data isolation (from users.default_owner_id)
pub owner_id: Uuid,
/// How the user was authenticated
pub auth_source: AuthSource,
/// User's email (if available)
pub email: Option<String>,
}
// =============================================================================
// Header Constants
// =============================================================================
/// Header name for tool key authentication (orchestrators).
pub const TOOL_KEY_HEADER: &str = "x-makima-tool-key";
/// Header name for API key authentication.
pub const API_KEY_HEADER: &str = "x-makima-api-key";
// =============================================================================
// Helper Functions
// =============================================================================
/// Hash an API key for database lookup.
pub fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hex::encode(hasher.finalize())
}
// =============================================================================
// API Key Generation
// =============================================================================
/// API key prefix for identification.
pub const API_KEY_PREFIX: &str = "mk_";
/// Result of generating an API key.
pub struct GeneratedApiKey {
/// The full API key (shown only once to user)
pub full_key: String,
/// SHA-256 hash of the key (stored in database)
pub key_hash: String,
/// Prefix for display (first 8 chars after mk_)
pub key_prefix: String,
}
/// Generate a new API key with mk_ prefix.
///
/// Returns the full key (to show once), hash (to store), and prefix (for display).
pub fn generate_api_key() -> GeneratedApiKey {
let mut rng = rand::thread_rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
let key_bytes = URL_SAFE_NO_PAD.encode(bytes);
let full_key = format!("{}{}", API_KEY_PREFIX, key_bytes);
let key_hash = hash_api_key(&full_key);
let key_prefix = format!("{}{}", API_KEY_PREFIX, &key_bytes[..8]);
GeneratedApiKey {
full_key,
key_hash,
key_prefix,
}
}
// =============================================================================
// API Key Cache
// =============================================================================
/// Cache entry for validated API keys.
struct ApiKeyCacheEntry {
user_id: Uuid,
owner_id: Uuid,
cached_at: Instant,
}
/// In-memory cache for API key validation to avoid database lookups on every request.
pub struct ApiKeyCache {
/// key_hash -> (user_id, owner_id, cached_at)
cache: DashMap<String, ApiKeyCacheEntry>,
/// Time-to-live for cache entries
ttl: Duration,
}
impl ApiKeyCache {
/// Create a new cache with the specified TTL in seconds.
pub fn new(ttl_seconds: u64) -> Self {
Self {
cache: DashMap::new(),
ttl: Duration::from_secs(ttl_seconds),
}
}
/// Get cached user_id and owner_id for a key hash, if not expired.
pub fn get(&self, key_hash: &str) -> Option<(Uuid, Uuid)> {
self.cache.get(key_hash).and_then(|entry| {
if entry.cached_at.elapsed() < self.ttl {
Some((entry.user_id, entry.owner_id))
} else {
None
}
})
}
/// Cache a validated API key.
pub fn set(&self, key_hash: String, user_id: Uuid, owner_id: Uuid) {
self.cache.insert(
key_hash,
ApiKeyCacheEntry {
user_id,
owner_id,
cached_at: Instant::now(),
},
);
}
/// Invalidate a cache entry (e.g., on key revocation).
pub fn invalidate(&self, key_hash: &str) {
self.cache.remove(key_hash);
}
/// Clear all cache entries.
pub fn clear(&self) {
self.cache.clear();
}
}
impl Default for ApiKeyCache {
fn default() -> Self {
// Default TTL: 5 minutes
Self::new(300)
}
}
// =============================================================================
// API Key Models
// =============================================================================
/// API key record from the database.
#[derive(Debug, Clone, FromRow, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ApiKey {
pub id: Uuid,
pub user_id: Uuid,
#[serde(skip)]
pub key_hash: String,
pub key_prefix: String,
pub name: Option<String>,
pub last_used_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub revoked_at: Option<DateTime<Utc>>,
}
/// Request to create a new API key.
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct CreateApiKeyRequest {
/// User-provided label for the key
pub name: Option<String>,
}
/// Response after creating an API key (includes the full key - shown only once).
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct CreateApiKeyResponse {
pub id: Uuid,
/// The full API key - save this, it won't be shown again!
pub key: String,
pub prefix: String,
pub name: Option<String>,
pub created_at: DateTime<Utc>,
}
/// Response for getting API key info (excludes the full key).
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ApiKeyInfoResponse {
pub id: Uuid,
pub prefix: String,
pub name: Option<String>,
pub last_used_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
impl From<ApiKey> for ApiKeyInfoResponse {
fn from(key: ApiKey) -> Self {
Self {
id: key.id,
prefix: key.key_prefix,
name: key.name,
last_used_at: key.last_used_at,
created_at: key.created_at,
}
}
}
/// Request to refresh an API key.
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct RefreshApiKeyRequest {
/// New name for the refreshed key
pub name: Option<String>,
}
/// Response after refreshing an API key.
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct RefreshApiKeyResponse {
pub id: Uuid,
/// The new API key - save this, it won't be shown again!
pub key: String,
pub prefix: String,
pub name: Option<String>,
pub created_at: DateTime<Utc>,
pub previous_key_revoked: bool,
}
/// Response after revoking an API key.
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct RevokeApiKeyResponse {
pub message: String,
pub revoked_key_prefix: String,
}
/// API key event types for audit logging.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiKeyEventType {
Created,
Used,
Revoked,
Refreshed,
}
impl std::fmt::Display for ApiKeyEventType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ApiKeyEventType::Created => write!(f, "created"),
ApiKeyEventType::Used => write!(f, "used"),
ApiKeyEventType::Revoked => write!(f, "revoked"),
ApiKeyEventType::Refreshed => write!(f, "refreshed"),
}
}
}
// =============================================================================
// API Keys Repository
// =============================================================================
/// Repository error for API key operations.
#[derive(Debug)]
pub enum ApiKeyError {
/// Database error
Database(sqlx::Error),
/// An active API key already exists for this user
KeyAlreadyExists,
/// No active API key found for this user
KeyNotFound,
}
impl std::fmt::Display for ApiKeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ApiKeyError::Database(e) => write!(f, "Database error: {}", e),
ApiKeyError::KeyAlreadyExists => write!(f, "An active API key already exists"),
ApiKeyError::KeyNotFound => write!(f, "No active API key found"),
}
}
}
impl std::error::Error for ApiKeyError {}
impl From<sqlx::Error> for ApiKeyError {
fn from(e: sqlx::Error) -> Self {
ApiKeyError::Database(e)
}
}
/// Get the active API key for a user (if any).
pub async fn get_active_api_key(pool: &PgPool, user_id: Uuid) -> Result<Option<ApiKey>, sqlx::Error> {
sqlx::query_as::<_, ApiKey>(
r#"
SELECT id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
FROM api_keys
WHERE user_id = $1 AND revoked_at IS NULL
"#,
)
.bind(user_id)
.fetch_optional(pool)
.await
}
/// Create a new API key for a user.
///
/// Returns an error if the user already has an active key.
/// The `generated` parameter should be created using `generate_api_key()`.
pub async fn create_api_key(
pool: &PgPool,
user_id: Uuid,
generated: &GeneratedApiKey,
name: Option<&str>,
) -> Result<ApiKey, ApiKeyError> {
// Check if user already has an active key
if let Some(_) = get_active_api_key(pool, user_id).await? {
return Err(ApiKeyError::KeyAlreadyExists);
}
let key = sqlx::query_as::<_, ApiKey>(
r#"
INSERT INTO api_keys (user_id, key_hash, key_prefix, name)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
"#,
)
.bind(user_id)
.bind(&generated.key_hash)
.bind(&generated.key_prefix)
.bind(name)
.fetch_one(pool)
.await?;
// Log the creation event
let _ = log_api_key_event(pool, key.id, ApiKeyEventType::Created, None, None).await;
Ok(key)
}
/// Revoke an API key by marking it with revoked_at timestamp.
pub async fn revoke_api_key(pool: &PgPool, user_id: Uuid) -> Result<ApiKey, ApiKeyError> {
// Get the active key first
let key = get_active_api_key(pool, user_id)
.await?
.ok_or(ApiKeyError::KeyNotFound)?;
// Revoke it
let revoked = sqlx::query_as::<_, ApiKey>(
r#"
UPDATE api_keys
SET revoked_at = NOW()
WHERE id = $1
RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
"#,
)
.bind(key.id)
.fetch_one(pool)
.await?;
// Log the revocation event
let _ = log_api_key_event(pool, revoked.id, ApiKeyEventType::Revoked, None, None).await;
Ok(revoked)
}
/// Refresh an API key: revoke the old one and create a new one atomically.
///
/// Returns the new key. The caller should use `generate_api_key()` to create
/// the `new_generated` parameter.
pub async fn refresh_api_key(
pool: &PgPool,
user_id: Uuid,
new_generated: &GeneratedApiKey,
new_name: Option<&str>,
) -> Result<(ApiKey, Option<String>), ApiKeyError> {
// Get and revoke the old key (if exists)
let old_prefix = if let Some(old_key) = get_active_api_key(pool, user_id).await? {
let old_prefix = old_key.key_prefix.clone();
// Revoke the old key
sqlx::query(
r#"
UPDATE api_keys
SET revoked_at = NOW()
WHERE id = $1
"#,
)
.bind(old_key.id)
.execute(pool)
.await?;
// Log the refresh event on the old key
let _ = log_api_key_event(pool, old_key.id, ApiKeyEventType::Refreshed, None, None).await;
Some(old_prefix)
} else {
None
};
// Create the new key
let new_key = sqlx::query_as::<_, ApiKey>(
r#"
INSERT INTO api_keys (user_id, key_hash, key_prefix, name)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, key_hash, key_prefix, name, last_used_at, created_at, revoked_at
"#,
)
.bind(user_id)
.bind(&new_generated.key_hash)
.bind(&new_generated.key_prefix)
.bind(new_name)
.fetch_one(pool)
.await?;
// Log the creation event on the new key
let _ = log_api_key_event(pool, new_key.id, ApiKeyEventType::Created, None, None).await;
Ok((new_key, old_prefix))
}
/// Update last_used_at timestamp for an API key.
pub async fn update_api_key_last_used(pool: &PgPool, key_hash: &str) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE api_keys
SET last_used_at = NOW()
WHERE key_hash = $1 AND revoked_at IS NULL
"#,
)
.bind(key_hash)
.execute(pool)
.await?;
Ok(())
}
/// Log an API key event for audit purposes.
pub async fn log_api_key_event(
pool: &PgPool,
api_key_id: Uuid,
event_type: ApiKeyEventType,
ip_address: Option<&str>,
user_agent: Option<&str>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
INSERT INTO api_key_events (api_key_id, event_type, ip_address, user_agent)
VALUES ($1, $2, $3::inet, $4)
"#,
)
.bind(api_key_id)
.bind(event_type.to_string())
.bind(ip_address)
.bind(user_agent)
.execute(pool)
.await?;
Ok(())
}
// =============================================================================
// Internal Helper Functions
// =============================================================================
/// Public wrapper for resolve_owner_id, used by SSE endpoints that authenticate via query params.
pub async fn resolve_owner_id_public(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> {
resolve_owner_id(pool, user_id, email).await
}
/// Resolve owner_id from user_id by looking up the users table.
/// If the user doesn't exist, auto-creates them on first login.
/// Uses ON CONFLICT to handle race conditions when multiple requests arrive simultaneously.
async fn resolve_owner_id(pool: &PgPool, user_id: Uuid, email: Option<&str>) -> Result<Uuid, AuthError> {
// First, try to get existing user
let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(pool)
.await
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
if let Some(row) = row {
let owner_id: Option<Uuid> = row.try_get("default_owner_id")
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
return owner_id.ok_or(AuthError::UserNotFound);
}
// User doesn't exist - auto-create on first login
tracing::info!("Creating new user record for {}", user_id);
// Create owner first (use ON CONFLICT to handle race conditions)
let owner_id = Uuid::new_v4();
sqlx::query("INSERT INTO owners (id, name) VALUES ($1, $2) ON CONFLICT DO NOTHING")
.bind(owner_id)
.bind(email.unwrap_or("Unknown"))
.execute(pool)
.await
.map_err(|e| AuthError::DatabaseError(format!("Failed to create owner: {}", e)))?;
// Create user with reference to owner (use ON CONFLICT to handle race conditions)
sqlx::query(
"INSERT INTO users (id, email, default_owner_id) VALUES ($1, $2, $3) ON CONFLICT (id) DO NOTHING"
)
.bind(user_id)
.bind(email)
.bind(owner_id)
.execute(pool)
.await
.map_err(|e| AuthError::DatabaseError(format!("Failed to create user: {}", e)))?;
// Re-fetch the user to get the actual owner_id (in case another request created it first)
let row = sqlx::query("SELECT default_owner_id FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(pool)
.await
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
match row {
Some(row) => {
let owner_id: Option<Uuid> = row.try_get("default_owner_id")
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
owner_id.ok_or(AuthError::UserNotFound)
}
None => Err(AuthError::DatabaseError("Failed to create user record".to_string()))
}
}
/// Public wrapper for validate_api_key, used by SSE endpoints that authenticate via query params.
pub async fn validate_api_key_public(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> {
validate_api_key(pool, key).await
}
/// Validate an API key and return (user_id, owner_id).
async fn validate_api_key(pool: &PgPool, key: &str) -> Result<(Uuid, Uuid), AuthError> {
let key_hash = hash_api_key(key);
// Look up the API key and join with users to get owner_id
let row = sqlx::query(
r#"
SELECT ak.user_id, u.default_owner_id
FROM api_keys ak
JOIN users u ON u.id = ak.user_id
WHERE ak.key_hash = $1 AND ak.revoked_at IS NULL
"#,
)
.bind(&key_hash)
.fetch_optional(pool)
.await
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
match row {
Some(row) => {
let user_id: Uuid = row.try_get("user_id")
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
let owner_id: Option<Uuid> = row.try_get("default_owner_id")
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
let owner_id = owner_id.ok_or(AuthError::UserNotFound)?;
// Update last_used_at asynchronously (fire and forget)
let pool_clone = pool.clone();
let key_hash_clone = key_hash.clone();
tokio::spawn(async move {
let _ = sqlx::query("UPDATE api_keys SET last_used_at = NOW() WHERE key_hash = $1")
.bind(&key_hash_clone)
.execute(&pool_clone)
.await;
});
Ok((user_id, owner_id))
}
None => Err(AuthError::InvalidApiKey),
}
}
/// Extract authentication from request headers.
///
/// Tries authentication methods in order:
/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators
/// 2. API Key (X-Makima-API-Key) - for daemons/CLI
/// 3. JWT (Authorization: Bearer) - for web clients
async fn extract_auth(
state: &SharedState,
headers: &HeaderMap,
) -> Result<AuthenticatedUser, AuthError> {
// 1. Check for tool key (orchestrator access)
if let Some(tool_key) = headers.get(TOOL_KEY_HEADER) {
if let Ok(key_str) = tool_key.to_str() {
if let Some(task_id) = state.validate_tool_key(key_str) {
// Tool keys are trusted - use a placeholder user/owner for orchestrator actions
// The orchestrator inherits the owner_id from its task
let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
// Get owner_id from the task
let row = sqlx::query("SELECT owner_id FROM tasks WHERE id = $1")
.bind(task_id)
.fetch_optional(pool)
.await
.map_err(|e| AuthError::DatabaseError(e.to_string()))?
.ok_or(AuthError::UserNotFound)?;
let task_owner: Uuid = row.try_get("owner_id")
.map_err(|e| AuthError::DatabaseError(e.to_string()))?;
return Ok(AuthenticatedUser {
user_id: Uuid::nil(), // Tool keys don't have a user
owner_id: task_owner,
auth_source: AuthSource::ToolKey(task_id),
email: None,
});
}
tracing::warn!("Invalid tool key provided");
}
}
// 2. Check for API key
if let Some(api_key) = headers.get(API_KEY_HEADER) {
if let Ok(key_str) = api_key.to_str() {
let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
let (user_id, owner_id) = validate_api_key(pool, key_str).await?;
return Ok(AuthenticatedUser {
user_id,
owner_id,
auth_source: AuthSource::ApiKey,
email: None,
});
}
}
// 3. Check for JWT (Bearer token)
if let Some(auth_header) = headers.get(AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
let verifier = state
.jwt_verifier
.as_ref()
.ok_or(AuthError::NotConfigured)?;
let claims = verifier.verify(token)?;
let pool = state.db_pool.as_ref().ok_or(AuthError::NotConfigured)?;
let owner_id = resolve_owner_id(pool, claims.sub, claims.email.as_deref()).await?;
return Ok(AuthenticatedUser {
user_id: claims.sub,
owner_id,
auth_source: AuthSource::Jwt,
email: claims.email,
});
}
}
}
Err(AuthError::MissingToken)
}
// =============================================================================
// Extractors
// =============================================================================
/// Extractor for authenticated requests.
///
/// Tries authentication methods in order:
/// 1. Tool Key (X-Makima-Tool-Key) - for orchestrators
/// 2. API Key (X-Makima-API-Key) - for daemons/CLI
/// 3. JWT (Authorization: Bearer) - for web clients
///
/// Returns 401 Unauthorized if no valid authentication is found.
///
/// # Example
/// ```ignore
/// async fn protected_handler(
/// Authenticated(user): Authenticated,
/// ) -> impl IntoResponse {
/// Json(format!("Hello user {}", user.user_id))
/// }
/// ```
pub struct Authenticated(pub AuthenticatedUser);
impl FromRequestParts<SharedState> for Authenticated {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &SharedState,
) -> Result<Self, Self::Rejection> {
let user = extract_auth(state, &parts.headers).await?;
Ok(Authenticated(user))
}
}
/// Extractor for user-only authentication (JWT or API key, no tool keys).
///
/// Use this for endpoints that should only be accessible to actual users,
/// not orchestrators with tool keys.
///
/// Returns 401 Unauthorized if no valid user authentication is found.
/// Returns 403 Forbidden if a tool key is used.
///
/// # Example
/// ```ignore
/// async fn user_profile(
/// UserOnly(user): UserOnly,
/// ) -> impl IntoResponse {
/// // Only actual users can access this
/// Json(format!("User profile for {}", user.user_id))
/// }
/// ```
pub struct UserOnly(pub AuthenticatedUser);
impl FromRequestParts<SharedState> for UserOnly {
type Rejection = AuthError;
async fn from_request_parts(
parts: &mut Parts,
state: &SharedState,
) -> Result<Self, Self::Rejection> {
let user = extract_auth(state, &parts.headers).await?;
// Reject tool key authentication
if matches!(user.auth_source, AuthSource::ToolKey(_)) {
return Err(AuthError::InsufficientPermissions);
}
Ok(UserOnly(user))
}
}
/// Extractor for optional authentication.
///
/// Returns Some(AuthenticatedUser) if valid auth is provided, None otherwise.
/// Never returns an error - invalid auth is treated as no auth.
///
/// # Example
/// ```ignore
/// async fn public_or_private(
/// MaybeAuthenticated(user): MaybeAuthenticated,
/// ) -> impl IntoResponse {
/// match user {
/// Some(u) => Json(format!("Hello {}", u.user_id)),
/// None => Json("Hello anonymous".to_string()),
/// }
/// }
/// ```
pub struct MaybeAuthenticated(pub Option<AuthenticatedUser>);
impl FromRequestParts<SharedState> for MaybeAuthenticated {
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
state: &SharedState,
) -> Result<Self, Self::Rejection> {
let user = extract_auth(state, &parts.headers).await.ok();
Ok(MaybeAuthenticated(user))
}
}
// =============================================================================
// Tests
// =============================================================================
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_api_key() {
let key = "mk_test123456789";
let hash = hash_api_key(key);
// Hash should be consistent
assert_eq!(hash, hash_api_key(key));
// Hash should be 64 characters (SHA-256 hex)
assert_eq!(hash.len(), 64);
}
#[test]
fn test_auth_error_display() {
assert_eq!(
AuthError::MissingToken.to_string(),
"Missing authentication token"
);
assert_eq!(
AuthError::InvalidToken("bad".to_string()).to_string(),
"Invalid token: bad"
);
}
#[test]
fn test_generate_api_key_format() {
let generated = generate_api_key();
// Full key should start with mk_ prefix
assert!(generated.full_key.starts_with(API_KEY_PREFIX));
// Full key should be mk_ + 43 chars (32 bytes base64url encoded)
assert_eq!(generated.full_key.len(), 3 + 43); // "mk_" + 43
// Prefix should be mk_ + first 8 chars
assert!(generated.key_prefix.starts_with(API_KEY_PREFIX));
assert_eq!(generated.key_prefix.len(), 3 + 8);
// Hash should be 64 hex chars (SHA-256)
assert_eq!(generated.key_hash.len(), 64);
}
#[test]
fn test_generate_api_key_uniqueness() {
let key1 = generate_api_key();
let key2 = generate_api_key();
// Keys should be unique
assert_ne!(key1.full_key, key2.full_key);
assert_ne!(key1.key_hash, key2.key_hash);
}
#[test]
fn test_api_key_cache_basic() {
let cache = ApiKeyCache::new(300);
let user_id = Uuid::new_v4();
let owner_id = Uuid::new_v4();
let key_hash = "test_hash_123";
// Cache miss initially
assert!(cache.get(key_hash).is_none());
// Set and verify cache hit
cache.set(key_hash.to_string(), user_id, owner_id);
let result = cache.get(key_hash);
assert!(result.is_some());
let (cached_user, cached_owner) = result.unwrap();
assert_eq!(cached_user, user_id);
assert_eq!(cached_owner, owner_id);
}
#[test]
fn test_api_key_cache_invalidate() {
let cache = ApiKeyCache::new(300);
let user_id = Uuid::new_v4();
let owner_id = Uuid::new_v4();
let key_hash = "test_hash_456";
cache.set(key_hash.to_string(), user_id, owner_id);
assert!(cache.get(key_hash).is_some());
cache.invalidate(key_hash);
assert!(cache.get(key_hash).is_none());
}
#[test]
fn test_api_key_cache_clear() {
let cache = ApiKeyCache::new(300);
cache.set("hash1".to_string(), Uuid::new_v4(), Uuid::new_v4());
cache.set("hash2".to_string(), Uuid::new_v4(), Uuid::new_v4());
assert!(cache.get("hash1").is_some());
assert!(cache.get("hash2").is_some());
cache.clear();
assert!(cache.get("hash1").is_none());
assert!(cache.get("hash2").is_none());
}
#[test]
fn test_api_key_event_type_display() {
assert_eq!(ApiKeyEventType::Created.to_string(), "created");
assert_eq!(ApiKeyEventType::Used.to_string(), "used");
assert_eq!(ApiKeyEventType::Revoked.to_string(), "revoked");
assert_eq!(ApiKeyEventType::Refreshed.to_string(), "refreshed");
}
}