diff options
Diffstat (limited to 'makima/src/llm/tools.rs')
| -rw-r--r-- | makima/src/llm/tools.rs | 308 |
1 files changed, 281 insertions, 27 deletions
diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 3bd102f..e6b2954 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -1,5 +1,6 @@ //! Tool definitions for file editing via LLM. +use jaq_interpret::FilterT; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -121,7 +122,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }, Tool { name: "update_element".to_string(), - description: "Update an existing element in the file body".to_string(), + description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(), parameters: json!({ "type": "object", "properties": { @@ -129,12 +130,35 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "type": "integer", "description": "Index of element to update (0-indexed)" }, - "element": { - "type": "object", - "description": "New element data. Must include 'type' field (heading, paragraph, chart)." + "element_type": { + "type": "string", + "enum": ["heading", "paragraph", "chart"], + "description": "Type of element" + }, + "text": { + "type": "string", + "description": "Text content (required for heading and paragraph)" + }, + "level": { + "type": "integer", + "description": "Heading level 1-6 (required for heading)" + }, + "chartType": { + "type": "string", + "enum": ["line", "bar", "pie", "area"], + "description": "Chart type (required for chart)" + }, + "data": { + "type": "array", + "description": "Chart data array (required for chart)", + "items": { "type": "object" } + }, + "title": { + "type": "string", + "description": "Chart title (optional for chart)" } }, - "required": ["index", "element"] + "required": ["index", "element_type"] }), }, Tool { @@ -191,6 +215,23 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "properties": {} }), }, + Tool { + name: "jq".to_string(), + description: "Transform JSON data using jq expressions. Useful for filtering, mapping, grouping, and aggregating data before creating charts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "input": { + "description": "The JSON data to transform (can be an array or object)" + }, + "filter": { + "type": "string", + "description": "The jq filter expression. Examples: '.[] | select(.value > 10)', 'group_by(.category) | map({name: .[0].category, count: length})', '[.[] | {name: .label, value: .amount}]'" + } + }, + "required": ["input", "filter"] + }), + }, ] }); @@ -219,6 +260,7 @@ pub fn execute_tool_call( "set_summary" => execute_set_summary(call, current_summary), "parse_csv" => execute_parse_csv(call), "clear_body" => execute_clear_body(), + "jq" => execute_jq(call), _ => ToolExecutionResult { result: ToolResult { success: false, @@ -415,7 +457,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { let index = call.arguments.get("index").and_then(|v| v.as_u64()); - let element_json = call.arguments.get("element"); + let element_type = call.arguments.get("element_type").and_then(|v| v.as_str()); let Some(index) = index else { return ToolExecutionResult { @@ -429,11 +471,11 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; }; - let Some(element_json) = element_json else { + let Some(element_type) = element_type else { return ToolExecutionResult { result: ToolResult { success: false, - message: "Missing element parameter".to_string(), + message: "Missing element_type parameter".to_string(), }, new_body: None, new_summary: None, @@ -454,31 +496,55 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; } - let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone()); - match new_element { - Ok(element) => { - let mut new_body = current_body.to_vec(); - new_body[index] = element; - - ToolExecutionResult { + // Build the element based on type + let new_element = match element_type { + "heading" => { + let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8; + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Heading { level, text } + } + "paragraph" => { + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Paragraph { text } + } + "chart" => { + let chart_type_str = call.arguments.get("chartType").and_then(|v| v.as_str()).unwrap_or("bar"); + let chart_type = match chart_type_str { + "line" => ChartType::Line, + "bar" => ChartType::Bar, + "pie" => ChartType::Pie, + "area" => ChartType::Area, + _ => ChartType::Bar, + }; + let title = call.arguments.get("title").and_then(|v| v.as_str()).map(|s| s.to_string()); + let data = call.arguments.get("data").cloned().unwrap_or(json!([])); + let config = call.arguments.get("config").cloned(); + BodyElement::Chart { chart_type, title, data, config } + } + _ => { + return ToolExecutionResult { result: ToolResult { - success: true, - message: format!("Updated element at index {}", index), + success: false, + message: format!("Unknown element_type: {}. Must be heading, paragraph, or chart.", element_type), }, - new_body: Some(new_body), + new_body: None, new_summary: None, parsed_data: None, - } + }; } - Err(e) => ToolExecutionResult { - result: ToolResult { - success: false, - message: format!("Invalid element format: {}", e), - }, - new_body: None, - new_summary: None, - parsed_data: None, + }; + + let mut new_body = current_body.to_vec(); + new_body[index] = new_element; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Updated element at index {} to {}", index, element_type), }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, } } @@ -616,3 +682,191 @@ fn execute_clear_body() -> ToolExecutionResult { parsed_data: None, } } + +fn execute_jq(call: &ToolCall) -> ToolExecutionResult { + let input = match call.arguments.get("input") { + Some(v) => v.clone(), + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing input parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + let filter = match call.arguments.get("filter").and_then(|v| v.as_str()) { + Some(f) => f, + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing filter parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + // Parse the jq filter + let mut defs = jaq_interpret::ParseCtx::new(Vec::new()); + defs.insert_natives(jaq_core::core()); + defs.insert_defs(jaq_std::std()); + + let (parsed_filter, errs) = jaq_parse::parse(filter, jaq_parse::main()); + if !errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Invalid jq filter: {:?}", errs), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let Some(parsed_filter) = parsed_filter else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Failed to parse jq filter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + // Compile the filter + let compiled = defs.compile(parsed_filter); + if !defs.errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Failed to compile jq filter ({} errors)", defs.errs.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + // Convert serde_json::Value to jaq Value + let jaq_input = json_to_jaq(&input); + + // Execute the filter + let inputs = jaq_interpret::RcIter::new(std::iter::empty()); + let mut results: Vec<serde_json::Value> = Vec::new(); + + for output in compiled.run((jaq_interpret::Ctx::new([], &inputs), jaq_input)) { + match output { + Ok(val) => { + results.push(jaq_to_json(&val)); + } + Err(e) => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("jq execution error: {:?}", e), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + } + } + + // Return single value or array based on results + let output = if results.len() == 1 { + results.into_iter().next().unwrap() + } else { + json!(results) + }; + + let preview = { + let s = output.to_string(); + if s.len() > 100 { + format!("{}...", &s[..100]) + } else { + s + } + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("jq transform complete: {}", preview), + }, + new_body: None, + new_summary: None, + parsed_data: Some(output), + } +} + +/// Convert serde_json::Value to jaq_interpret::Val +fn json_to_jaq(value: &serde_json::Value) -> jaq_interpret::Val { + match value { + serde_json::Value::Null => jaq_interpret::Val::Null, + serde_json::Value::Bool(b) => jaq_interpret::Val::Bool(*b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + jaq_interpret::Val::Int(i as isize) + } else if let Some(f) = n.as_f64() { + jaq_interpret::Val::Float(f) + } else { + jaq_interpret::Val::Null + } + } + serde_json::Value::String(s) => jaq_interpret::Val::Str(s.clone().into()), + serde_json::Value::Array(arr) => { + jaq_interpret::Val::Arr(std::rc::Rc::new(arr.iter().map(json_to_jaq).collect())) + } + serde_json::Value::Object(obj) => { + let mut map: indexmap::IndexMap<std::rc::Rc<String>, jaq_interpret::Val, ahash::RandomState> = + indexmap::IndexMap::with_hasher(ahash::RandomState::new()); + for (k, v) in obj { + map.insert(std::rc::Rc::new(k.clone()), json_to_jaq(v)); + } + jaq_interpret::Val::Obj(std::rc::Rc::new(map)) + } + } +} + +/// Convert jaq_interpret::Val to serde_json::Value +fn jaq_to_json(value: &jaq_interpret::Val) -> serde_json::Value { + match value { + jaq_interpret::Val::Null => serde_json::Value::Null, + jaq_interpret::Val::Bool(b) => json!(*b), + jaq_interpret::Val::Int(i) => json!(*i), + jaq_interpret::Val::Float(f) => json!(*f), + jaq_interpret::Val::Num(n) => { + // Try to parse the number string + if let Ok(i) = n.parse::<i64>() { + json!(i) + } else if let Ok(f) = n.parse::<f64>() { + json!(f) + } else { + json!(n.as_ref()) + } + } + jaq_interpret::Val::Str(s) => json!(s.as_ref()), + jaq_interpret::Val::Arr(arr) => { + json!(arr.iter().map(jaq_to_json).collect::<Vec<_>>()) + } + jaq_interpret::Val::Obj(obj) => { + let mut map = serde_json::Map::new(); + for (k, v) in obj.iter() { + map.insert((**k).clone(), jaq_to_json(v)); + } + serde_json::Value::Object(map) + } + } +} |
