diff options
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/llm/tools.rs | 234 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 307 |
2 files changed, 525 insertions, 16 deletions
diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 35f321f..216b733 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -4,7 +4,7 @@ use jaq_interpret::FilterT; use serde::{Deserialize, Serialize}; use serde_json::json; -use crate::db::models::{BodyElement, ChartType}; +use crate::db::models::{BodyElement, ChartType, TranscriptEntry}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { @@ -232,6 +232,39 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "required": ["input", "filter"] }), }, + // Content viewing tools + Tool { + name: "view_body".to_string(), + description: "View the complete body structure with full content of all elements. Returns detailed information about each element including type, index, and full text/data.".to_string(), + parameters: json!({ + "type": "object", + "properties": {}, + "required": [] + }), + }, + Tool { + name: "read_element".to_string(), + description: "Read the full content of a specific body element by its index. Use this to get complete details of a single element.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of the element to read (0-indexed)" + } + }, + "required": ["index"] + }), + }, + Tool { + name: "view_transcript".to_string(), + description: "View the complete transcript of the file. Returns all transcript entries with speaker names, text, and timestamps.".to_string(), + parameters: json!({ + "type": "object", + "properties": {}, + "required": [] + }), + }, // Version history tools Tool { name: "list_versions".to_string(), @@ -304,6 +337,7 @@ pub fn execute_tool_call( call: &ToolCall, current_body: &[BodyElement], current_summary: Option<&str>, + transcript: &[TranscriptEntry], ) -> ToolExecutionResult { match call.name.as_str() { "add_heading" => execute_add_heading(call, current_body), @@ -316,6 +350,10 @@ pub fn execute_tool_call( "parse_csv" => execute_parse_csv(call), "clear_body" => execute_clear_body(), "jq" => execute_jq(call), + // Content viewing tools + "view_body" => execute_view_body(current_body), + "read_element" => execute_read_element(call, current_body), + "view_transcript" => execute_view_transcript(transcript), // Version history tools - return request for async handling "list_versions" => execute_list_versions(), "read_version" => execute_read_version(call), @@ -897,6 +935,200 @@ fn execute_jq(call: &ToolCall) -> ToolExecutionResult { } // ============================================================================= +// Content Viewing Tool Execution Functions +// ============================================================================= + +fn execute_view_body(current_body: &[BodyElement]) -> ToolExecutionResult { + if current_body.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: true, + message: "Body is empty (no elements)".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!([])), + version_request: None, + }; + } + + let elements: Vec<serde_json::Value> = current_body + .iter() + .enumerate() + .map(|(i, element)| { + match element { + BodyElement::Heading { level, text } => json!({ + "index": i, + "type": "heading", + "level": level, + "text": text + }), + BodyElement::Paragraph { text } => json!({ + "index": i, + "type": "paragraph", + "text": text + }), + BodyElement::Chart { chart_type, title, data, config } => json!({ + "index": i, + "type": "chart", + "chartType": format!("{:?}", chart_type).to_lowercase(), + "title": title, + "data": data, + "config": config + }), + BodyElement::Image { src, alt, caption } => json!({ + "index": i, + "type": "image", + "src": src, + "alt": alt, + "caption": caption + }), + } + }) + .collect(); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Body contains {} element(s)", current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!(elements)), + version_request: None, + } +} + +fn execute_read_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let index = call.arguments.get("index").and_then(|v| v.as_u64()); + + let Some(index) = index else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + version_request: None, + }; + }; + + let index = index as usize; + if index >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Index {} out of bounds (body has {} elements)", index, current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + version_request: None, + }; + } + + let element = ¤t_body[index]; + let element_data = match element { + BodyElement::Heading { level, text } => json!({ + "index": index, + "type": "heading", + "level": level, + "text": text + }), + BodyElement::Paragraph { text } => json!({ + "index": index, + "type": "paragraph", + "text": text + }), + BodyElement::Chart { chart_type, title, data, config } => json!({ + "index": index, + "type": "chart", + "chartType": format!("{:?}", chart_type).to_lowercase(), + "title": title, + "data": data, + "config": config + }), + BodyElement::Image { src, alt, caption } => json!({ + "index": index, + "type": "image", + "src": src, + "alt": alt, + "caption": caption + }), + }; + + let type_str = match element { + BodyElement::Heading { .. } => "heading", + BodyElement::Paragraph { .. } => "paragraph", + BodyElement::Chart { .. } => "chart", + BodyElement::Image { .. } => "image", + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Element {} is a {}", index, type_str), + }, + new_body: None, + new_summary: None, + parsed_data: Some(element_data), + version_request: None, + } +} + +fn execute_view_transcript(transcript: &[TranscriptEntry]) -> ToolExecutionResult { + if transcript.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: true, + message: "Transcript is empty".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!([])), + version_request: None, + }; + } + + let entries: Vec<serde_json::Value> = transcript + .iter() + .enumerate() + .map(|(i, entry)| { + json!({ + "index": i, + "speaker": entry.speaker, + "text": entry.text, + "start": entry.start, + "end": entry.end + }) + }) + .collect(); + + // Calculate duration from timestamps + let duration_info = if let (Some(first), Some(last)) = (transcript.first(), transcript.last()) { + let duration_secs = last.end - first.start; + let minutes = (duration_secs / 60.0).floor() as u32; + let seconds = (duration_secs % 60.0).round() as u32; + format!(" (duration: {}:{:02})", minutes, seconds) + } else { + String::new() + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Transcript has {} entries{}", transcript.len(), duration_info), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!(entries)), + version_request: None, + } +} + +// ============================================================================= // Version History Tool Execution Functions // ============================================================================= // These return version_request instead of performing the operation directly, diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index 396c973..306093a 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -20,7 +20,18 @@ use crate::llm::{ use crate::server::state::{FileUpdateNotification, SharedState}; /// Maximum number of tool-calling rounds to prevent infinite loops -const MAX_TOOL_ROUNDS: usize = 10; +const MAX_TOOL_ROUNDS: usize = 20; + +/// Context limits for different models (in tokens) +/// Claude models have 200K context, Groq models vary +const CLAUDE_CONTEXT_LIMIT: usize = 200_000; +const GROQ_CONTEXT_LIMIT: usize = 32_000; + +/// Threshold for triggering context compaction (90% of limit) +const CONTEXT_COMPACTION_THRESHOLD: f32 = 0.90; + +/// Approximate characters per token (rough estimate for English text) +const CHARS_PER_TOKEN: usize = 4; #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] @@ -206,20 +217,62 @@ pub async fn chat_handler( // Build context about the file let file_context = build_file_context(&file); + // Build agentic system prompt + let system_prompt = format!( + r#"You are an intelligent document editing agent. You help users view, analyze, and modify document files. + +## Your Capabilities +You have access to tools for: +- **Viewing content**: view_body (see all elements), read_element (inspect specific element), view_transcript (read full transcript) +- **Adding content**: add_heading, add_paragraph, add_chart +- **Modifying content**: update_element, remove_element, reorder_elements, clear_body +- **Document metadata**: set_summary +- **Data processing**: parse_csv (convert CSV to JSON), jq (transform JSON data) +- **Version history**: list_versions, read_version, restore_version + +## Agentic Behavior Guidelines + +### 1. Analyze Before Acting +- For complex requests, first gather information using view_body, view_transcript, or read_element +- Understand the current state of the document before making changes +- For simple, direct requests (e.g., "add a heading called X"), you can act immediately without prior inspection + +### 2. Plan Multi-Step Operations +- Break complex tasks into logical steps +- For data visualization: parse_csv → (optionally jq to transform) → add_chart +- For restructuring: view_body → understand structure → make targeted changes + +### 3. Handle Errors Gracefully +- If a tool call fails, analyze the error message +- Try an alternative approach or different parameters +- Don't repeat the exact same failing call + +### 4. Know When to Stop +- Stop when you've completed the user's request +- Stop when you've provided the requested information +- Provide a clear summary of what you did in your final response + +### 5. Be Efficient +- Don't over-analyze simple requests +- Use the minimum number of tool calls needed +- Combine operations when possible + +## Current Document Context +{file_context} + +## Important Notes +- Body element indices are 0-based +- When updating elements, provide ALL required fields for that element type +- The transcript is read-only (you cannot modify it, only read it) +- Changes are saved automatically after tool execution"#, + file_context = file_context + ); + // Build initial messages (Groq/OpenAI format - will be converted for Claude) let mut messages = vec![ Message { role: "system".to_string(), - content: Some(format!( - "You are a helpful assistant that helps users edit and analyze document files. \ - You have access to tools to add headings, paragraphs, charts, and set summaries. \ - When the user asks you to modify the file, use the appropriate tools.\n\n\ - IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \ - and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \ - to call add_chart.\n\n\ - Current file context:\n{}", - file_context - )), + content: Some(system_prompt), tool_calls: None, tool_call_id: None, }, @@ -240,10 +293,48 @@ pub async fn chat_handler( let mut version_restored = false; // Track if there were modifications after a restore let mut has_changes_after_restore = false; + // Track consecutive failures for agentic retry logic + let mut consecutive_failures = 0; + const MAX_CONSECUTIVE_FAILURES: usize = 3; - // Multi-turn tool calling loop + // Multi-turn agentic tool calling loop for round in 0..MAX_TOOL_ROUNDS { - tracing::debug!(round = round, "LLM tool calling round"); + tracing::info!( + round = round, + body_elements = current_body.len(), + total_tool_calls = all_tool_call_infos.len(), + "Agentic loop iteration" + ); + + // Check if we've hit too many consecutive failures + if consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + tracing::warn!("Breaking loop due to {} consecutive failures", consecutive_failures); + final_response = Some(format!( + "I encountered multiple consecutive errors and stopped to avoid an infinite loop. \ + Please try rephrasing your request or check if the document state is as expected." + )); + break; + } + + // Check context usage and compact if nearing limit + if is_context_near_limit(&messages, &model) { + let estimated_tokens = estimate_total_tokens(&messages); + tracing::warn!( + estimated_tokens = estimated_tokens, + round = round, + "Context nearing limit, compacting conversation" + ); + compact_conversation(&mut messages, &all_tool_call_infos); + + // Log the new token count + let new_tokens = estimate_total_tokens(&messages); + tracing::info!( + tokens_before = estimated_tokens, + tokens_after = new_tokens, + tokens_saved = estimated_tokens - new_tokens, + "Conversation compacted" + ); + } // Call the appropriate LLM API let result = match &llm_client { @@ -324,8 +415,14 @@ pub async fn chat_handler( // Execute each tool call and add results to conversation for (i, tool_call) in result.tool_calls.iter().enumerate() { + tracing::info!( + tool = %tool_call.name, + round = round, + "Executing tool call" + ); + let mut execution_result = - execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + execute_tool_call(tool_call, ¤t_body, current_summary.as_deref(), &file.transcript); // Handle version tool requests that need async database access if let Some(version_request) = &execution_result.version_request { @@ -369,7 +466,19 @@ pub async fn chat_handler( } } - // Build tool result message content + // Track consecutive failures for agentic behavior + if execution_result.result.success { + consecutive_failures = 0; + } else { + consecutive_failures += 1; + tracing::warn!( + tool = %tool_call.name, + consecutive_failures = consecutive_failures, + "Tool call failed" + ); + } + + // Build tool result message content with enhanced context for agentic reasoning let result_content = if let Some(parsed_data) = &execution_result.parsed_data { // Include parsed data in the result for the LLM to use serde_json::json!({ @@ -378,6 +487,19 @@ pub async fn chat_handler( "data": parsed_data }) .to_string() + } else if !execution_result.result.success { + // On failure, include hints for the LLM + let hint = if consecutive_failures >= MAX_CONSECUTIVE_FAILURES { + " [HINT: Multiple consecutive failures detected. Consider a different approach or verify your parameters.]" + } else { + "" + }; + serde_json::json!({ + "success": false, + "message": format!("{}{}", execution_result.result.message, hint), + "currentBodyElementCount": current_body.len() + }) + .to_string() } else { serde_json::json!({ "success": execution_result.result.success, @@ -742,3 +864,158 @@ async fn handle_version_request( } } } + +/// Estimate the token count of a message +fn estimate_message_tokens(message: &Message) -> usize { + let mut chars = 0; + + // Count content characters + if let Some(ref content) = message.content { + chars += content.len(); + } + + // Count tool call characters (rough estimate) + if let Some(ref tool_calls) = message.tool_calls { + for tc in tool_calls { + chars += tc.function.name.len(); + chars += tc.function.arguments.len(); + } + } + + // Count tool call ID + if let Some(ref id) = message.tool_call_id { + chars += id.len(); + } + + // Add overhead for role and structure + chars += message.role.len() + 20; + + // Convert to tokens + chars / CHARS_PER_TOKEN +} + +/// Estimate total token count of all messages +fn estimate_total_tokens(messages: &[Message]) -> usize { + messages.iter().map(estimate_message_tokens).sum() +} + +/// Check if context is nearing the limit +fn is_context_near_limit(messages: &[Message], model: &LlmModel) -> bool { + let estimated_tokens = estimate_total_tokens(messages); + let limit = match model { + LlmModel::ClaudeSonnet | LlmModel::ClaudeOpus => CLAUDE_CONTEXT_LIMIT, + LlmModel::GroqKimi => GROQ_CONTEXT_LIMIT, + }; + let threshold = (limit as f32 * CONTEXT_COMPACTION_THRESHOLD) as usize; + + estimated_tokens >= threshold +} + +/// Compact the conversation by summarizing older messages +/// Keeps: system message, last N user/assistant exchanges, and a summary of older content +fn compact_conversation(messages: &mut Vec<Message>, tool_call_history: &[ToolCallInfo]) { + // Keep at least system message + 4 recent messages (2 exchanges) + const MIN_MESSAGES_TO_KEEP: usize = 5; + + if messages.len() <= MIN_MESSAGES_TO_KEEP { + return; + } + + // Extract system message (always first) + let system_message = messages.remove(0); + + // Calculate how many messages to summarize + // Keep the last ~1/3 of messages for recent context + let messages_to_keep = std::cmp::max(4, messages.len() / 3); + let messages_to_summarize = messages.len() - messages_to_keep; + + if messages_to_summarize < 2 { + // Not enough to summarize, just put system message back + messages.insert(0, system_message); + return; + } + + // Extract messages to summarize + let old_messages: Vec<Message> = messages.drain(..messages_to_summarize).collect(); + + // Build summary of old messages + let mut summary_parts: Vec<String> = Vec::new(); + + // Summarize user requests + let user_requests: Vec<&str> = old_messages + .iter() + .filter(|m| m.role == "user") + .filter_map(|m| m.content.as_deref()) + .collect(); + + if !user_requests.is_empty() { + summary_parts.push(format!( + "Previous user requests: {}", + user_requests.join("; ") + )); + } + + // Summarize tool calls executed so far + if !tool_call_history.is_empty() { + let tool_summary: Vec<String> = tool_call_history + .iter() + .map(|tc| { + if tc.result.success { + format!("{}(ok)", tc.name) + } else { + format!("{}(failed: {})", tc.name, tc.result.message) + } + }) + .collect(); + + summary_parts.push(format!( + "Tools executed: {}", + tool_summary.join(", ") + )); + } + + // Count assistant responses that were summarized + let assistant_responses = old_messages + .iter() + .filter(|m| m.role == "assistant" && m.content.is_some()) + .count(); + + if assistant_responses > 0 { + summary_parts.push(format!( + "({} previous assistant responses omitted for brevity)", + assistant_responses + )); + } + + // Create compacted context message + let compacted_content = format!( + "[CONTEXT SUMMARY - Earlier conversation compacted to save tokens]\n{}", + summary_parts.join("\n") + ); + + // Rebuild messages: system + summary + remaining recent messages + let mut new_messages = vec![ + system_message, + Message { + role: "user".to_string(), + content: Some(compacted_content), + tool_calls: None, + tool_call_id: None, + }, + Message { + role: "assistant".to_string(), + content: Some("Understood. I have context from the previous conversation and will continue from here.".to_string()), + tool_calls: None, + tool_call_id: None, + }, + ]; + + new_messages.append(messages); + *messages = new_messages; + + tracing::info!( + summarized_messages = messages_to_summarize, + remaining_messages = messages.len(), + "Compacted conversation to save context" + ); +} |
