diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 14:43:23 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (patch) | |
| tree | 0545b4395dab6d957884d8d36bf15b8da529dc1f /makima/src/llm/tools.rs | |
| parent | a32dc56d2e5447ef8988cb98b8686476cc94e70c (diff) | |
| download | soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.tar.gz soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.zip | |
Add file body and initial tool call system
Diffstat (limited to 'makima/src/llm/tools.rs')
| -rw-r--r-- | makima/src/llm/tools.rs | 618 |
1 files changed, 618 insertions, 0 deletions
diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs new file mode 100644 index 0000000..3bd102f --- /dev/null +++ b/makima/src/llm/tools.rs @@ -0,0 +1,618 @@ +//! Tool definitions for file editing via LLM. + +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::db::models::{BodyElement, ChartType}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ToolResult { + pub success: bool, + pub message: String, +} + +/// Available tools for file editing +pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = + once_cell::sync::Lazy::new(|| { + vec![ + Tool { + name: "add_heading".to_string(), + description: "Add a heading element to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "level": { + "type": "integer", + "description": "Heading level (1-6)", + "minimum": 1, + "maximum": 6 + }, + "text": { + "type": "string", + "description": "The heading text" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["level", "text"] + }), + }, + Tool { + name: "add_paragraph".to_string(), + description: "Add a paragraph element to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The paragraph text" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["text"] + }), + }, + Tool { + name: "add_chart".to_string(), + description: "Add a chart visualization to the file body. Supports line, bar, pie, and area charts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "chart_type": { + "type": "string", + "enum": ["line", "bar", "pie", "area"], + "description": "Type of chart to create" + }, + "title": { + "type": "string", + "description": "Optional chart title" + }, + "data": { + "type": "array", + "description": "Array of data points. Each point should have a 'name' field and one or more numeric value fields.", + "items": { + "type": "object" + } + }, + "config": { + "type": "object", + "description": "Optional chart configuration (colors, axes, etc.)" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["chart_type", "data"] + }), + }, + Tool { + name: "remove_element".to_string(), + description: "Remove an element from the file body by index".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of element to remove (0-indexed)" + } + }, + "required": ["index"] + }), + }, + Tool { + name: "update_element".to_string(), + description: "Update an existing element in the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of element to update (0-indexed)" + }, + "element": { + "type": "object", + "description": "New element data. Must include 'type' field (heading, paragraph, chart)." + } + }, + "required": ["index", "element"] + }), + }, + Tool { + name: "reorder_elements".to_string(), + description: "Move an element from one position to another".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "from_index": { + "type": "integer", + "description": "Current index of the element" + }, + "to_index": { + "type": "integer", + "description": "New index for the element" + } + }, + "required": ["from_index", "to_index"] + }), + }, + Tool { + name: "set_summary".to_string(), + description: "Set the file summary text".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "The summary text" + } + }, + "required": ["summary"] + }), + }, + Tool { + name: "parse_csv".to_string(), + description: "Parse CSV data into JSON format suitable for charts".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "csv": { + "type": "string", + "description": "CSV data string with header row" + } + }, + "required": ["csv"] + }), + }, + Tool { + name: "clear_body".to_string(), + description: "Clear all elements from the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": {} + }), + }, + ] + }); + +/// Result of executing a tool call with modified file state +#[derive(Debug)] +pub struct ToolExecutionResult { + pub result: ToolResult, + pub new_body: Option<Vec<BodyElement>>, + pub new_summary: Option<String>, + pub parsed_data: Option<serde_json::Value>, +} + +/// Execute a tool call and return the result along with any state changes +pub fn execute_tool_call( + call: &ToolCall, + current_body: &[BodyElement], + current_summary: Option<&str>, +) -> ToolExecutionResult { + match call.name.as_str() { + "add_heading" => execute_add_heading(call, current_body), + "add_paragraph" => execute_add_paragraph(call, current_body), + "add_chart" => execute_add_chart(call, current_body), + "remove_element" => execute_remove_element(call, current_body), + "update_element" => execute_update_element(call, current_body), + "reorder_elements" => execute_reorder_elements(call, current_body), + "set_summary" => execute_set_summary(call, current_summary), + "parse_csv" => execute_parse_csv(call), + "clear_body" => execute_clear_body(), + _ => ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Unknown tool: {}", call.name), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }, + } +} + +fn execute_add_heading(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8; + let text = call + .arguments + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Heading { level, text: text.clone() }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added heading: {}", text), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_add_paragraph(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let text = call + .arguments + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Paragraph { text: text.clone() }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added paragraph: {}", preview), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_add_chart(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let chart_type_str = call + .arguments + .get("chart_type") + .and_then(|v| v.as_str()) + .unwrap_or("bar"); + + let chart_type = match chart_type_str { + "line" => ChartType::Line, + "bar" => ChartType::Bar, + "pie" => ChartType::Pie, + "area" => ChartType::Area, + _ => ChartType::Bar, + }; + + let title = call + .arguments + .get("title") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let data = call + .arguments + .get("data") + .cloned() + .unwrap_or(json!([])); + + let config = call.arguments.get("config").cloned(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Chart { + chart_type, + title: title.clone(), + data, + config, + }; + + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!( + "Added {} chart{}", + chart_type_str, + title.map(|t| format!(": {}", t)).unwrap_or_default() + ), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let index = call.arguments.get("index").and_then(|v| v.as_u64()); + + let Some(index) = index else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let index = index as usize; + if index >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Index {} out of bounds (body has {} elements)", index, current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let mut new_body = current_body.to_vec(); + new_body.remove(index); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Removed element at index {}", index), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let index = call.arguments.get("index").and_then(|v| v.as_u64()); + let element_json = call.arguments.get("element"); + + let Some(index) = index else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let Some(element_json) = element_json else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing element parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let index = index as usize; + if index >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Index {} out of bounds (body has {} elements)", index, current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone()); + match new_element { + Ok(element) => { + let mut new_body = current_body.to_vec(); + new_body[index] = element; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Updated element at index {}", index), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } + } + Err(e) => ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Invalid element format: {}", e), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }, + } +} + +fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let from_index = call.arguments.get("from_index").and_then(|v| v.as_u64()); + let to_index = call.arguments.get("to_index").and_then(|v| v.as_u64()); + + let (Some(from), Some(to)) = (from_index, to_index) else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing from_index or to_index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let from = from as usize; + let to = to as usize; + + if from >= current_body.len() || to >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!( + "Index out of bounds: from={}, to={}, body has {} elements", + from, to, current_body.len() + ), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let mut new_body = current_body.to_vec(); + let element = new_body.remove(from); + new_body.insert(to, element); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Moved element from index {} to {}", from, to), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_set_summary(call: &ToolCall, _current_summary: Option<&str>) -> ToolExecutionResult { + let summary = call + .arguments + .get("summary") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: "Summary updated".to_string(), + }, + new_body: None, + new_summary: Some(summary), + parsed_data: None, + } +} + +fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult { + let csv = call + .arguments + .get("csv") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let lines: Vec<&str> = csv.lines().collect(); + if lines.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Empty CSV data".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let headers: Vec<&str> = lines[0].split(',').map(|s| s.trim()).collect(); + let mut data: Vec<serde_json::Value> = Vec::new(); + + for line in lines.iter().skip(1) { + if line.trim().is_empty() { + continue; + } + let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect(); + let mut row = serde_json::Map::new(); + + for (i, header) in headers.iter().enumerate() { + if let Some(value) = values.get(i) { + // Try to parse as number, otherwise use string + if let Ok(num) = value.parse::<f64>() { + row.insert(header.to_string(), json!(num)); + } else { + row.insert(header.to_string(), json!(value)); + } + } + } + + data.push(serde_json::Value::Object(row)); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Parsed {} rows from CSV", data.len()), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!(data)), + } +} + +fn execute_clear_body() -> ToolExecutionResult { + ToolExecutionResult { + result: ToolResult { + success: true, + message: "Cleared all body elements".to_string(), + }, + new_body: Some(vec![]), + new_summary: None, + parsed_data: None, + } +} |
