diff options
Diffstat (limited to 'makima/src/server/handlers/mesh_supervisor.rs')
| -rw-r--r-- | makima/src/server/handlers/mesh_supervisor.rs | 321 |
1 files changed, 307 insertions, 14 deletions
diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs index ac59130..d0fa4d1 100644 --- a/makima/src/server/handlers/mesh_supervisor.rs +++ b/makima/src/server/handlers/mesh_supervisor.rs @@ -15,6 +15,7 @@ use uuid::Uuid; use crate::db::models::{CreateTaskRequest, Task, TaskSummary}; use crate::db::repository; +use crate::server::auth::Authenticated; use crate::server::handlers::mesh::{extract_auth, AuthSource}; use crate::server::messages::ApiError; use crate::server::state::{DaemonCommand, SharedState}; @@ -32,7 +33,7 @@ pub struct SpawnTaskRequest { pub contract_id: Uuid, pub parent_task_id: Option<Uuid>, pub checkpoint_sha: Option<String>, - /// Repository URL for the task (supervisor should provide this) + /// Repository URL for the task (optional - if not provided, will be looked up from contract). pub repository_url: Option<String>, } @@ -55,6 +56,67 @@ pub struct ReadWorktreeFileRequest { pub file_path: String, } +/// Request to ask a question and wait for user feedback. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct AskQuestionRequest { + /// The question to ask the user + pub question: String, + /// Optional choices (if empty, free-form text response) + #[serde(default)] + pub choices: Vec<String>, + /// Optional context about what this relates to + pub context: Option<String>, + /// How long to wait for a response (seconds) + #[serde(default = "default_question_timeout")] + pub timeout_seconds: i32, +} + +fn default_question_timeout() -> i32 { + 3600 // 1 hour default +} + +/// Response from asking a question. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct AskQuestionResponse { + /// The question ID for tracking + pub question_id: Uuid, + /// The user's response (None if timed out) + pub response: Option<String>, + /// Whether the question timed out + pub timed_out: bool, +} + +/// Request to answer a supervisor question. +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct AnswerQuestionRequest { + /// The user's response + pub response: String, +} + +/// Response to answering a question. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct AnswerQuestionResponse { + /// Whether the answer was accepted + pub success: bool, +} + +/// Pending question summary. +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct PendingQuestionSummary { + pub question_id: Uuid, + pub task_id: Uuid, + pub contract_id: Uuid, + pub question: String, + pub choices: Vec<String>, + pub context: Option<String>, + pub created_at: chrono::DateTime<chrono::Utc>, +} + /// Request to create a checkpoint. #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] @@ -321,23 +383,49 @@ pub async fn spawn_task( } }; - // Get repository URL from the contract's primary repository - let repo_url = match repository::list_contract_repositories(pool, request.contract_id).await { - Ok(repos) => { - // Prefer primary repo, fallback to first repo - repos.iter() - .find(|r| r.is_primary) - .or(repos.first()) - .and_then(|r| r.repository_url.clone()) - } - Err(e) => { - tracing::warn!(error = %e, "Failed to get contract repositories, continuing without repo URL"); + // Get repository URL - either from request or from contract's repositories + let repo_url = if let Some(url) = request.repository_url.clone() { + if !url.trim().is_empty() { + Some(url) + } else { None } + } else { + None }; - // Supervisor can override with explicit repository_url - let repo_url = request.repository_url.clone().or(repo_url); + // If no repo URL provided, look it up from the contract + let repo_url = match repo_url { + Some(url) => Some(url), + None => { + match repository::list_contract_repositories(pool, request.contract_id).await { + Ok(repos) => { + // Prefer primary repo, fallback to first repo + let repo = repos.iter() + .find(|r| r.is_primary) + .or(repos.first()); + + // Use repository_url if set, otherwise use local_path + repo.and_then(|r| { + r.repository_url.clone() + .or_else(|| r.local_path.clone()) + }) + } + Err(e) => { + tracing::warn!(error = %e, "Failed to get contract repositories"); + None + } + } + } + }; + + // Validate that we have a repo URL + if repo_url.is_none() { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("MISSING_REPO_URL", "No repository URL found. Either provide one or ensure the contract has repositories configured.")), + ).into_response(); + } // Create task request let create_req = CreateTaskRequest { @@ -1151,3 +1239,208 @@ pub async fn get_task_diff( }), ).into_response() } + +// ============================================================================= +// Supervisor Question Handlers +// ============================================================================= + +/// Ask a question and wait for user feedback. +/// +/// The supervisor calls this to ask a question. The endpoint will poll until +/// either the user responds or the timeout is reached. +#[utoipa::path( + post, + path = "/api/v1/mesh/supervisor/questions", + request_body = AskQuestionRequest, + responses( + (status = 200, description = "Question answered", body = AskQuestionResponse), + (status = 408, description = "Question timed out", body = AskQuestionResponse), + (status = 401, description = "Unauthorized"), + (status = 403, description = "Forbidden - not a supervisor"), + (status = 500, description = "Internal server error"), + ), + security( + ("tool_key" = []) + ), + tag = "Mesh Supervisor" +)] +pub async fn ask_question( + State(state): State<SharedState>, + headers: HeaderMap, + Json(request): Json<AskQuestionRequest>, +) -> impl IntoResponse { + let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await { + Ok(ids) => ids, + Err(e) => return e.into_response(), + }; + + let pool = state.db_pool.as_ref().unwrap(); + + // Get the supervisor task to find its contract + let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await { + Ok(Some(t)) => t, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Supervisor task not found")), + ).into_response(); + } + Err(e) => { + tracing::error!(error = %e, "Failed to get supervisor task"); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")), + ).into_response(); + } + }; + + let Some(contract_id) = supervisor.contract_id else { + return ( + StatusCode::BAD_REQUEST, + Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")), + ).into_response(); + }; + + // Add the question + let question_id = state.add_supervisor_question( + supervisor_id, + contract_id, + owner_id, + request.question.clone(), + request.choices.clone(), + request.context.clone(), + ); + + // Poll for response with timeout + let timeout_duration = std::time::Duration::from_secs(request.timeout_seconds.max(1) as u64); + let start = std::time::Instant::now(); + let poll_interval = std::time::Duration::from_millis(500); + + loop { + // Check if response has been submitted + if let Some(response) = state.get_question_response(question_id) { + // Clean up the response + state.cleanup_question_response(question_id); + + return ( + StatusCode::OK, + Json(AskQuestionResponse { + question_id, + response: Some(response.response), + timed_out: false, + }), + ).into_response(); + } + + // Check timeout + if start.elapsed() >= timeout_duration { + // Remove the pending question on timeout + state.remove_pending_question(question_id); + + return ( + StatusCode::REQUEST_TIMEOUT, + Json(AskQuestionResponse { + question_id, + response: None, + timed_out: true, + }), + ).into_response(); + } + + // Wait before polling again + tokio::time::sleep(poll_interval).await; + } +} + +/// Get all pending questions for the current user. +#[utoipa::path( + get, + path = "/api/v1/mesh/questions", + responses( + (status = 200, description = "List of pending questions", body = Vec<PendingQuestionSummary>), + (status = 401, description = "Unauthorized"), + (status = 500, description = "Internal server error"), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn list_pending_questions( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, +) -> impl IntoResponse { + let questions: Vec<PendingQuestionSummary> = state + .get_pending_questions_for_owner(auth.owner_id) + .into_iter() + .map(|q| PendingQuestionSummary { + question_id: q.question_id, + task_id: q.task_id, + contract_id: q.contract_id, + question: q.question, + choices: q.choices, + context: q.context, + created_at: q.created_at, + }) + .collect(); + + Json(questions).into_response() +} + +/// Answer a pending supervisor question. +#[utoipa::path( + post, + path = "/api/v1/mesh/questions/{question_id}/answer", + params( + ("question_id" = Uuid, Path, description = "Question ID") + ), + request_body = AnswerQuestionRequest, + responses( + (status = 200, description = "Question answered", body = AnswerQuestionResponse), + (status = 401, description = "Unauthorized"), + (status = 404, description = "Question not found"), + (status = 500, description = "Internal server error"), + ), + security( + ("bearer_auth" = []), + ("api_key" = []) + ), + tag = "Mesh" +)] +pub async fn answer_question( + State(state): State<SharedState>, + Authenticated(auth): Authenticated, + Path(question_id): Path<Uuid>, + Json(request): Json<AnswerQuestionRequest>, +) -> impl IntoResponse { + // Verify the question exists and belongs to this owner + let question = match state.get_pending_question(question_id) { + Some(q) if q.owner_id == auth.owner_id => q, + Some(_) => { + return ( + StatusCode::FORBIDDEN, + Json(ApiError::new("FORBIDDEN", "Question belongs to another user")), + ).into_response(); + } + None => { + return ( + StatusCode::NOT_FOUND, + Json(ApiError::new("NOT_FOUND", "Question not found or already answered")), + ).into_response(); + } + }; + + // Submit the response + let success = state.submit_question_response(question_id, request.response); + + if success { + tracing::info!( + question_id = %question_id, + task_id = %question.task_id, + "User answered supervisor question" + ); + } + + Json(AnswerQuestionResponse { success }).into_response() +} |
