diff options
Diffstat (limited to 'makima/src/server')
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 308 |
1 files changed, 247 insertions, 61 deletions
diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index e6d22ca..92c4ec8 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -12,17 +12,24 @@ use uuid::Uuid; use crate::db::{models::BodyElement, repository}; use crate::llm::{ + claude::{self, ClaudeClient, ClaudeError, ClaudeModel}, execute_tool_call, - groq::{GroqClient, GroqError, Message}, - ToolResult, AVAILABLE_TOOLS, + groq::{GroqClient, GroqError, Message, ToolCallResponse}, + LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS, }; use crate::server::state::SharedState; +/// Maximum number of tool-calling rounds to prevent infinite loops +const MAX_TOOL_ROUNDS: usize = 10; + #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct ChatRequest { /// The user's message/instruction pub message: String, + /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq" + #[serde(default)] + pub model: Option<String>, } #[derive(Debug, Serialize, ToSchema)] @@ -45,9 +52,18 @@ pub struct ToolCallInfo { pub result: ToolResult, } -#[derive(Debug, Serialize)] -struct ErrorResponse { - error: String, +/// Enum to hold LLM clients +enum LlmClient { + Groq(GroqClient), + Claude(ClaudeClient), +} + +/// Unified result from LLM call +struct LlmResult { + content: Option<String>, + tool_calls: Vec<ToolCall>, + raw_tool_calls: Vec<ToolCallResponse>, + finish_reason: String, } /// Chat with a file using LLM tool calling @@ -105,40 +121,102 @@ pub async fn chat_handler( } }; - // Initialize Groq client - let groq = match GroqClient::from_env() { - Ok(client) => client, - Err(GroqError::MissingApiKey) => { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "error": "GROQ_API_KEY not configured" - })), - ) - .into_response(); + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or_default(); + + tracing::info!("Using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => { + match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "ANTHROPIC_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Claude client error: {}", e) + })), + ) + .into_response(); + } + } } - Err(e) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("Groq client error: {}", e) - })), - ) - .into_response(); + LlmModel::ClaudeOpus => { + match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "ANTHROPIC_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Claude client error: {}", e) + })), + ) + .into_response(); + } + } + } + LlmModel::GroqKimi => { + match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "GROQ_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Groq client error: {}", e) + })), + ) + .into_response(); + } + } } }; // Build context about the file let file_context = build_file_context(&file); - // Build messages - let messages = vec![ + // Build initial messages (Groq/OpenAI format - will be converted for Claude) + let mut messages = vec![ Message { role: "system".to_string(), content: Some(format!( "You are a helpful assistant that helps users edit and analyze document files. \ You have access to tools to add headings, paragraphs, charts, and set summaries. \ When the user asks you to modify the file, use the appropriate tools.\n\n\ + IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \ + and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \ + to call add_chart.\n\n\ Current file context:\n{}", file_context )), @@ -153,46 +231,154 @@ pub async fn chat_handler( }, ]; - // Call Groq API - let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await { - Ok(result) => result, - Err(e) => { - tracing::error!("Groq API error: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("LLM API error: {}", e) - })), - ) - .into_response(); - } - }; - - // Execute tool calls + // State for tracking changes let mut current_body = file.body.clone(); let mut current_summary = file.summary.clone(); - let mut tool_call_infos = Vec::new(); + let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new(); + let mut final_response: Option<String> = None; - for tool_call in &result.tool_calls { - let execution_result = - execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + // Multi-turn tool calling loop + for round in 0..MAX_TOOL_ROUNDS { + tracing::debug!(round = round, "LLM tool calling round"); - // Apply state changes - if let Some(new_body) = execution_result.new_body { - current_body = new_body; - } - if let Some(new_summary) = execution_result.new_summary { - current_summary = Some(new_summary); + // Call the appropriate LLM API + let result = match &llm_client { + LlmClient::Groq(groq) => { + match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await { + Ok(r) => LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls: r.raw_tool_calls, + finish_reason: r.finish_reason, + }, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + } + } + LlmClient::Claude(claude_client) => { + // Convert messages to Claude format + let claude_messages = claude::groq_messages_to_claude(&messages); + match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await { + Ok(r) => { + // Convert Claude tool uses to Groq-style ToolCallResponse for consistency + let raw_tool_calls: Vec<ToolCallResponse> = r + .tool_calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + call_type: "function".to_string(), + function: crate::llm::groq::FunctionCall { + name: tc.name.clone(), + arguments: tc.arguments.to_string(), + }, + }) + .collect(); + + LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls, + finish_reason: r.stop_reason, + } + } + Err(e) => { + tracing::error!("Claude API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + } + } + }; + + // Check if there are tool calls to execute + if result.tool_calls.is_empty() { + // No more tool calls - capture the final response and exit loop + final_response = result.content; + break; } - tool_call_infos.push(ToolCallInfo { - name: tool_call.name.clone(), - result: execution_result.result, + // Add assistant message with tool calls to conversation + messages.push(Message { + role: "assistant".to_string(), + content: result.content.clone(), + tool_calls: Some(result.raw_tool_calls.clone()), + tool_call_id: None, }); + + // Execute each tool call and add results to conversation + for (i, tool_call) in result.tool_calls.iter().enumerate() { + let execution_result = + execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + + // Apply state changes + if let Some(new_body) = execution_result.new_body { + current_body = new_body; + } + if let Some(new_summary) = execution_result.new_summary { + current_summary = Some(new_summary); + } + + // Build tool result message content + let result_content = if let Some(parsed_data) = &execution_result.parsed_data { + // Include parsed data in the result for the LLM to use + serde_json::json!({ + "success": execution_result.result.success, + "message": execution_result.result.message, + "data": parsed_data + }) + .to_string() + } else { + serde_json::json!({ + "success": execution_result.result.success, + "message": execution_result.result.message + }) + .to_string() + }; + + // Add tool result message + // Use the appropriate ID format for each provider + let tool_call_id = match &llm_client { + LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(), + LlmClient::Claude(_) => tool_call.id.clone(), + }; + + messages.push(Message { + role: "tool".to_string(), + content: Some(result_content), + tool_calls: None, + tool_call_id: Some(tool_call_id), + }); + + // Track for response + all_tool_call_infos.push(ToolCallInfo { + name: tool_call.name.clone(), + result: execution_result.result, + }); + } + + // If finish reason indicates completion, exit loop + let finish_lower = result.finish_reason.to_lowercase(); + if finish_lower == "stop" || finish_lower == "end_turn" { + final_response = result.content; + break; + } } // Save changes to database if any tools were executed - if !result.tool_calls.is_empty() { + if !all_tool_call_infos.is_empty() { let update_req = crate::db::models::UpdateFileRequest { name: None, description: None, @@ -214,14 +400,14 @@ pub async fn chat_handler( } // Build response - let response_text = result.content.unwrap_or_else(|| { - if tool_call_infos.is_empty() { + let response_text = final_response.unwrap_or_else(|| { + if all_tool_call_infos.is_empty() { "I couldn't understand your request. Please try rephrasing.".to_string() } else { format!( "Done! Executed {} tool{}.", - tool_call_infos.len(), - if tool_call_infos.len() == 1 { "" } else { "s" } + all_tool_call_infos.len(), + if all_tool_call_infos.len() == 1 { "" } else { "s" } ) } }); @@ -230,7 +416,7 @@ pub async fn chat_handler( StatusCode::OK, Json(ChatResponse { response: response_text, - tool_calls: tool_call_infos, + tool_calls: all_tool_call_infos, updated_body: current_body, updated_summary: current_summary, }), |
