summaryrefslogtreecommitdiff
path: root/makima/src/server
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server')
-rw-r--r--makima/src/server/handlers/chat.rs308
1 files changed, 247 insertions, 61 deletions
diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs
index e6d22ca..92c4ec8 100644
--- a/makima/src/server/handlers/chat.rs
+++ b/makima/src/server/handlers/chat.rs
@@ -12,17 +12,24 @@ use uuid::Uuid;
use crate::db::{models::BodyElement, repository};
use crate::llm::{
+ claude::{self, ClaudeClient, ClaudeError, ClaudeModel},
execute_tool_call,
- groq::{GroqClient, GroqError, Message},
- ToolResult, AVAILABLE_TOOLS,
+ groq::{GroqClient, GroqError, Message, ToolCallResponse},
+ LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS,
};
use crate::server::state::SharedState;
+/// Maximum number of tool-calling rounds to prevent infinite loops
+const MAX_TOOL_ROUNDS: usize = 10;
+
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
/// The user's message/instruction
pub message: String,
+ /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq"
+ #[serde(default)]
+ pub model: Option<String>,
}
#[derive(Debug, Serialize, ToSchema)]
@@ -45,9 +52,18 @@ pub struct ToolCallInfo {
pub result: ToolResult,
}
-#[derive(Debug, Serialize)]
-struct ErrorResponse {
- error: String,
+/// Enum to hold LLM clients
+enum LlmClient {
+ Groq(GroqClient),
+ Claude(ClaudeClient),
+}
+
+/// Unified result from LLM call
+struct LlmResult {
+ content: Option<String>,
+ tool_calls: Vec<ToolCall>,
+ raw_tool_calls: Vec<ToolCallResponse>,
+ finish_reason: String,
}
/// Chat with a file using LLM tool calling
@@ -105,40 +121,102 @@ pub async fn chat_handler(
}
};
- // Initialize Groq client
- let groq = match GroqClient::from_env() {
- Ok(client) => client,
- Err(GroqError::MissingApiKey) => {
- return (
- StatusCode::SERVICE_UNAVAILABLE,
- Json(serde_json::json!({
- "error": "GROQ_API_KEY not configured"
- })),
- )
- .into_response();
+ // Parse model selection (default to Claude Sonnet)
+ let model = request
+ .model
+ .as_ref()
+ .and_then(|m| LlmModel::from_str(m))
+ .unwrap_or_default();
+
+ tracing::info!("Using LLM model: {:?}", model);
+
+ // Initialize the appropriate LLM client
+ let llm_client = match model {
+ LlmModel::ClaudeSonnet => {
+ match ClaudeClient::from_env(ClaudeModel::Sonnet) {
+ Ok(client) => LlmClient::Claude(client),
+ Err(ClaudeError::MissingApiKey) => {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ Json(serde_json::json!({
+ "error": "ANTHROPIC_API_KEY not configured"
+ })),
+ )
+ .into_response();
+ }
+ Err(e) => {
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Claude client error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
}
- Err(e) => {
- return (
- StatusCode::INTERNAL_SERVER_ERROR,
- Json(serde_json::json!({
- "error": format!("Groq client error: {}", e)
- })),
- )
- .into_response();
+ LlmModel::ClaudeOpus => {
+ match ClaudeClient::from_env(ClaudeModel::Opus) {
+ Ok(client) => LlmClient::Claude(client),
+ Err(ClaudeError::MissingApiKey) => {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ Json(serde_json::json!({
+ "error": "ANTHROPIC_API_KEY not configured"
+ })),
+ )
+ .into_response();
+ }
+ Err(e) => {
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Claude client error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
+ }
+ LlmModel::GroqKimi => {
+ match GroqClient::from_env() {
+ Ok(client) => LlmClient::Groq(client),
+ Err(GroqError::MissingApiKey) => {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ Json(serde_json::json!({
+ "error": "GROQ_API_KEY not configured"
+ })),
+ )
+ .into_response();
+ }
+ Err(e) => {
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Groq client error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
}
};
// Build context about the file
let file_context = build_file_context(&file);
- // Build messages
- let messages = vec![
+ // Build initial messages (Groq/OpenAI format - will be converted for Claude)
+ let mut messages = vec![
Message {
role: "system".to_string(),
content: Some(format!(
"You are a helpful assistant that helps users edit and analyze document files. \
You have access to tools to add headings, paragraphs, charts, and set summaries. \
When the user asks you to modify the file, use the appropriate tools.\n\n\
+ IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \
+ and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \
+ to call add_chart.\n\n\
Current file context:\n{}",
file_context
)),
@@ -153,46 +231,154 @@ pub async fn chat_handler(
},
];
- // Call Groq API
- let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await {
- Ok(result) => result,
- Err(e) => {
- tracing::error!("Groq API error: {}", e);
- return (
- StatusCode::INTERNAL_SERVER_ERROR,
- Json(serde_json::json!({
- "error": format!("LLM API error: {}", e)
- })),
- )
- .into_response();
- }
- };
-
- // Execute tool calls
+ // State for tracking changes
let mut current_body = file.body.clone();
let mut current_summary = file.summary.clone();
- let mut tool_call_infos = Vec::new();
+ let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new();
+ let mut final_response: Option<String> = None;
- for tool_call in &result.tool_calls {
- let execution_result =
- execute_tool_call(tool_call, &current_body, current_summary.as_deref());
+ // Multi-turn tool calling loop
+ for round in 0..MAX_TOOL_ROUNDS {
+ tracing::debug!(round = round, "LLM tool calling round");
- // Apply state changes
- if let Some(new_body) = execution_result.new_body {
- current_body = new_body;
- }
- if let Some(new_summary) = execution_result.new_summary {
- current_summary = Some(new_summary);
+ // Call the appropriate LLM API
+ let result = match &llm_client {
+ LlmClient::Groq(groq) => {
+ match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await {
+ Ok(r) => LlmResult {
+ content: r.content,
+ tool_calls: r.tool_calls,
+ raw_tool_calls: r.raw_tool_calls,
+ finish_reason: r.finish_reason,
+ },
+ Err(e) => {
+ tracing::error!("Groq API error: {}", e);
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("LLM API error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
+ }
+ LlmClient::Claude(claude_client) => {
+ // Convert messages to Claude format
+ let claude_messages = claude::groq_messages_to_claude(&messages);
+ match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await {
+ Ok(r) => {
+ // Convert Claude tool uses to Groq-style ToolCallResponse for consistency
+ let raw_tool_calls: Vec<ToolCallResponse> = r
+ .tool_calls
+ .iter()
+ .map(|tc| ToolCallResponse {
+ id: tc.id.clone(),
+ call_type: "function".to_string(),
+ function: crate::llm::groq::FunctionCall {
+ name: tc.name.clone(),
+ arguments: tc.arguments.to_string(),
+ },
+ })
+ .collect();
+
+ LlmResult {
+ content: r.content,
+ tool_calls: r.tool_calls,
+ raw_tool_calls,
+ finish_reason: r.stop_reason,
+ }
+ }
+ Err(e) => {
+ tracing::error!("Claude API error: {}", e);
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("LLM API error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
+ }
+ };
+
+ // Check if there are tool calls to execute
+ if result.tool_calls.is_empty() {
+ // No more tool calls - capture the final response and exit loop
+ final_response = result.content;
+ break;
}
- tool_call_infos.push(ToolCallInfo {
- name: tool_call.name.clone(),
- result: execution_result.result,
+ // Add assistant message with tool calls to conversation
+ messages.push(Message {
+ role: "assistant".to_string(),
+ content: result.content.clone(),
+ tool_calls: Some(result.raw_tool_calls.clone()),
+ tool_call_id: None,
});
+
+ // Execute each tool call and add results to conversation
+ for (i, tool_call) in result.tool_calls.iter().enumerate() {
+ let execution_result =
+ execute_tool_call(tool_call, &current_body, current_summary.as_deref());
+
+ // Apply state changes
+ if let Some(new_body) = execution_result.new_body {
+ current_body = new_body;
+ }
+ if let Some(new_summary) = execution_result.new_summary {
+ current_summary = Some(new_summary);
+ }
+
+ // Build tool result message content
+ let result_content = if let Some(parsed_data) = &execution_result.parsed_data {
+ // Include parsed data in the result for the LLM to use
+ serde_json::json!({
+ "success": execution_result.result.success,
+ "message": execution_result.result.message,
+ "data": parsed_data
+ })
+ .to_string()
+ } else {
+ serde_json::json!({
+ "success": execution_result.result.success,
+ "message": execution_result.result.message
+ })
+ .to_string()
+ };
+
+ // Add tool result message
+ // Use the appropriate ID format for each provider
+ let tool_call_id = match &llm_client {
+ LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(),
+ LlmClient::Claude(_) => tool_call.id.clone(),
+ };
+
+ messages.push(Message {
+ role: "tool".to_string(),
+ content: Some(result_content),
+ tool_calls: None,
+ tool_call_id: Some(tool_call_id),
+ });
+
+ // Track for response
+ all_tool_call_infos.push(ToolCallInfo {
+ name: tool_call.name.clone(),
+ result: execution_result.result,
+ });
+ }
+
+ // If finish reason indicates completion, exit loop
+ let finish_lower = result.finish_reason.to_lowercase();
+ if finish_lower == "stop" || finish_lower == "end_turn" {
+ final_response = result.content;
+ break;
+ }
}
// Save changes to database if any tools were executed
- if !result.tool_calls.is_empty() {
+ if !all_tool_call_infos.is_empty() {
let update_req = crate::db::models::UpdateFileRequest {
name: None,
description: None,
@@ -214,14 +400,14 @@ pub async fn chat_handler(
}
// Build response
- let response_text = result.content.unwrap_or_else(|| {
- if tool_call_infos.is_empty() {
+ let response_text = final_response.unwrap_or_else(|| {
+ if all_tool_call_infos.is_empty() {
"I couldn't understand your request. Please try rephrasing.".to_string()
} else {
format!(
"Done! Executed {} tool{}.",
- tool_call_infos.len(),
- if tool_call_infos.len() == 1 { "" } else { "s" }
+ all_tool_call_infos.len(),
+ if all_tool_call_infos.len() == 1 { "" } else { "s" }
)
}
});
@@ -230,7 +416,7 @@ pub async fn chat_handler(
StatusCode::OK,
Json(ChatResponse {
response: response_text,
- tool_calls: tool_call_infos,
+ tool_calls: all_tool_call_infos,
updated_body: current_body,
updated_summary: current_summary,
}),