summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/mesh_supervisor.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-15 03:26:28 +0000
committersoryu <soryu@soryu.co>2026-01-15 03:26:28 +0000
commiteeafe072bc6bb81459f7d087b48fc921afe9cc11 (patch)
tree7f835993edd732f8ff66d756391dedffe3d44e90 /makima/src/server/handlers/mesh_supervisor.rs
parentc61a2b9b9c988f5460f85980d4ddf285f1a730b5 (diff)
downloadsoryu-eeafe072bc6bb81459f7d087b48fc921afe9cc11.tar.gz
soryu-eeafe072bc6bb81459f7d087b48fc921afe9cc11.zip
Automatically derive repo URL and add notifications for input
Diffstat (limited to 'makima/src/server/handlers/mesh_supervisor.rs')
-rw-r--r--makima/src/server/handlers/mesh_supervisor.rs321
1 files changed, 307 insertions, 14 deletions
diff --git a/makima/src/server/handlers/mesh_supervisor.rs b/makima/src/server/handlers/mesh_supervisor.rs
index ac59130..d0fa4d1 100644
--- a/makima/src/server/handlers/mesh_supervisor.rs
+++ b/makima/src/server/handlers/mesh_supervisor.rs
@@ -15,6 +15,7 @@ use uuid::Uuid;
use crate::db::models::{CreateTaskRequest, Task, TaskSummary};
use crate::db::repository;
+use crate::server::auth::Authenticated;
use crate::server::handlers::mesh::{extract_auth, AuthSource};
use crate::server::messages::ApiError;
use crate::server::state::{DaemonCommand, SharedState};
@@ -32,7 +33,7 @@ pub struct SpawnTaskRequest {
pub contract_id: Uuid,
pub parent_task_id: Option<Uuid>,
pub checkpoint_sha: Option<String>,
- /// Repository URL for the task (supervisor should provide this)
+ /// Repository URL for the task (optional - if not provided, will be looked up from contract).
pub repository_url: Option<String>,
}
@@ -55,6 +56,67 @@ pub struct ReadWorktreeFileRequest {
pub file_path: String,
}
+/// Request to ask a question and wait for user feedback.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct AskQuestionRequest {
+ /// The question to ask the user
+ pub question: String,
+ /// Optional choices (if empty, free-form text response)
+ #[serde(default)]
+ pub choices: Vec<String>,
+ /// Optional context about what this relates to
+ pub context: Option<String>,
+ /// How long to wait for a response (seconds)
+ #[serde(default = "default_question_timeout")]
+ pub timeout_seconds: i32,
+}
+
+fn default_question_timeout() -> i32 {
+ 3600 // 1 hour default
+}
+
+/// Response from asking a question.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct AskQuestionResponse {
+ /// The question ID for tracking
+ pub question_id: Uuid,
+ /// The user's response (None if timed out)
+ pub response: Option<String>,
+ /// Whether the question timed out
+ pub timed_out: bool,
+}
+
+/// Request to answer a supervisor question.
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct AnswerQuestionRequest {
+ /// The user's response
+ pub response: String,
+}
+
+/// Response to answering a question.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct AnswerQuestionResponse {
+ /// Whether the answer was accepted
+ pub success: bool,
+}
+
+/// Pending question summary.
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct PendingQuestionSummary {
+ pub question_id: Uuid,
+ pub task_id: Uuid,
+ pub contract_id: Uuid,
+ pub question: String,
+ pub choices: Vec<String>,
+ pub context: Option<String>,
+ pub created_at: chrono::DateTime<chrono::Utc>,
+}
+
/// Request to create a checkpoint.
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
@@ -321,23 +383,49 @@ pub async fn spawn_task(
}
};
- // Get repository URL from the contract's primary repository
- let repo_url = match repository::list_contract_repositories(pool, request.contract_id).await {
- Ok(repos) => {
- // Prefer primary repo, fallback to first repo
- repos.iter()
- .find(|r| r.is_primary)
- .or(repos.first())
- .and_then(|r| r.repository_url.clone())
- }
- Err(e) => {
- tracing::warn!(error = %e, "Failed to get contract repositories, continuing without repo URL");
+ // Get repository URL - either from request or from contract's repositories
+ let repo_url = if let Some(url) = request.repository_url.clone() {
+ if !url.trim().is_empty() {
+ Some(url)
+ } else {
None
}
+ } else {
+ None
};
- // Supervisor can override with explicit repository_url
- let repo_url = request.repository_url.clone().or(repo_url);
+ // If no repo URL provided, look it up from the contract
+ let repo_url = match repo_url {
+ Some(url) => Some(url),
+ None => {
+ match repository::list_contract_repositories(pool, request.contract_id).await {
+ Ok(repos) => {
+ // Prefer primary repo, fallback to first repo
+ let repo = repos.iter()
+ .find(|r| r.is_primary)
+ .or(repos.first());
+
+ // Use repository_url if set, otherwise use local_path
+ repo.and_then(|r| {
+ r.repository_url.clone()
+ .or_else(|| r.local_path.clone())
+ })
+ }
+ Err(e) => {
+ tracing::warn!(error = %e, "Failed to get contract repositories");
+ None
+ }
+ }
+ }
+ };
+
+ // Validate that we have a repo URL
+ if repo_url.is_none() {
+ return (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("MISSING_REPO_URL", "No repository URL found. Either provide one or ensure the contract has repositories configured.")),
+ ).into_response();
+ }
// Create task request
let create_req = CreateTaskRequest {
@@ -1151,3 +1239,208 @@ pub async fn get_task_diff(
}),
).into_response()
}
+
+// =============================================================================
+// Supervisor Question Handlers
+// =============================================================================
+
+/// Ask a question and wait for user feedback.
+///
+/// The supervisor calls this to ask a question. The endpoint will poll until
+/// either the user responds or the timeout is reached.
+#[utoipa::path(
+ post,
+ path = "/api/v1/mesh/supervisor/questions",
+ request_body = AskQuestionRequest,
+ responses(
+ (status = 200, description = "Question answered", body = AskQuestionResponse),
+ (status = 408, description = "Question timed out", body = AskQuestionResponse),
+ (status = 401, description = "Unauthorized"),
+ (status = 403, description = "Forbidden - not a supervisor"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("tool_key" = [])
+ ),
+ tag = "Mesh Supervisor"
+)]
+pub async fn ask_question(
+ State(state): State<SharedState>,
+ headers: HeaderMap,
+ Json(request): Json<AskQuestionRequest>,
+) -> impl IntoResponse {
+ let (supervisor_id, owner_id) = match verify_supervisor_auth(&state, &headers, None).await {
+ Ok(ids) => ids,
+ Err(e) => return e.into_response(),
+ };
+
+ let pool = state.db_pool.as_ref().unwrap();
+
+ // Get the supervisor task to find its contract
+ let supervisor = match repository::get_task_for_owner(pool, supervisor_id, owner_id).await {
+ Ok(Some(t)) => t,
+ Ok(None) => {
+ return (
+ StatusCode::NOT_FOUND,
+ Json(ApiError::new("NOT_FOUND", "Supervisor task not found")),
+ ).into_response();
+ }
+ Err(e) => {
+ tracing::error!(error = %e, "Failed to get supervisor task");
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(ApiError::new("DB_ERROR", "Failed to get supervisor task")),
+ ).into_response();
+ }
+ };
+
+ let Some(contract_id) = supervisor.contract_id else {
+ return (
+ StatusCode::BAD_REQUEST,
+ Json(ApiError::new("NO_CONTRACT", "Supervisor has no associated contract")),
+ ).into_response();
+ };
+
+ // Add the question
+ let question_id = state.add_supervisor_question(
+ supervisor_id,
+ contract_id,
+ owner_id,
+ request.question.clone(),
+ request.choices.clone(),
+ request.context.clone(),
+ );
+
+ // Poll for response with timeout
+ let timeout_duration = std::time::Duration::from_secs(request.timeout_seconds.max(1) as u64);
+ let start = std::time::Instant::now();
+ let poll_interval = std::time::Duration::from_millis(500);
+
+ loop {
+ // Check if response has been submitted
+ if let Some(response) = state.get_question_response(question_id) {
+ // Clean up the response
+ state.cleanup_question_response(question_id);
+
+ return (
+ StatusCode::OK,
+ Json(AskQuestionResponse {
+ question_id,
+ response: Some(response.response),
+ timed_out: false,
+ }),
+ ).into_response();
+ }
+
+ // Check timeout
+ if start.elapsed() >= timeout_duration {
+ // Remove the pending question on timeout
+ state.remove_pending_question(question_id);
+
+ return (
+ StatusCode::REQUEST_TIMEOUT,
+ Json(AskQuestionResponse {
+ question_id,
+ response: None,
+ timed_out: true,
+ }),
+ ).into_response();
+ }
+
+ // Wait before polling again
+ tokio::time::sleep(poll_interval).await;
+ }
+}
+
+/// Get all pending questions for the current user.
+#[utoipa::path(
+ get,
+ path = "/api/v1/mesh/questions",
+ responses(
+ (status = 200, description = "List of pending questions", body = Vec<PendingQuestionSummary>),
+ (status = 401, description = "Unauthorized"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("bearer_auth" = []),
+ ("api_key" = [])
+ ),
+ tag = "Mesh"
+)]
+pub async fn list_pending_questions(
+ State(state): State<SharedState>,
+ Authenticated(auth): Authenticated,
+) -> impl IntoResponse {
+ let questions: Vec<PendingQuestionSummary> = state
+ .get_pending_questions_for_owner(auth.owner_id)
+ .into_iter()
+ .map(|q| PendingQuestionSummary {
+ question_id: q.question_id,
+ task_id: q.task_id,
+ contract_id: q.contract_id,
+ question: q.question,
+ choices: q.choices,
+ context: q.context,
+ created_at: q.created_at,
+ })
+ .collect();
+
+ Json(questions).into_response()
+}
+
+/// Answer a pending supervisor question.
+#[utoipa::path(
+ post,
+ path = "/api/v1/mesh/questions/{question_id}/answer",
+ params(
+ ("question_id" = Uuid, Path, description = "Question ID")
+ ),
+ request_body = AnswerQuestionRequest,
+ responses(
+ (status = 200, description = "Question answered", body = AnswerQuestionResponse),
+ (status = 401, description = "Unauthorized"),
+ (status = 404, description = "Question not found"),
+ (status = 500, description = "Internal server error"),
+ ),
+ security(
+ ("bearer_auth" = []),
+ ("api_key" = [])
+ ),
+ tag = "Mesh"
+)]
+pub async fn answer_question(
+ State(state): State<SharedState>,
+ Authenticated(auth): Authenticated,
+ Path(question_id): Path<Uuid>,
+ Json(request): Json<AnswerQuestionRequest>,
+) -> impl IntoResponse {
+ // Verify the question exists and belongs to this owner
+ let question = match state.get_pending_question(question_id) {
+ Some(q) if q.owner_id == auth.owner_id => q,
+ Some(_) => {
+ return (
+ StatusCode::FORBIDDEN,
+ Json(ApiError::new("FORBIDDEN", "Question belongs to another user")),
+ ).into_response();
+ }
+ None => {
+ return (
+ StatusCode::NOT_FOUND,
+ Json(ApiError::new("NOT_FOUND", "Question not found or already answered")),
+ ).into_response();
+ }
+ };
+
+ // Submit the response
+ let success = state.submit_question_response(question_id, request.response);
+
+ if success {
+ tracing::info!(
+ question_id = %question_id,
+ task_id = %question.task_id,
+ "User answered supervisor question"
+ );
+ }
+
+ Json(AnswerQuestionResponse { success }).into_response()
+}