diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 18:24:42 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 18:24:42 +0000 |
| commit | 3c0adec8e3a9dd3bc34251e87e0fb5314793426d (patch) | |
| tree | 9dfe61e55bd703aa09df03abfcbf8e7a8b2babce | |
| parent | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (diff) | |
| download | soryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.tar.gz soryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.zip | |
Add claude opus/sonnet support
| -rw-r--r-- | Cargo.lock | 103 | ||||
| -rw-r--r-- | makima/Cargo.toml | 8 | ||||
| -rw-r--r-- | makima/frontend/src/components/files/CliInput.tsx | 27 | ||||
| -rw-r--r-- | makima/frontend/src/lib/api.ts | 13 | ||||
| -rw-r--r-- | makima/src/llm/claude.rs | 304 | ||||
| -rw-r--r-- | makima/src/llm/groq.rs | 16 | ||||
| -rw-r--r-- | makima/src/llm/mod.rs | 25 | ||||
| -rw-r--r-- | makima/src/llm/tools.rs | 308 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 308 |
9 files changed, 1011 insertions, 101 deletions
@@ -300,6 +300,15 @@ dependencies = [ ] [[package]] +name = "chumsky" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eebd66744a15ded14960ab4ccdbfb51ad3b81f51f3f04a80adac98c985396c9" +dependencies = [ + "hashbrown 0.14.5", +] + +[[package]] name = "clap" version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -635,6 +644,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" [[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -963,6 +978,16 @@ dependencies = [ [[package]] name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" @@ -1030,6 +1055,12 @@ dependencies = [ ] [[package]] +name = "hifijson" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7763b98ba8a24f59e698bf9ab197e7676c640d6455d1580b4ce7dc560f0f0d" + +[[package]] name = "hkdf" version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1407,6 +1438,66 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] +name = "jaq-core" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6fda09ee08c84c81293fdf811d9ebaa87b327557b5391f290c926d728c2ddd4" +dependencies = [ + "aho-corasick", + "base64 0.22.1", + "chrono", + "hifijson", + "jaq-interpret", + "libm", + "log", + "regex", + "urlencoding", +] + +[[package]] +name = "jaq-interpret" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fe95ec3c24af3fd9f3dd1091593f5e49b003a66c496a8aa39d764d0a06ae17b" +dependencies = [ + "ahash", + "dyn-clone", + "hifijson", + "indexmap", + "jaq-syn", + "once_cell", + "serde_json", +] + +[[package]] +name = "jaq-parse" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0346d7d3146cdda8acd929581f3d6626a332356c74d5c95aeaffaac2eb6dee82" +dependencies = [ + "chumsky", + "jaq-syn", +] + +[[package]] +name = "jaq-std" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfbaa55578fd3b70433b594a370741e0c364e4afff92cc0099623fce87311bc1" +dependencies = [ + "jaq-syn", +] + +[[package]] +name = "jaq-syn" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba44fe4428c71304604261ecbae047ee9cfb60c4f1a6bd222ebbb31726d3948" +dependencies = [ + "serde", +] + +[[package]] name = "js-sys" version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1514,12 +1605,18 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" name = "makima" version = "0.1.0" dependencies = [ + "ahash", "anyhow", "axum", "bytes", "chrono", "futures", "hf-hub", + "indexmap", + "jaq-core", + "jaq-interpret", + "jaq-parse", + "jaq-std", "ndarray", "once_cell", "ort", @@ -3682,6 +3779,12 @@ dependencies = [ ] [[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] name = "utf-8" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/makima/Cargo.toml b/makima/Cargo.toml index 5cf1f65..4bf629f 100644 --- a/makima/Cargo.toml +++ b/makima/Cargo.toml @@ -46,3 +46,11 @@ reqwest = { version = "0.12", features = ["json"] } # Lazy statics once_cell = "1.19" + +# JQ for JSON transformation +jaq-interpret = "1.5" +jaq-parse = "1.0" +jaq-core = "1.5" +jaq-std = "1.6" +indexmap = "2.0" +ahash = "0.8" diff --git a/makima/frontend/src/components/files/CliInput.tsx b/makima/frontend/src/components/files/CliInput.tsx index b20eb27..1dcc884 100644 --- a/makima/frontend/src/components/files/CliInput.tsx +++ b/makima/frontend/src/components/files/CliInput.tsx @@ -1,5 +1,5 @@ import { useState, useCallback, useRef, useEffect } from "react"; -import { chatWithFile, type BodyElement } from "../../lib/api"; +import { chatWithFile, type BodyElement, type LlmModel } from "../../lib/api"; interface CliInputProps { fileId: string; @@ -13,11 +13,18 @@ interface Message { toolCalls?: { name: string; success: boolean; message: string }[]; } +const MODEL_OPTIONS: { value: LlmModel; label: string }[] = [ + { value: "claude-opus", label: "Claude Opus" }, + { value: "claude-sonnet", label: "Claude Sonnet" }, + { value: "groq", label: "Groq Kimi" }, +]; + export function CliInput({ fileId, onUpdate }: CliInputProps) { const [input, setInput] = useState(""); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState<Message[]>([]); const [expanded, setExpanded] = useState(false); + const [model, setModel] = useState<LlmModel>("claude-opus"); const inputRef = useRef<HTMLInputElement>(null); const messagesRef = useRef<HTMLDivElement>(null); @@ -47,7 +54,7 @@ export function CliInput({ fileId, onUpdate }: CliInputProps) { setLoading(true); try { - const response = await chatWithFile(fileId, userMessage); + const response = await chatWithFile(fileId, userMessage, model); // Add assistant response const assistantMsgId = (Date.now() + 1).toString(); @@ -82,7 +89,7 @@ export function CliInput({ fileId, onUpdate }: CliInputProps) { inputRef.current?.focus(); } }, - [input, loading, fileId, onUpdate] + [input, loading, fileId, model, onUpdate] ); const clearMessages = useCallback(() => { @@ -136,7 +143,19 @@ export function CliInput({ fileId, onUpdate }: CliInputProps) { {/* Input Bar */} <form onSubmit={handleSubmit} className="flex items-center gap-2 p-3"> - <span className="text-[#9bc3ff] font-mono text-sm">$</span> + <select + value={model} + onChange={(e) => setModel(e.target.value as LlmModel)} + disabled={loading} + className="bg-[#0d1b2d] border border-[rgba(117,170,252,0.25)] text-[#9bc3ff] font-mono text-xs px-2 py-1 rounded-none outline-none focus:border-[#3f6fb3] disabled:opacity-50" + > + {MODEL_OPTIONS.map((opt) => ( + <option key={opt.value} value={opt.value}> + {opt.label} + </option> + ))} + </select> + <span className="text-[#9bc3ff] font-mono text-sm">></span> <input ref={inputRef} type="text" diff --git a/makima/frontend/src/lib/api.ts b/makima/frontend/src/lib/api.ts index 5ef9c22..6f7071d 100644 --- a/makima/frontend/src/lib/api.ts +++ b/makima/frontend/src/lib/api.ts @@ -108,9 +108,13 @@ export interface UpdateFileRequest { body?: BodyElement[]; } +// Available LLM models +export type LlmModel = "claude-sonnet" | "claude-opus" | "groq"; + // Chat API types export interface ChatRequest { message: string; + model?: LlmModel; } export interface ToolCallInfo { @@ -184,12 +188,17 @@ export async function deleteFile(id: string): Promise<void> { // Chat API function export async function chatWithFile( id: string, - message: string + message: string, + model?: LlmModel ): Promise<ChatResponse> { + const body: ChatRequest = { message }; + if (model) { + body.model = model; + } const res = await fetch(`${API_BASE}/api/v1/files/${id}/chat`, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ message }), + body: JSON.stringify(body), }); if (!res.ok) { const errorText = await res.text(); diff --git a/makima/src/llm/claude.rs b/makima/src/llm/claude.rs new file mode 100644 index 0000000..f475acd --- /dev/null +++ b/makima/src/llm/claude.rs @@ -0,0 +1,304 @@ +//! Claude API client for LLM tool calling. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::tools::{Tool, ToolCall}; + +const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ClaudeModel { + Opus, + Sonnet, +} + +impl ClaudeModel { + pub fn model_id(&self) -> &'static str { + match self { + ClaudeModel::Opus => "claude-opus-4-5-20251101", + ClaudeModel::Sonnet => "claude-sonnet-4-5-20250929", + } + } +} + +impl Default for ClaudeModel { + fn default() -> Self { + ClaudeModel::Opus + } +} + +#[derive(Debug, Error)] +pub enum ClaudeError { + #[error("HTTP request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("API error: {0}")] + Api(String), + #[error("Missing API key")] + MissingApiKey, +} + +#[derive(Debug, Clone)] +pub struct ClaudeClient { + api_key: String, + client: reqwest::Client, + model: ClaudeModel, +} + +// Request types +#[derive(Debug, Serialize)] +struct ClaudeRequest { + model: String, + max_tokens: u32, + messages: Vec<Message>, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option<Vec<ToolDefinition>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: MessageContent, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Blocks(Vec<ContentBlock>), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, +} + +#[derive(Debug, Serialize)] +struct ToolDefinition { + name: String, + description: String, + input_schema: serde_json::Value, +} + +// Response types +#[derive(Debug, Deserialize)] +struct ClaudeResponse { + content: Vec<ResponseContentBlock>, + stop_reason: Option<String>, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, +} + +#[derive(Debug)] +pub struct ChatResult { + pub content: Option<String>, + pub tool_calls: Vec<ToolCall>, + /// Raw tool use blocks for including in subsequent messages + pub raw_tool_uses: Vec<ResponseContentBlock>, + pub stop_reason: String, +} + +impl ClaudeClient { + pub fn new(api_key: String, model: ClaudeModel) -> Self { + Self { + api_key, + client: reqwest::Client::new(), + model, + } + } + + pub fn from_env(model: ClaudeModel) -> Result<Self, ClaudeError> { + let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeError::MissingApiKey)?; + Ok(Self::new(api_key, model)) + } + + pub async fn chat_with_tools( + &self, + messages: Vec<Message>, + tools: &[Tool], + ) -> Result<ChatResult, ClaudeError> { + let tool_definitions: Vec<ToolDefinition> = tools + .iter() + .map(|t| ToolDefinition { + name: t.name.clone(), + description: t.description.clone(), + input_schema: t.parameters.clone(), + }) + .collect(); + + let request = ClaudeRequest { + model: self.model.model_id().to_string(), + max_tokens: 4096, + messages, + tools: Some(tool_definitions), + }; + + let response = self + .client + .post(CLAUDE_API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(ClaudeError::Api(error_text)); + } + + let claude_response: ClaudeResponse = response.json().await?; + + let stop_reason = claude_response.stop_reason.unwrap_or_else(|| "end_turn".to_string()); + + // Extract text content and tool uses from content blocks + let mut text_parts: Vec<String> = Vec::new(); + let mut raw_tool_uses: Vec<ResponseContentBlock> = Vec::new(); + + for block in &claude_response.content { + match block { + ResponseContentBlock::Text { text } => { + if !text.is_empty() { + text_parts.push(text.clone()); + } + } + ResponseContentBlock::ToolUse { .. } => { + raw_tool_uses.push(block.clone()); + } + } + } + + let content = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }; + + // Convert tool uses to ToolCalls + let tool_calls: Vec<ToolCall> = raw_tool_uses + .iter() + .filter_map(|block| { + if let ResponseContentBlock::ToolUse { id, name, input } = block { + Some(ToolCall { + id: id.clone(), + name: name.clone(), + arguments: input.clone(), + }) + } else { + None + } + }) + .collect(); + + Ok(ChatResult { + content, + tool_calls, + raw_tool_uses, + stop_reason, + }) + } +} + +/// Helper to convert Groq-style messages to Claude messages +pub fn groq_messages_to_claude(messages: &[super::groq::Message]) -> Vec<Message> { + let mut claude_messages: Vec<Message> = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + // Claude handles system prompts as first user message + if let Some(ref content) = msg.content { + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Text(format!("[System Instructions]: {}", content)), + }); + // Add assistant acknowledgment to maintain conversation structure + claude_messages.push(Message { + role: "assistant".to_string(), + content: MessageContent::Text("Understood. I'll follow these instructions.".to_string()), + }); + } + } + "user" => { + if let Some(ref content) = msg.content { + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Text(content.clone()), + }); + } + } + "assistant" => { + let mut blocks: Vec<ContentBlock> = Vec::new(); + + // Add text content if present + if let Some(ref content) = msg.content { + if !content.is_empty() { + blocks.push(ContentBlock::Text { text: content.clone() }); + } + } + + // Add tool uses if present + if let Some(ref tool_calls) = msg.tool_calls { + for tc in tool_calls { + let input: serde_json::Value = + serde_json::from_str(&tc.function.arguments).unwrap_or_default(); + blocks.push(ContentBlock::ToolUse { + id: tc.id.clone(), + name: tc.function.name.clone(), + input, + }); + } + } + + if !blocks.is_empty() { + claude_messages.push(Message { + role: "assistant".to_string(), + content: MessageContent::Blocks(blocks), + }); + } + } + "tool" => { + // Tool results in Claude go in a user message with tool_result blocks + if let Some(ref content) = msg.content { + let tool_use_id = msg.tool_call_id.clone().unwrap_or_default(); + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Blocks(vec![ContentBlock::ToolResult { + tool_use_id, + content: content.clone(), + }]), + }); + } + } + _ => {} + } + } + + claude_messages +} diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs index be0e2bc..ee01fcf 100644 --- a/makima/src/llm/groq.rs +++ b/makima/src/llm/groq.rs @@ -92,6 +92,8 @@ struct MessageResponse { pub struct ChatResult { pub content: Option<String>, pub tool_calls: Vec<ToolCall>, + /// Raw tool call responses for including in subsequent messages + pub raw_tool_calls: Vec<ToolCallResponse>, pub finish_reason: String, } @@ -154,14 +156,13 @@ impl GroqClient { .next() .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; - let tool_calls = choice - .message - .tool_calls - .unwrap_or_default() - .into_iter() + let raw_tool_calls = choice.message.tool_calls.unwrap_or_default(); + + let tool_calls = raw_tool_calls + .iter() .map(|tc| ToolCall { - id: tc.id, - name: tc.function.name, + id: tc.id.clone(), + name: tc.function.name.clone(), arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), }) .collect(); @@ -169,6 +170,7 @@ impl GroqClient { Ok(ChatResult { content: choice.message.content, tool_calls, + raw_tool_calls, finish_reason: choice.finish_reason, }) } diff --git a/makima/src/llm/mod.rs b/makima/src/llm/mod.rs index 00f3333..7de8afe 100644 --- a/makima/src/llm/mod.rs +++ b/makima/src/llm/mod.rs @@ -1,7 +1,32 @@ //! LLM integration module for file editing via tool calling. +pub mod claude; pub mod groq; pub mod tools; +pub use claude::{ClaudeClient, ClaudeModel}; pub use groq::GroqClient; pub use tools::{execute_tool_call, Tool, ToolCall, ToolResult, AVAILABLE_TOOLS}; + +/// Available LLM providers and models +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LlmModel { + /// Claude Sonnet 4.5 - balanced speed and capability + ClaudeSonnet, + /// Claude Opus 4.5 (default) - most capable + #[default] + ClaudeOpus, + /// Groq Kimi - fast alternative provider + GroqKimi, +} + +impl LlmModel { + pub fn from_str(s: &str) -> Option<Self> { + match s.to_lowercase().as_str() { + "claude-sonnet" | "sonnet" | "claude" => Some(LlmModel::ClaudeSonnet), + "claude-opus" | "opus" => Some(LlmModel::ClaudeOpus), + "groq" | "kimi" | "groq-kimi" => Some(LlmModel::GroqKimi), + _ => None, + } + } +} diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 3bd102f..e6b2954 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -1,5 +1,6 @@ //! Tool definitions for file editing via LLM. +use jaq_interpret::FilterT; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -121,7 +122,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }, Tool { name: "update_element".to_string(), - description: "Update an existing element in the file body".to_string(), + description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(), parameters: json!({ "type": "object", "properties": { @@ -129,12 +130,35 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "type": "integer", "description": "Index of element to update (0-indexed)" }, - "element": { - "type": "object", - "description": "New element data. Must include 'type' field (heading, paragraph, chart)." + "element_type": { + "type": "string", + "enum": ["heading", "paragraph", "chart"], + "description": "Type of element" + }, + "text": { + "type": "string", + "description": "Text content (required for heading and paragraph)" + }, + "level": { + "type": "integer", + "description": "Heading level 1-6 (required for heading)" + }, + "chartType": { + "type": "string", + "enum": ["line", "bar", "pie", "area"], + "description": "Chart type (required for chart)" + }, + "data": { + "type": "array", + "description": "Chart data array (required for chart)", + "items": { "type": "object" } + }, + "title": { + "type": "string", + "description": "Chart title (optional for chart)" } }, - "required": ["index", "element"] + "required": ["index", "element_type"] }), }, Tool { @@ -191,6 +215,23 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "properties": {} }), }, + Tool { + name: "jq".to_string(), + description: "Transform JSON data using jq expressions. Useful for filtering, mapping, grouping, and aggregating data before creating charts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "input": { + "description": "The JSON data to transform (can be an array or object)" + }, + "filter": { + "type": "string", + "description": "The jq filter expression. Examples: '.[] | select(.value > 10)', 'group_by(.category) | map({name: .[0].category, count: length})', '[.[] | {name: .label, value: .amount}]'" + } + }, + "required": ["input", "filter"] + }), + }, ] }); @@ -219,6 +260,7 @@ pub fn execute_tool_call( "set_summary" => execute_set_summary(call, current_summary), "parse_csv" => execute_parse_csv(call), "clear_body" => execute_clear_body(), + "jq" => execute_jq(call), _ => ToolExecutionResult { result: ToolResult { success: false, @@ -415,7 +457,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { let index = call.arguments.get("index").and_then(|v| v.as_u64()); - let element_json = call.arguments.get("element"); + let element_type = call.arguments.get("element_type").and_then(|v| v.as_str()); let Some(index) = index else { return ToolExecutionResult { @@ -429,11 +471,11 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; }; - let Some(element_json) = element_json else { + let Some(element_type) = element_type else { return ToolExecutionResult { result: ToolResult { success: false, - message: "Missing element parameter".to_string(), + message: "Missing element_type parameter".to_string(), }, new_body: None, new_summary: None, @@ -454,31 +496,55 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; } - let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone()); - match new_element { - Ok(element) => { - let mut new_body = current_body.to_vec(); - new_body[index] = element; - - ToolExecutionResult { + // Build the element based on type + let new_element = match element_type { + "heading" => { + let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8; + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Heading { level, text } + } + "paragraph" => { + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Paragraph { text } + } + "chart" => { + let chart_type_str = call.arguments.get("chartType").and_then(|v| v.as_str()).unwrap_or("bar"); + let chart_type = match chart_type_str { + "line" => ChartType::Line, + "bar" => ChartType::Bar, + "pie" => ChartType::Pie, + "area" => ChartType::Area, + _ => ChartType::Bar, + }; + let title = call.arguments.get("title").and_then(|v| v.as_str()).map(|s| s.to_string()); + let data = call.arguments.get("data").cloned().unwrap_or(json!([])); + let config = call.arguments.get("config").cloned(); + BodyElement::Chart { chart_type, title, data, config } + } + _ => { + return ToolExecutionResult { result: ToolResult { - success: true, - message: format!("Updated element at index {}", index), + success: false, + message: format!("Unknown element_type: {}. Must be heading, paragraph, or chart.", element_type), }, - new_body: Some(new_body), + new_body: None, new_summary: None, parsed_data: None, - } + }; } - Err(e) => ToolExecutionResult { - result: ToolResult { - success: false, - message: format!("Invalid element format: {}", e), - }, - new_body: None, - new_summary: None, - parsed_data: None, + }; + + let mut new_body = current_body.to_vec(); + new_body[index] = new_element; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Updated element at index {} to {}", index, element_type), }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, } } @@ -616,3 +682,191 @@ fn execute_clear_body() -> ToolExecutionResult { parsed_data: None, } } + +fn execute_jq(call: &ToolCall) -> ToolExecutionResult { + let input = match call.arguments.get("input") { + Some(v) => v.clone(), + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing input parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + let filter = match call.arguments.get("filter").and_then(|v| v.as_str()) { + Some(f) => f, + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing filter parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + // Parse the jq filter + let mut defs = jaq_interpret::ParseCtx::new(Vec::new()); + defs.insert_natives(jaq_core::core()); + defs.insert_defs(jaq_std::std()); + + let (parsed_filter, errs) = jaq_parse::parse(filter, jaq_parse::main()); + if !errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Invalid jq filter: {:?}", errs), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let Some(parsed_filter) = parsed_filter else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Failed to parse jq filter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + // Compile the filter + let compiled = defs.compile(parsed_filter); + if !defs.errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Failed to compile jq filter ({} errors)", defs.errs.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + // Convert serde_json::Value to jaq Value + let jaq_input = json_to_jaq(&input); + + // Execute the filter + let inputs = jaq_interpret::RcIter::new(std::iter::empty()); + let mut results: Vec<serde_json::Value> = Vec::new(); + + for output in compiled.run((jaq_interpret::Ctx::new([], &inputs), jaq_input)) { + match output { + Ok(val) => { + results.push(jaq_to_json(&val)); + } + Err(e) => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("jq execution error: {:?}", e), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + } + } + + // Return single value or array based on results + let output = if results.len() == 1 { + results.into_iter().next().unwrap() + } else { + json!(results) + }; + + let preview = { + let s = output.to_string(); + if s.len() > 100 { + format!("{}...", &s[..100]) + } else { + s + } + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("jq transform complete: {}", preview), + }, + new_body: None, + new_summary: None, + parsed_data: Some(output), + } +} + +/// Convert serde_json::Value to jaq_interpret::Val +fn json_to_jaq(value: &serde_json::Value) -> jaq_interpret::Val { + match value { + serde_json::Value::Null => jaq_interpret::Val::Null, + serde_json::Value::Bool(b) => jaq_interpret::Val::Bool(*b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + jaq_interpret::Val::Int(i as isize) + } else if let Some(f) = n.as_f64() { + jaq_interpret::Val::Float(f) + } else { + jaq_interpret::Val::Null + } + } + serde_json::Value::String(s) => jaq_interpret::Val::Str(s.clone().into()), + serde_json::Value::Array(arr) => { + jaq_interpret::Val::Arr(std::rc::Rc::new(arr.iter().map(json_to_jaq).collect())) + } + serde_json::Value::Object(obj) => { + let mut map: indexmap::IndexMap<std::rc::Rc<String>, jaq_interpret::Val, ahash::RandomState> = + indexmap::IndexMap::with_hasher(ahash::RandomState::new()); + for (k, v) in obj { + map.insert(std::rc::Rc::new(k.clone()), json_to_jaq(v)); + } + jaq_interpret::Val::Obj(std::rc::Rc::new(map)) + } + } +} + +/// Convert jaq_interpret::Val to serde_json::Value +fn jaq_to_json(value: &jaq_interpret::Val) -> serde_json::Value { + match value { + jaq_interpret::Val::Null => serde_json::Value::Null, + jaq_interpret::Val::Bool(b) => json!(*b), + jaq_interpret::Val::Int(i) => json!(*i), + jaq_interpret::Val::Float(f) => json!(*f), + jaq_interpret::Val::Num(n) => { + // Try to parse the number string + if let Ok(i) = n.parse::<i64>() { + json!(i) + } else if let Ok(f) = n.parse::<f64>() { + json!(f) + } else { + json!(n.as_ref()) + } + } + jaq_interpret::Val::Str(s) => json!(s.as_ref()), + jaq_interpret::Val::Arr(arr) => { + json!(arr.iter().map(jaq_to_json).collect::<Vec<_>>()) + } + jaq_interpret::Val::Obj(obj) => { + let mut map = serde_json::Map::new(); + for (k, v) in obj.iter() { + map.insert((**k).clone(), jaq_to_json(v)); + } + serde_json::Value::Object(map) + } + } +} diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs index e6d22ca..92c4ec8 100644 --- a/makima/src/server/handlers/chat.rs +++ b/makima/src/server/handlers/chat.rs @@ -12,17 +12,24 @@ use uuid::Uuid; use crate::db::{models::BodyElement, repository}; use crate::llm::{ + claude::{self, ClaudeClient, ClaudeError, ClaudeModel}, execute_tool_call, - groq::{GroqClient, GroqError, Message}, - ToolResult, AVAILABLE_TOOLS, + groq::{GroqClient, GroqError, Message, ToolCallResponse}, + LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS, }; use crate::server::state::SharedState; +/// Maximum number of tool-calling rounds to prevent infinite loops +const MAX_TOOL_ROUNDS: usize = 10; + #[derive(Debug, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] pub struct ChatRequest { /// The user's message/instruction pub message: String, + /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq" + #[serde(default)] + pub model: Option<String>, } #[derive(Debug, Serialize, ToSchema)] @@ -45,9 +52,18 @@ pub struct ToolCallInfo { pub result: ToolResult, } -#[derive(Debug, Serialize)] -struct ErrorResponse { - error: String, +/// Enum to hold LLM clients +enum LlmClient { + Groq(GroqClient), + Claude(ClaudeClient), +} + +/// Unified result from LLM call +struct LlmResult { + content: Option<String>, + tool_calls: Vec<ToolCall>, + raw_tool_calls: Vec<ToolCallResponse>, + finish_reason: String, } /// Chat with a file using LLM tool calling @@ -105,40 +121,102 @@ pub async fn chat_handler( } }; - // Initialize Groq client - let groq = match GroqClient::from_env() { - Ok(client) => client, - Err(GroqError::MissingApiKey) => { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "error": "GROQ_API_KEY not configured" - })), - ) - .into_response(); + // Parse model selection (default to Claude Sonnet) + let model = request + .model + .as_ref() + .and_then(|m| LlmModel::from_str(m)) + .unwrap_or_default(); + + tracing::info!("Using LLM model: {:?}", model); + + // Initialize the appropriate LLM client + let llm_client = match model { + LlmModel::ClaudeSonnet => { + match ClaudeClient::from_env(ClaudeModel::Sonnet) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "ANTHROPIC_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Claude client error: {}", e) + })), + ) + .into_response(); + } + } } - Err(e) => { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("Groq client error: {}", e) - })), - ) - .into_response(); + LlmModel::ClaudeOpus => { + match ClaudeClient::from_env(ClaudeModel::Opus) { + Ok(client) => LlmClient::Claude(client), + Err(ClaudeError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "ANTHROPIC_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Claude client error: {}", e) + })), + ) + .into_response(); + } + } + } + LlmModel::GroqKimi => { + match GroqClient::from_env() { + Ok(client) => LlmClient::Groq(client), + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "GROQ_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Groq client error: {}", e) + })), + ) + .into_response(); + } + } } }; // Build context about the file let file_context = build_file_context(&file); - // Build messages - let messages = vec![ + // Build initial messages (Groq/OpenAI format - will be converted for Claude) + let mut messages = vec![ Message { role: "system".to_string(), content: Some(format!( "You are a helpful assistant that helps users edit and analyze document files. \ You have access to tools to add headings, paragraphs, charts, and set summaries. \ When the user asks you to modify the file, use the appropriate tools.\n\n\ + IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \ + and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \ + to call add_chart.\n\n\ Current file context:\n{}", file_context )), @@ -153,46 +231,154 @@ pub async fn chat_handler( }, ]; - // Call Groq API - let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await { - Ok(result) => result, - Err(e) => { - tracing::error!("Groq API error: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": format!("LLM API error: {}", e) - })), - ) - .into_response(); - } - }; - - // Execute tool calls + // State for tracking changes let mut current_body = file.body.clone(); let mut current_summary = file.summary.clone(); - let mut tool_call_infos = Vec::new(); + let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new(); + let mut final_response: Option<String> = None; - for tool_call in &result.tool_calls { - let execution_result = - execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + // Multi-turn tool calling loop + for round in 0..MAX_TOOL_ROUNDS { + tracing::debug!(round = round, "LLM tool calling round"); - // Apply state changes - if let Some(new_body) = execution_result.new_body { - current_body = new_body; - } - if let Some(new_summary) = execution_result.new_summary { - current_summary = Some(new_summary); + // Call the appropriate LLM API + let result = match &llm_client { + LlmClient::Groq(groq) => { + match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await { + Ok(r) => LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls: r.raw_tool_calls, + finish_reason: r.finish_reason, + }, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + } + } + LlmClient::Claude(claude_client) => { + // Convert messages to Claude format + let claude_messages = claude::groq_messages_to_claude(&messages); + match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await { + Ok(r) => { + // Convert Claude tool uses to Groq-style ToolCallResponse for consistency + let raw_tool_calls: Vec<ToolCallResponse> = r + .tool_calls + .iter() + .map(|tc| ToolCallResponse { + id: tc.id.clone(), + call_type: "function".to_string(), + function: crate::llm::groq::FunctionCall { + name: tc.name.clone(), + arguments: tc.arguments.to_string(), + }, + }) + .collect(); + + LlmResult { + content: r.content, + tool_calls: r.tool_calls, + raw_tool_calls, + finish_reason: r.stop_reason, + } + } + Err(e) => { + tracing::error!("Claude API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + } + } + }; + + // Check if there are tool calls to execute + if result.tool_calls.is_empty() { + // No more tool calls - capture the final response and exit loop + final_response = result.content; + break; } - tool_call_infos.push(ToolCallInfo { - name: tool_call.name.clone(), - result: execution_result.result, + // Add assistant message with tool calls to conversation + messages.push(Message { + role: "assistant".to_string(), + content: result.content.clone(), + tool_calls: Some(result.raw_tool_calls.clone()), + tool_call_id: None, }); + + // Execute each tool call and add results to conversation + for (i, tool_call) in result.tool_calls.iter().enumerate() { + let execution_result = + execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + + // Apply state changes + if let Some(new_body) = execution_result.new_body { + current_body = new_body; + } + if let Some(new_summary) = execution_result.new_summary { + current_summary = Some(new_summary); + } + + // Build tool result message content + let result_content = if let Some(parsed_data) = &execution_result.parsed_data { + // Include parsed data in the result for the LLM to use + serde_json::json!({ + "success": execution_result.result.success, + "message": execution_result.result.message, + "data": parsed_data + }) + .to_string() + } else { + serde_json::json!({ + "success": execution_result.result.success, + "message": execution_result.result.message + }) + .to_string() + }; + + // Add tool result message + // Use the appropriate ID format for each provider + let tool_call_id = match &llm_client { + LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(), + LlmClient::Claude(_) => tool_call.id.clone(), + }; + + messages.push(Message { + role: "tool".to_string(), + content: Some(result_content), + tool_calls: None, + tool_call_id: Some(tool_call_id), + }); + + // Track for response + all_tool_call_infos.push(ToolCallInfo { + name: tool_call.name.clone(), + result: execution_result.result, + }); + } + + // If finish reason indicates completion, exit loop + let finish_lower = result.finish_reason.to_lowercase(); + if finish_lower == "stop" || finish_lower == "end_turn" { + final_response = result.content; + break; + } } // Save changes to database if any tools were executed - if !result.tool_calls.is_empty() { + if !all_tool_call_infos.is_empty() { let update_req = crate::db::models::UpdateFileRequest { name: None, description: None, @@ -214,14 +400,14 @@ pub async fn chat_handler( } // Build response - let response_text = result.content.unwrap_or_else(|| { - if tool_call_infos.is_empty() { + let response_text = final_response.unwrap_or_else(|| { + if all_tool_call_infos.is_empty() { "I couldn't understand your request. Please try rephrasing.".to_string() } else { format!( "Done! Executed {} tool{}.", - tool_call_infos.len(), - if tool_call_infos.len() == 1 { "" } else { "s" } + all_tool_call_infos.len(), + if all_tool_call_infos.len() == 1 { "" } else { "s" } ) } }); @@ -230,7 +416,7 @@ pub async fn chat_handler( StatusCode::OK, Json(ChatResponse { response: response_text, - tool_calls: tool_call_infos, + tool_calls: all_tool_call_infos, updated_body: current_body, updated_summary: current_summary, }), |
