summaryrefslogtreecommitdiff
path: root/makima/src/llm/tools.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/llm/tools.rs')
-rw-r--r--makima/src/llm/tools.rs192
1 files changed, 192 insertions, 0 deletions
diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs
index 216b733..77fc8c6 100644
--- a/makima/src/llm/tools.rs
+++ b/makima/src/llm/tools.rs
@@ -232,6 +232,48 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> =
"required": ["input", "filter"]
}),
},
+ // Interactive tools
+ Tool {
+ name: "ask_user".to_string(),
+ description: "Ask the user one or more questions. Use this when you need clarification, want to offer choices, or need user input before proceeding. Each question can have multiple choice options and optionally allow custom answers. The conversation will pause until the user responds.".to_string(),
+ parameters: json!({
+ "type": "object",
+ "properties": {
+ "questions": {
+ "type": "array",
+ "description": "List of questions to ask the user",
+ "items": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string",
+ "description": "Unique identifier for this question (e.g., 'chart_type', 'color_scheme')"
+ },
+ "question": {
+ "type": "string",
+ "description": "The question to ask the user"
+ },
+ "options": {
+ "type": "array",
+ "items": { "type": "string" },
+ "description": "Multiple choice options for the user to select from"
+ },
+ "allowMultiple": {
+ "type": "boolean",
+ "description": "If true, user can select multiple options. Default false."
+ },
+ "allowCustom": {
+ "type": "boolean",
+ "description": "If true, user can provide a custom answer instead of selecting from options. Default true."
+ }
+ },
+ "required": ["id", "question", "options"]
+ }
+ }
+ },
+ "required": ["questions"]
+ }),
+ },
// Content viewing tools
Tool {
name: "view_body".to_string(),
@@ -321,6 +363,38 @@ pub enum VersionToolRequest {
RestoreVersion { target_version: i32, reason: Option<String> },
}
+/// A question to ask the user
+#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct UserQuestion {
+ /// Unique identifier for this question
+ pub id: String,
+ /// The question text
+ pub question: String,
+ /// Multiple choice options
+ pub options: Vec<String>,
+ /// Whether multiple options can be selected
+ #[serde(default)]
+ pub allow_multiple: bool,
+ /// Whether a custom answer is allowed
+ #[serde(default = "default_allow_custom")]
+ pub allow_custom: bool,
+}
+
+fn default_allow_custom() -> bool {
+ true
+}
+
+/// User's answer to a question
+#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct UserAnswer {
+ /// Question ID this answers
+ pub id: String,
+ /// Selected option(s) or custom answer
+ pub answers: Vec<String>,
+}
+
/// Result of executing a tool call with modified file state
#[derive(Debug)]
pub struct ToolExecutionResult {
@@ -330,6 +404,8 @@ pub struct ToolExecutionResult {
pub parsed_data: Option<serde_json::Value>,
/// Request for async version operations (handled by chat handler)
pub version_request: Option<VersionToolRequest>,
+ /// Questions to ask the user (pauses conversation until answered)
+ pub pending_questions: Option<Vec<UserQuestion>>,
}
/// Execute a tool call and return the result along with any state changes
@@ -350,6 +426,8 @@ pub fn execute_tool_call(
"parse_csv" => execute_parse_csv(call),
"clear_body" => execute_clear_body(),
"jq" => execute_jq(call),
+ // Interactive tools
+ "ask_user" => execute_ask_user(call),
// Content viewing tools
"view_body" => execute_view_body(current_body),
"read_element" => execute_read_element(call, current_body),
@@ -367,7 +445,84 @@ pub fn execute_tool_call(
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
+ },
+ }
+}
+
+fn execute_ask_user(call: &ToolCall) -> ToolExecutionResult {
+ let questions_value = call.arguments.get("questions");
+
+ let Some(questions_array) = questions_value.and_then(|v| v.as_array()) else {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: "Missing or invalid 'questions' parameter".to_string(),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ version_request: None,
+ pending_questions: None,
+ };
+ };
+
+ let mut questions: Vec<UserQuestion> = Vec::new();
+
+ for q in questions_array {
+ let id = q.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let question = q.get("question").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let options: Vec<String> = q
+ .get("options")
+ .and_then(|v| v.as_array())
+ .map(|arr| {
+ arr.iter()
+ .filter_map(|o| o.as_str())
+ .map(|s| s.to_string())
+ .collect()
+ })
+ .unwrap_or_default();
+ let allow_multiple = q.get("allowMultiple").and_then(|v| v.as_bool()).unwrap_or(false);
+ let allow_custom = q.get("allowCustom").and_then(|v| v.as_bool()).unwrap_or(true);
+
+ if id.is_empty() || question.is_empty() || options.is_empty() {
+ continue;
+ }
+
+ questions.push(UserQuestion {
+ id,
+ question,
+ options,
+ allow_multiple,
+ allow_custom,
+ });
+ }
+
+ if questions.is_empty() {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: "No valid questions provided".to_string(),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ version_request: None,
+ pending_questions: None,
+ };
+ }
+
+ let question_count = questions.len();
+ ToolExecutionResult {
+ result: ToolResult {
+ success: true,
+ message: format!("Asking user {} question(s). Waiting for response...", question_count),
},
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ version_request: None,
+ pending_questions: Some(questions),
}
}
@@ -404,6 +559,7 @@ fn execute_add_heading(call: &ToolCall, current_body: &[BodyElement]) -> ToolExe
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -445,6 +601,7 @@ fn execute_add_paragraph(call: &ToolCall, current_body: &[BodyElement]) -> ToolE
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -511,6 +668,7 @@ fn execute_add_chart(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecu
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -527,6 +685,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -541,6 +700,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -556,6 +716,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -573,6 +734,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -586,6 +748,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -600,6 +763,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -638,6 +802,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
};
@@ -654,6 +819,7 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -671,6 +837,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -690,6 +857,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -706,6 +874,7 @@ fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> To
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -726,6 +895,7 @@ fn execute_set_summary(call: &ToolCall, _current_summary: Option<&str>) -> ToolE
new_summary: Some(summary),
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -747,6 +917,7 @@ fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -783,6 +954,7 @@ fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: Some(json!(data)),
version_request: None,
+ pending_questions: None,
}
}
@@ -796,6 +968,7 @@ fn execute_clear_body() -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
}
}
@@ -812,6 +985,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
};
@@ -828,6 +1002,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
};
@@ -848,6 +1023,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -861,6 +1037,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -876,6 +1053,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -901,6 +1079,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
}
@@ -931,6 +1110,7 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: Some(output),
version_request: None,
+ pending_questions: None,
}
}
@@ -949,6 +1129,7 @@ fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult {
new_summary: None,
parsed_data: Some(json!([])),
version_request: None,
+ pending_questions: None,
};
}
@@ -996,6 +1177,7 @@ fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult {
new_summary: None,
parsed_data: Some(json!(elements)),
version_request: None,
+ pending_questions: None,
}
}
@@ -1012,6 +1194,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -1026,6 +1209,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
}
@@ -1075,6 +1259,7 @@ fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolEx
new_summary: None,
parsed_data: Some(element_data),
version_request: None,
+ pending_questions: None,
}
}
@@ -1089,6 +1274,7 @@ fn execute_view_transcript(transcript: &[TranscriptEntry]) -> ToolExecutionResul
new_summary: None,
parsed_data: Some(json!([])),
version_request: None,
+ pending_questions: None,
};
}
@@ -1125,6 +1311,7 @@ fn execute_view_transcript(transcript: &[TranscriptEntry]) -> ToolExecutionResul
new_summary: None,
parsed_data: Some(json!(entries)),
version_request: None,
+ pending_questions: None,
}
}
@@ -1144,6 +1331,7 @@ fn execute_list_versions() -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: Some(VersionToolRequest::ListVersions),
+ pending_questions: None,
}
}
@@ -1160,6 +1348,7 @@ fn execute_read_version(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -1172,6 +1361,7 @@ fn execute_read_version(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: Some(VersionToolRequest::ReadVersion { version: version as i32 }),
+ pending_questions: None,
}
}
@@ -1193,6 +1383,7 @@ fn execute_restore_version(call: &ToolCall) -> ToolExecutionResult {
new_summary: None,
parsed_data: None,
version_request: None,
+ pending_questions: None,
};
};
@@ -1208,6 +1399,7 @@ fn execute_restore_version(call: &ToolCall) -> ToolExecutionResult {
target_version: target_version as i32,
reason,
}),
+ pending_questions: None,
}
}