summaryrefslogtreecommitdiff
path: root/makima/src/llm/tools.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-23 18:24:42 +0000
committersoryu <soryu@soryu.co>2025-12-23 18:24:42 +0000
commit3c0adec8e3a9dd3bc34251e87e0fb5314793426d (patch)
tree9dfe61e55bd703aa09df03abfcbf8e7a8b2babce /makima/src/llm/tools.rs
parent555061b179b8ec034cb70f9a2dd6c823ced0f637 (diff)
downloadsoryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.tar.gz
soryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.zip
Add claude opus/sonnet support
Diffstat (limited to 'makima/src/llm/tools.rs')
-rw-r--r--makima/src/llm/tools.rs308
1 files changed, 281 insertions, 27 deletions
diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs
index 3bd102f..e6b2954 100644
--- a/makima/src/llm/tools.rs
+++ b/makima/src/llm/tools.rs
@@ -1,5 +1,6 @@
//! Tool definitions for file editing via LLM.
+use jaq_interpret::FilterT;
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -121,7 +122,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> =
},
Tool {
name: "update_element".to_string(),
- description: "Update an existing element in the file body".to_string(),
+ description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(),
parameters: json!({
"type": "object",
"properties": {
@@ -129,12 +130,35 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> =
"type": "integer",
"description": "Index of element to update (0-indexed)"
},
- "element": {
- "type": "object",
- "description": "New element data. Must include 'type' field (heading, paragraph, chart)."
+ "element_type": {
+ "type": "string",
+ "enum": ["heading", "paragraph", "chart"],
+ "description": "Type of element"
+ },
+ "text": {
+ "type": "string",
+ "description": "Text content (required for heading and paragraph)"
+ },
+ "level": {
+ "type": "integer",
+ "description": "Heading level 1-6 (required for heading)"
+ },
+ "chartType": {
+ "type": "string",
+ "enum": ["line", "bar", "pie", "area"],
+ "description": "Chart type (required for chart)"
+ },
+ "data": {
+ "type": "array",
+ "description": "Chart data array (required for chart)",
+ "items": { "type": "object" }
+ },
+ "title": {
+ "type": "string",
+ "description": "Chart title (optional for chart)"
}
},
- "required": ["index", "element"]
+ "required": ["index", "element_type"]
}),
},
Tool {
@@ -191,6 +215,23 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> =
"properties": {}
}),
},
+ Tool {
+ name: "jq".to_string(),
+ description: "Transform JSON data using jq expressions. Useful for filtering, mapping, grouping, and aggregating data before creating charts.".to_string(),
+ parameters: json!({
+ "type": "object",
+ "properties": {
+ "input": {
+ "description": "The JSON data to transform (can be an array or object)"
+ },
+ "filter": {
+ "type": "string",
+ "description": "The jq filter expression. Examples: '.[] | select(.value > 10)', 'group_by(.category) | map({name: .[0].category, count: length})', '[.[] | {name: .label, value: .amount}]'"
+ }
+ },
+ "required": ["input", "filter"]
+ }),
+ },
]
});
@@ -219,6 +260,7 @@ pub fn execute_tool_call(
"set_summary" => execute_set_summary(call, current_summary),
"parse_csv" => execute_parse_csv(call),
"clear_body" => execute_clear_body(),
+ "jq" => execute_jq(call),
_ => ToolExecutionResult {
result: ToolResult {
success: false,
@@ -415,7 +457,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult {
let index = call.arguments.get("index").and_then(|v| v.as_u64());
- let element_json = call.arguments.get("element");
+ let element_type = call.arguments.get("element_type").and_then(|v| v.as_str());
let Some(index) = index else {
return ToolExecutionResult {
@@ -429,11 +471,11 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
};
};
- let Some(element_json) = element_json else {
+ let Some(element_type) = element_type else {
return ToolExecutionResult {
result: ToolResult {
success: false,
- message: "Missing element parameter".to_string(),
+ message: "Missing element_type parameter".to_string(),
},
new_body: None,
new_summary: None,
@@ -454,31 +496,55 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool
};
}
- let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone());
- match new_element {
- Ok(element) => {
- let mut new_body = current_body.to_vec();
- new_body[index] = element;
-
- ToolExecutionResult {
+ // Build the element based on type
+ let new_element = match element_type {
+ "heading" => {
+ let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8;
+ let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ BodyElement::Heading { level, text }
+ }
+ "paragraph" => {
+ let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ BodyElement::Paragraph { text }
+ }
+ "chart" => {
+ let chart_type_str = call.arguments.get("chartType").and_then(|v| v.as_str()).unwrap_or("bar");
+ let chart_type = match chart_type_str {
+ "line" => ChartType::Line,
+ "bar" => ChartType::Bar,
+ "pie" => ChartType::Pie,
+ "area" => ChartType::Area,
+ _ => ChartType::Bar,
+ };
+ let title = call.arguments.get("title").and_then(|v| v.as_str()).map(|s| s.to_string());
+ let data = call.arguments.get("data").cloned().unwrap_or(json!([]));
+ let config = call.arguments.get("config").cloned();
+ BodyElement::Chart { chart_type, title, data, config }
+ }
+ _ => {
+ return ToolExecutionResult {
result: ToolResult {
- success: true,
- message: format!("Updated element at index {}", index),
+ success: false,
+ message: format!("Unknown element_type: {}. Must be heading, paragraph, or chart.", element_type),
},
- new_body: Some(new_body),
+ new_body: None,
new_summary: None,
parsed_data: None,
- }
+ };
}
- Err(e) => ToolExecutionResult {
- result: ToolResult {
- success: false,
- message: format!("Invalid element format: {}", e),
- },
- new_body: None,
- new_summary: None,
- parsed_data: None,
+ };
+
+ let mut new_body = current_body.to_vec();
+ new_body[index] = new_element;
+
+ ToolExecutionResult {
+ result: ToolResult {
+ success: true,
+ message: format!("Updated element at index {} to {}", index, element_type),
},
+ new_body: Some(new_body),
+ new_summary: None,
+ parsed_data: None,
}
}
@@ -616,3 +682,191 @@ fn execute_clear_body() -> ToolExecutionResult {
parsed_data: None,
}
}
+
+fn execute_jq(call: &ToolCall) -> ToolExecutionResult {
+ let input = match call.arguments.get("input") {
+ Some(v) => v.clone(),
+ None => {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: "Missing input parameter".to_string(),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ }
+ };
+
+ let filter = match call.arguments.get("filter").and_then(|v| v.as_str()) {
+ Some(f) => f,
+ None => {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: "Missing filter parameter".to_string(),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ }
+ };
+
+ // Parse the jq filter
+ let mut defs = jaq_interpret::ParseCtx::new(Vec::new());
+ defs.insert_natives(jaq_core::core());
+ defs.insert_defs(jaq_std::std());
+
+ let (parsed_filter, errs) = jaq_parse::parse(filter, jaq_parse::main());
+ if !errs.is_empty() {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: format!("Invalid jq filter: {:?}", errs),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ }
+
+ let Some(parsed_filter) = parsed_filter else {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: "Failed to parse jq filter".to_string(),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ };
+
+ // Compile the filter
+ let compiled = defs.compile(parsed_filter);
+ if !defs.errs.is_empty() {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: format!("Failed to compile jq filter ({} errors)", defs.errs.len()),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ }
+
+ // Convert serde_json::Value to jaq Value
+ let jaq_input = json_to_jaq(&input);
+
+ // Execute the filter
+ let inputs = jaq_interpret::RcIter::new(std::iter::empty());
+ let mut results: Vec<serde_json::Value> = Vec::new();
+
+ for output in compiled.run((jaq_interpret::Ctx::new([], &inputs), jaq_input)) {
+ match output {
+ Ok(val) => {
+ results.push(jaq_to_json(&val));
+ }
+ Err(e) => {
+ return ToolExecutionResult {
+ result: ToolResult {
+ success: false,
+ message: format!("jq execution error: {:?}", e),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: None,
+ };
+ }
+ }
+ }
+
+ // Return single value or array based on results
+ let output = if results.len() == 1 {
+ results.into_iter().next().unwrap()
+ } else {
+ json!(results)
+ };
+
+ let preview = {
+ let s = output.to_string();
+ if s.len() > 100 {
+ format!("{}...", &s[..100])
+ } else {
+ s
+ }
+ };
+
+ ToolExecutionResult {
+ result: ToolResult {
+ success: true,
+ message: format!("jq transform complete: {}", preview),
+ },
+ new_body: None,
+ new_summary: None,
+ parsed_data: Some(output),
+ }
+}
+
+/// Convert serde_json::Value to jaq_interpret::Val
+fn json_to_jaq(value: &serde_json::Value) -> jaq_interpret::Val {
+ match value {
+ serde_json::Value::Null => jaq_interpret::Val::Null,
+ serde_json::Value::Bool(b) => jaq_interpret::Val::Bool(*b),
+ serde_json::Value::Number(n) => {
+ if let Some(i) = n.as_i64() {
+ jaq_interpret::Val::Int(i as isize)
+ } else if let Some(f) = n.as_f64() {
+ jaq_interpret::Val::Float(f)
+ } else {
+ jaq_interpret::Val::Null
+ }
+ }
+ serde_json::Value::String(s) => jaq_interpret::Val::Str(s.clone().into()),
+ serde_json::Value::Array(arr) => {
+ jaq_interpret::Val::Arr(std::rc::Rc::new(arr.iter().map(json_to_jaq).collect()))
+ }
+ serde_json::Value::Object(obj) => {
+ let mut map: indexmap::IndexMap<std::rc::Rc<String>, jaq_interpret::Val, ahash::RandomState> =
+ indexmap::IndexMap::with_hasher(ahash::RandomState::new());
+ for (k, v) in obj {
+ map.insert(std::rc::Rc::new(k.clone()), json_to_jaq(v));
+ }
+ jaq_interpret::Val::Obj(std::rc::Rc::new(map))
+ }
+ }
+}
+
+/// Convert jaq_interpret::Val to serde_json::Value
+fn jaq_to_json(value: &jaq_interpret::Val) -> serde_json::Value {
+ match value {
+ jaq_interpret::Val::Null => serde_json::Value::Null,
+ jaq_interpret::Val::Bool(b) => json!(*b),
+ jaq_interpret::Val::Int(i) => json!(*i),
+ jaq_interpret::Val::Float(f) => json!(*f),
+ jaq_interpret::Val::Num(n) => {
+ // Try to parse the number string
+ if let Ok(i) = n.parse::<i64>() {
+ json!(i)
+ } else if let Ok(f) = n.parse::<f64>() {
+ json!(f)
+ } else {
+ json!(n.as_ref())
+ }
+ }
+ jaq_interpret::Val::Str(s) => json!(s.as_ref()),
+ jaq_interpret::Val::Arr(arr) => {
+ json!(arr.iter().map(jaq_to_json).collect::<Vec<_>>())
+ }
+ jaq_interpret::Val::Obj(obj) => {
+ let mut map = serde_json::Map::new();
+ for (k, v) in obj.iter() {
+ map.insert((**k).clone(), jaq_to_json(v));
+ }
+ serde_json::Value::Object(map)
+ }
+ }
+}