diff options
| author | soryu <soryu@soryu.co> | 2026-01-02 22:13:28 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2026-01-02 22:13:28 +0000 |
| commit | f79c416c58557d2f946aa5332989afdfa8c021cd (patch) | |
| tree | e64e8fef0bedd6b40d3a2314d39654aa5c073980 /makima/src | |
| parent | 2fab6904260099d9a011734763e62ebba91cf448 (diff) | |
| download | soryu-f79c416c58557d2f946aa5332989afdfa8c021cd.tar.gz soryu-f79c416c58557d2f946aa5332989afdfa8c021cd.zip | |
Add defined user input dialogue to LLM edit
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/llm/mod.rs | 5 | ||||
| -rw-r--r-- | makima/src/llm/tools.rs | 192 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 29 |
3 files changed, 224 insertions, 2 deletions
diff --git a/makima/src/llm/mod.rs b/makima/src/llm/mod.rs index 0df492d..1001854 100644 --- a/makima/src/llm/mod.rs +++ b/makima/src/llm/mod.rs @@ -6,7 +6,10 @@ pub mod tools; pub use claude::{ClaudeClient, ClaudeModel}; pub use groq::GroqClient; -pub use tools::{execute_tool_call, Tool, ToolCall, ToolResult, VersionToolRequest, AVAILABLE_TOOLS}; +pub use tools::{ + execute_tool_call, Tool, ToolCall, ToolResult, UserAnswer, UserQuestion, VersionToolRequest, + AVAILABLE_TOOLS, +}; /// Available LLM providers and models #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 216b733..77fc8c6 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -232,6 +232,48 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "required": ["input", "filter"] }), }, + // Interactive tools + Tool { + name: "ask_user".to_string(), + description: "Ask the user one or more questions. Use this when you need clarification, want to offer choices, or need user input before proceeding. Each question can have multiple choice options and optionally allow custom answers. The conversation will pause until the user responds.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "questions": { + "type": "array", + "description": "List of questions to ask the user", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this question (e.g., 'chart_type', 'color_scheme')" + }, + "question": { + "type": "string", + "description": "The question to ask the user" + }, + "options": { + "type": "array", + "items": { "type": "string" }, + "description": "Multiple choice options for the user to select from" + }, + "allowMultiple": { + "type": "boolean", + "description": "If true, user can select multiple options. Default false." + }, + "allowCustom": { + "type": "boolean", + "description": "If true, user can provide a custom answer instead of selecting from options. Default true." + } + }, + "required": ["id", "question", "options"] + } + } + }, + "required": ["questions"] + }), + }, // Content viewing tools Tool { name: "view_body".to_string(), @@ -321,6 +363,38 @@ pub enum VersionToolRequest { RestoreVersion { target_version: i32, reason: Option<String> }, } +/// A question to ask the user +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct UserQuestion { + /// Unique identifier for this question + pub id: String, + /// The question text + pub question: String, + /// Multiple choice options + pub options: Vec<String>, + /// Whether multiple options can be selected + #[serde(default)] + pub allow_multiple: bool, + /// Whether a custom answer is allowed + #[serde(default = "default_allow_custom")] + pub allow_custom: bool, +} + +fn default_allow_custom() -> bool { + true +} + +/// User's answer to a question +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct UserAnswer { + /// Question ID this answers + pub id: String, + /// Selected option(s) or custom answer + pub answers: Vec<String>, +} + /// Result of executing a tool call with modified file state #[derive(Debug)] pub struct ToolExecutionResult { @@ -330,6 +404,8 @@ pub struct ToolExecutionResult { pub parsed_data: Option<serde_json::Value>, /// Request for async version operations (handled by chat handler) pub version_request: Option<VersionToolRequest>, + /// Questions to ask the user (pauses conversation until answered) + pub pending_questions: Option<Vec<UserQuestion>>, } /// Execute a tool call and return the result along with any state changes @@ -350,6 +426,8 @@ pub fn execute_tool_call( "parse_csv" => execute_parse_csv(call), "clear_body" => execute_clear_body(), "jq" => execute_jq(call), + // Interactive tools + "ask_user" => execute_ask_user(call), // Content viewing tools "view_body" => execute_view_body(current_body), "read_element" => execute_read_element(call, current_body), @@ -367,7 +445,84 @@ pub fn execute_tool_call( new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, + }, + } +} + +fn execute_ask_user(call: &ToolCall) -> ToolExecutionResult { + let questions_value = call.arguments.get("questions"); + + let Some(questions_array) = questions_value.and_then(|v| v.as_array()) else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing or invalid 'questions' parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + version_request: None, + pending_questions: None, + }; + }; + + let mut questions: Vec<UserQuestion> = Vec::new(); + + for q in questions_array { + let id = q.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let question = q.get("question").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let options: Vec<String> = q + .get("options") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|o| o.as_str()) + .map(|s| s.to_string()) + .collect() + }) + .unwrap_or_default(); + let allow_multiple = q.get("allowMultiple").and_then(|v| v.as_bool()).unwrap_or(false); + let allow_custom = q.get("allowCustom").and_then(|v| v.as_bool()).unwrap_or(true); + + if id.is_empty() || question.is_empty() || options.is_empty() { + continue; + } + + questions.push(UserQuestion { + id, + question, + options, + allow_multiple, + allow_custom, + }); + } + + if questions.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "No valid questions provided".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + version_request: None, + pending_questions: None, + }; + } + + let question_count = questions.len(); + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Asking user {} question(s). Waiting for response...", question_count), }, + new_body: None, + new_summary: None, + parsed_data: None, + version_request: None, + pending_questions: Some(questions), } } @@ -404,6 +559,7 @@ fn execute_add_heading(call: &ToolCall, current_body: &[BodyElement]) -> ToolExe new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -445,6 +601,7 @@ fn execute_add_paragraph(call: &ToolCall, current_body: &[BodyElement]) -> ToolE new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -511,6 +668,7 @@ fn execute_add_chart(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecu new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -527,6 +685,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -541,6 +700,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -556,6 +716,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -573,6 +734,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -586,6 +748,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -600,6 +763,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -638,6 +802,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } }; @@ -654,6 +819,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -671,6 +837,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -690,6 +857,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -706,6 +874,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -726,6 +895,7 @@ fn execute_set_summary(call: &ToolCall, _current_summary: Option<&str>) -> ToolE new_summary: Some(summary), parsed_data: None, version_request: None, + pending_questions: None, } } @@ -747,6 +917,7 @@ fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -783,6 +954,7 @@ fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: Some(json!(data)), version_request: None, + pending_questions: None, } } @@ -796,6 +968,7 @@ fn execute_clear_body() -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, } } @@ -812,6 +985,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } }; @@ -828,6 +1002,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } }; @@ -848,6 +1023,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -861,6 +1037,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -876,6 +1053,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -901,6 +1079,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } } @@ -931,6 +1110,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: Some(output), version_request: None, + pending_questions: None, } } @@ -949,6 +1129,7 @@ fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult { new_summary: None, parsed_data: Some(json!([])), version_request: None, + pending_questions: None, }; } @@ -996,6 +1177,7 @@ fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult { new_summary: None, parsed_data: Some(json!(elements)), version_request: None, + pending_questions: None, } } @@ -1012,6 +1194,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -1026,6 +1209,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; } @@ -1075,6 +1259,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx new_summary: None, parsed_data: Some(element_data), version_request: None, + pending_questions: None, } } @@ -1089,6 +1274,7 @@ fn execute_view_transcript(transcript: &[TranscriptEntry]) -> ToolExecutionResul new_summary: None, parsed_data: Some(json!([])), version_request: None, + pending_questions: None, }; } @@ -1125,6 +1311,7 @@ fn execute_view_transcript(transcript: &[TranscriptEntry]) -> ToolExecutionResul new_summary: None, parsed_data: Some(json!(entries)), version_request: None, + pending_questions: None, } } @@ -1144,6 +1331,7 @@ fn execute_list_versions() -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: Some(VersionToolRequest::ListVersions), + pending_questions: None, } } @@ -1160,6 +1348,7 @@ fn execute_read_version(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -1172,6 +1361,7 @@ fn execute_read_version(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: Some(VersionToolRequest::ReadVersion { version: version as i32 }), + pending_questions: None, } } @@ -1193,6 +1383,7 @@ fn execute_restore_version(call: &ToolCall) -> ToolExecutionResult { new_summary: None, parsed_data: None, version_request: None, + pending_questions: None, }; }; @@ -1208,6 +1399,7 @@ fn execute_restore_version(call: &ToolCall) -> ToolExecutionResult { target_version: target_version as i32, reason, }), + pending_questions: None, } } diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index 158805b..51f17c1 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -15,7 +15,7 @@ use crate::llm::{ claude::{self, ClaudeClient, ClaudeError, ClaudeModel}, execute_tool_call, groq::{GroqClient, GroqError, Message, ToolCallResponse}, - LlmModel, ToolCall, ToolResult, VersionToolRequest, AVAILABLE_TOOLS, + LlmModel, ToolCall, ToolResult, UserQuestion, VersionToolRequest, AVAILABLE_TOOLS, }; use crate::server::state::{FileUpdateNotification, SharedState}; @@ -66,6 +66,9 @@ pub struct ChatResponse { pub updated_body: Vec<BodyElement>, /// Updated summary (if changed) pub updated_summary: Option<String>, + /// Questions pending user answers (pauses conversation) + #[serde(skip_serializing_if = "Option::is_none")] + pub pending_questions: Option<Vec<UserQuestion>>, } #[derive(Debug, Serialize, ToSchema)] @@ -326,6 +329,8 @@ You have access to tools for: // Track consecutive failures for agentic retry logic let mut consecutive_failures = 0; const MAX_CONSECUTIVE_FAILURES: usize = 3; + // Track pending user questions (pauses the conversation) + let mut pending_questions: Option<Vec<UserQuestion>> = None; // Multi-turn agentic tool calling loop for round in 0..MAX_TOOL_ROUNDS { @@ -508,6 +513,21 @@ You have access to tools for: ); } + // Check for pending user questions (pauses the conversation) + if let Some(questions) = execution_result.pending_questions { + tracing::info!( + question_count = questions.len(), + "LLM requesting user input, pausing conversation" + ); + pending_questions = Some(questions); + // Track this tool call before breaking + all_tool_call_infos.push(ToolCallInfo { + name: tool_call.name.clone(), + result: execution_result.result, + }); + break; // Exit inner loop + } + // Build tool result message content with enhanced context for agentic reasoning let result_content = if let Some(parsed_data) = &execution_result.parsed_data { // Include parsed data in the result for the LLM to use @@ -559,6 +579,12 @@ You have access to tools for: }); } + // If user questions are pending, pause the conversation + if pending_questions.is_some() { + final_response = result.content; + break; + } + // If finish reason indicates completion, exit loop let finish_lower = result.finish_reason.to_lowercase(); if finish_lower == "stop" || finish_lower == "end_turn" { @@ -637,6 +663,7 @@ You have access to tools for: tool_calls: all_tool_call_infos, updated_body: current_body, updated_summary: current_summary, + pending_questions, }), ) .into_response() |
