diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 18:24:42 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 18:24:42 +0000 |
| commit | 3c0adec8e3a9dd3bc34251e87e0fb5314793426d (patch) | |
| tree | 9dfe61e55bd703aa09df03abfcbf8e7a8b2babce /makima/src/llm | |
| parent | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (diff) | |
| download | soryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.tar.gz soryu-3c0adec8e3a9dd3bc34251e87e0fb5314793426d.zip | |
Add claude opus/sonnet support
Diffstat (limited to 'makima/src/llm')
| -rw-r--r-- | makima/src/llm/claude.rs | 304 | ||||
| -rw-r--r-- | makima/src/llm/groq.rs | 16 | ||||
| -rw-r--r-- | makima/src/llm/mod.rs | 25 | ||||
| -rw-r--r-- | makima/src/llm/tools.rs | 308 |
4 files changed, 619 insertions, 34 deletions
diff --git a/makima/src/llm/claude.rs b/makima/src/llm/claude.rs new file mode 100644 index 0000000..f475acd --- /dev/null +++ b/makima/src/llm/claude.rs @@ -0,0 +1,304 @@ +//! Claude API client for LLM tool calling. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::tools::{Tool, ToolCall}; + +const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ClaudeModel { + Opus, + Sonnet, +} + +impl ClaudeModel { + pub fn model_id(&self) -> &'static str { + match self { + ClaudeModel::Opus => "claude-opus-4-5-20251101", + ClaudeModel::Sonnet => "claude-sonnet-4-5-20250929", + } + } +} + +impl Default for ClaudeModel { + fn default() -> Self { + ClaudeModel::Opus + } +} + +#[derive(Debug, Error)] +pub enum ClaudeError { + #[error("HTTP request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("API error: {0}")] + Api(String), + #[error("Missing API key")] + MissingApiKey, +} + +#[derive(Debug, Clone)] +pub struct ClaudeClient { + api_key: String, + client: reqwest::Client, + model: ClaudeModel, +} + +// Request types +#[derive(Debug, Serialize)] +struct ClaudeRequest { + model: String, + max_tokens: u32, + messages: Vec<Message>, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option<Vec<ToolDefinition>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: MessageContent, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Blocks(Vec<ContentBlock>), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, +} + +#[derive(Debug, Serialize)] +struct ToolDefinition { + name: String, + description: String, + input_schema: serde_json::Value, +} + +// Response types +#[derive(Debug, Deserialize)] +struct ClaudeResponse { + content: Vec<ResponseContentBlock>, + stop_reason: Option<String>, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, +} + +#[derive(Debug)] +pub struct ChatResult { + pub content: Option<String>, + pub tool_calls: Vec<ToolCall>, + /// Raw tool use blocks for including in subsequent messages + pub raw_tool_uses: Vec<ResponseContentBlock>, + pub stop_reason: String, +} + +impl ClaudeClient { + pub fn new(api_key: String, model: ClaudeModel) -> Self { + Self { + api_key, + client: reqwest::Client::new(), + model, + } + } + + pub fn from_env(model: ClaudeModel) -> Result<Self, ClaudeError> { + let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeError::MissingApiKey)?; + Ok(Self::new(api_key, model)) + } + + pub async fn chat_with_tools( + &self, + messages: Vec<Message>, + tools: &[Tool], + ) -> Result<ChatResult, ClaudeError> { + let tool_definitions: Vec<ToolDefinition> = tools + .iter() + .map(|t| ToolDefinition { + name: t.name.clone(), + description: t.description.clone(), + input_schema: t.parameters.clone(), + }) + .collect(); + + let request = ClaudeRequest { + model: self.model.model_id().to_string(), + max_tokens: 4096, + messages, + tools: Some(tool_definitions), + }; + + let response = self + .client + .post(CLAUDE_API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(ClaudeError::Api(error_text)); + } + + let claude_response: ClaudeResponse = response.json().await?; + + let stop_reason = claude_response.stop_reason.unwrap_or_else(|| "end_turn".to_string()); + + // Extract text content and tool uses from content blocks + let mut text_parts: Vec<String> = Vec::new(); + let mut raw_tool_uses: Vec<ResponseContentBlock> = Vec::new(); + + for block in &claude_response.content { + match block { + ResponseContentBlock::Text { text } => { + if !text.is_empty() { + text_parts.push(text.clone()); + } + } + ResponseContentBlock::ToolUse { .. } => { + raw_tool_uses.push(block.clone()); + } + } + } + + let content = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }; + + // Convert tool uses to ToolCalls + let tool_calls: Vec<ToolCall> = raw_tool_uses + .iter() + .filter_map(|block| { + if let ResponseContentBlock::ToolUse { id, name, input } = block { + Some(ToolCall { + id: id.clone(), + name: name.clone(), + arguments: input.clone(), + }) + } else { + None + } + }) + .collect(); + + Ok(ChatResult { + content, + tool_calls, + raw_tool_uses, + stop_reason, + }) + } +} + +/// Helper to convert Groq-style messages to Claude messages +pub fn groq_messages_to_claude(messages: &[super::groq::Message]) -> Vec<Message> { + let mut claude_messages: Vec<Message> = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + // Claude handles system prompts as first user message + if let Some(ref content) = msg.content { + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Text(format!("[System Instructions]: {}", content)), + }); + // Add assistant acknowledgment to maintain conversation structure + claude_messages.push(Message { + role: "assistant".to_string(), + content: MessageContent::Text("Understood. I'll follow these instructions.".to_string()), + }); + } + } + "user" => { + if let Some(ref content) = msg.content { + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Text(content.clone()), + }); + } + } + "assistant" => { + let mut blocks: Vec<ContentBlock> = Vec::new(); + + // Add text content if present + if let Some(ref content) = msg.content { + if !content.is_empty() { + blocks.push(ContentBlock::Text { text: content.clone() }); + } + } + + // Add tool uses if present + if let Some(ref tool_calls) = msg.tool_calls { + for tc in tool_calls { + let input: serde_json::Value = + serde_json::from_str(&tc.function.arguments).unwrap_or_default(); + blocks.push(ContentBlock::ToolUse { + id: tc.id.clone(), + name: tc.function.name.clone(), + input, + }); + } + } + + if !blocks.is_empty() { + claude_messages.push(Message { + role: "assistant".to_string(), + content: MessageContent::Blocks(blocks), + }); + } + } + "tool" => { + // Tool results in Claude go in a user message with tool_result blocks + if let Some(ref content) = msg.content { + let tool_use_id = msg.tool_call_id.clone().unwrap_or_default(); + claude_messages.push(Message { + role: "user".to_string(), + content: MessageContent::Blocks(vec![ContentBlock::ToolResult { + tool_use_id, + content: content.clone(), + }]), + }); + } + } + _ => {} + } + } + + claude_messages +} diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs index be0e2bc..ee01fcf 100644 --- a/makima/src/llm/groq.rs +++ b/makima/src/llm/groq.rs @@ -92,6 +92,8 @@ struct MessageResponse { pub struct ChatResult { pub content: Option<String>, pub tool_calls: Vec<ToolCall>, + /// Raw tool call responses for including in subsequent messages + pub raw_tool_calls: Vec<ToolCallResponse>, pub finish_reason: String, } @@ -154,14 +156,13 @@ impl GroqClient { .next() .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; - let tool_calls = choice - .message - .tool_calls - .unwrap_or_default() - .into_iter() + let raw_tool_calls = choice.message.tool_calls.unwrap_or_default(); + + let tool_calls = raw_tool_calls + .iter() .map(|tc| ToolCall { - id: tc.id, - name: tc.function.name, + id: tc.id.clone(), + name: tc.function.name.clone(), arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), }) .collect(); @@ -169,6 +170,7 @@ impl GroqClient { Ok(ChatResult { content: choice.message.content, tool_calls, + raw_tool_calls, finish_reason: choice.finish_reason, }) } diff --git a/makima/src/llm/mod.rs b/makima/src/llm/mod.rs index 00f3333..7de8afe 100644 --- a/makima/src/llm/mod.rs +++ b/makima/src/llm/mod.rs @@ -1,7 +1,32 @@ //! LLM integration module for file editing via tool calling. +pub mod claude; pub mod groq; pub mod tools; +pub use claude::{ClaudeClient, ClaudeModel}; pub use groq::GroqClient; pub use tools::{execute_tool_call, Tool, ToolCall, ToolResult, AVAILABLE_TOOLS}; + +/// Available LLM providers and models +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LlmModel { + /// Claude Sonnet 4.5 - balanced speed and capability + ClaudeSonnet, + /// Claude Opus 4.5 (default) - most capable + #[default] + ClaudeOpus, + /// Groq Kimi - fast alternative provider + GroqKimi, +} + +impl LlmModel { + pub fn from_str(s: &str) -> Option<Self> { + match s.to_lowercase().as_str() { + "claude-sonnet" | "sonnet" | "claude" => Some(LlmModel::ClaudeSonnet), + "claude-opus" | "opus" => Some(LlmModel::ClaudeOpus), + "groq" | "kimi" | "groq-kimi" => Some(LlmModel::GroqKimi), + _ => None, + } + } +} diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs index 3bd102f..e6b2954 100644 --- a/makima/src/llm/tools.rs +++ b/makima/src/llm/tools.rs @@ -1,5 +1,6 @@ //! Tool definitions for file editing via LLM. +use jaq_interpret::FilterT; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -121,7 +122,7 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = }, Tool { name: "update_element".to_string(), - description: "Update an existing element in the file body".to_string(), + description: "Update an existing element in the file body. IMPORTANT: You must provide ALL required fields. For heading: type, level (1-6), text. For paragraph: type, text. For chart: type, chartType (line/bar/pie/area), data (array of objects).".to_string(), parameters: json!({ "type": "object", "properties": { @@ -129,12 +130,35 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "type": "integer", "description": "Index of element to update (0-indexed)" }, - "element": { - "type": "object", - "description": "New element data. Must include 'type' field (heading, paragraph, chart)." + "element_type": { + "type": "string", + "enum": ["heading", "paragraph", "chart"], + "description": "Type of element" + }, + "text": { + "type": "string", + "description": "Text content (required for heading and paragraph)" + }, + "level": { + "type": "integer", + "description": "Heading level 1-6 (required for heading)" + }, + "chartType": { + "type": "string", + "enum": ["line", "bar", "pie", "area"], + "description": "Chart type (required for chart)" + }, + "data": { + "type": "array", + "description": "Chart data array (required for chart)", + "items": { "type": "object" } + }, + "title": { + "type": "string", + "description": "Chart title (optional for chart)" } }, - "required": ["index", "element"] + "required": ["index", "element_type"] }), }, Tool { @@ -191,6 +215,23 @@ pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = "properties": {} }), }, + Tool { + name: "jq".to_string(), + description: "Transform JSON data using jq expressions. Useful for filtering, mapping, grouping, and aggregating data before creating charts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "input": { + "description": "The JSON data to transform (can be an array or object)" + }, + "filter": { + "type": "string", + "description": "The jq filter expression. Examples: '.[] | select(.value > 10)', 'group_by(.category) | map({name: .[0].category, count: length})', '[.[] | {name: .label, value: .amount}]'" + } + }, + "required": ["input", "filter"] + }), + }, ] }); @@ -219,6 +260,7 @@ pub fn execute_tool_call( "set_summary" => execute_set_summary(call, current_summary), "parse_csv" => execute_parse_csv(call), "clear_body" => execute_clear_body(), + "jq" => execute_jq(call), _ => ToolExecutionResult { result: ToolResult { success: false, @@ -415,7 +457,7 @@ fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { let index = call.arguments.get("index").and_then(|v| v.as_u64()); - let element_json = call.arguments.get("element"); + let element_type = call.arguments.get("element_type").and_then(|v| v.as_str()); let Some(index) = index else { return ToolExecutionResult { @@ -429,11 +471,11 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; }; - let Some(element_json) = element_json else { + let Some(element_type) = element_type else { return ToolExecutionResult { result: ToolResult { success: false, - message: "Missing element parameter".to_string(), + message: "Missing element_type parameter".to_string(), }, new_body: None, new_summary: None, @@ -454,31 +496,55 @@ fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> Tool }; } - let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone()); - match new_element { - Ok(element) => { - let mut new_body = current_body.to_vec(); - new_body[index] = element; - - ToolExecutionResult { + // Build the element based on type + let new_element = match element_type { + "heading" => { + let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8; + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Heading { level, text } + } + "paragraph" => { + let text = call.arguments.get("text").and_then(|v| v.as_str()).unwrap_or("").to_string(); + BodyElement::Paragraph { text } + } + "chart" => { + let chart_type_str = call.arguments.get("chartType").and_then(|v| v.as_str()).unwrap_or("bar"); + let chart_type = match chart_type_str { + "line" => ChartType::Line, + "bar" => ChartType::Bar, + "pie" => ChartType::Pie, + "area" => ChartType::Area, + _ => ChartType::Bar, + }; + let title = call.arguments.get("title").and_then(|v| v.as_str()).map(|s| s.to_string()); + let data = call.arguments.get("data").cloned().unwrap_or(json!([])); + let config = call.arguments.get("config").cloned(); + BodyElement::Chart { chart_type, title, data, config } + } + _ => { + return ToolExecutionResult { result: ToolResult { - success: true, - message: format!("Updated element at index {}", index), + success: false, + message: format!("Unknown element_type: {}. Must be heading, paragraph, or chart.", element_type), }, - new_body: Some(new_body), + new_body: None, new_summary: None, parsed_data: None, - } + }; } - Err(e) => ToolExecutionResult { - result: ToolResult { - success: false, - message: format!("Invalid element format: {}", e), - }, - new_body: None, - new_summary: None, - parsed_data: None, + }; + + let mut new_body = current_body.to_vec(); + new_body[index] = new_element; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Updated element at index {} to {}", index, element_type), }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, } } @@ -616,3 +682,191 @@ fn execute_clear_body() -> ToolExecutionResult { parsed_data: None, } } + +fn execute_jq(call: &ToolCall) -> ToolExecutionResult { + let input = match call.arguments.get("input") { + Some(v) => v.clone(), + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing input parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + let filter = match call.arguments.get("filter").and_then(|v| v.as_str()) { + Some(f) => f, + None => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing filter parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + }; + + // Parse the jq filter + let mut defs = jaq_interpret::ParseCtx::new(Vec::new()); + defs.insert_natives(jaq_core::core()); + defs.insert_defs(jaq_std::std()); + + let (parsed_filter, errs) = jaq_parse::parse(filter, jaq_parse::main()); + if !errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Invalid jq filter: {:?}", errs), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let Some(parsed_filter) = parsed_filter else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Failed to parse jq filter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + // Compile the filter + let compiled = defs.compile(parsed_filter); + if !defs.errs.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Failed to compile jq filter ({} errors)", defs.errs.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + // Convert serde_json::Value to jaq Value + let jaq_input = json_to_jaq(&input); + + // Execute the filter + let inputs = jaq_interpret::RcIter::new(std::iter::empty()); + let mut results: Vec<serde_json::Value> = Vec::new(); + + for output in compiled.run((jaq_interpret::Ctx::new([], &inputs), jaq_input)) { + match output { + Ok(val) => { + results.push(jaq_to_json(&val)); + } + Err(e) => { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("jq execution error: {:?}", e), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + } + } + + // Return single value or array based on results + let output = if results.len() == 1 { + results.into_iter().next().unwrap() + } else { + json!(results) + }; + + let preview = { + let s = output.to_string(); + if s.len() > 100 { + format!("{}...", &s[..100]) + } else { + s + } + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("jq transform complete: {}", preview), + }, + new_body: None, + new_summary: None, + parsed_data: Some(output), + } +} + +/// Convert serde_json::Value to jaq_interpret::Val +fn json_to_jaq(value: &serde_json::Value) -> jaq_interpret::Val { + match value { + serde_json::Value::Null => jaq_interpret::Val::Null, + serde_json::Value::Bool(b) => jaq_interpret::Val::Bool(*b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + jaq_interpret::Val::Int(i as isize) + } else if let Some(f) = n.as_f64() { + jaq_interpret::Val::Float(f) + } else { + jaq_interpret::Val::Null + } + } + serde_json::Value::String(s) => jaq_interpret::Val::Str(s.clone().into()), + serde_json::Value::Array(arr) => { + jaq_interpret::Val::Arr(std::rc::Rc::new(arr.iter().map(json_to_jaq).collect())) + } + serde_json::Value::Object(obj) => { + let mut map: indexmap::IndexMap<std::rc::Rc<String>, jaq_interpret::Val, ahash::RandomState> = + indexmap::IndexMap::with_hasher(ahash::RandomState::new()); + for (k, v) in obj { + map.insert(std::rc::Rc::new(k.clone()), json_to_jaq(v)); + } + jaq_interpret::Val::Obj(std::rc::Rc::new(map)) + } + } +} + +/// Convert jaq_interpret::Val to serde_json::Value +fn jaq_to_json(value: &jaq_interpret::Val) -> serde_json::Value { + match value { + jaq_interpret::Val::Null => serde_json::Value::Null, + jaq_interpret::Val::Bool(b) => json!(*b), + jaq_interpret::Val::Int(i) => json!(*i), + jaq_interpret::Val::Float(f) => json!(*f), + jaq_interpret::Val::Num(n) => { + // Try to parse the number string + if let Ok(i) = n.parse::<i64>() { + json!(i) + } else if let Ok(f) = n.parse::<f64>() { + json!(f) + } else { + json!(n.as_ref()) + } + } + jaq_interpret::Val::Str(s) => json!(s.as_ref()), + jaq_interpret::Val::Arr(arr) => { + json!(arr.iter().map(jaq_to_json).collect::<Vec<_>>()) + } + jaq_interpret::Val::Obj(obj) => { + let mut map = serde_json::Map::new(); + for (k, v) in obj.iter() { + map.insert((**k).clone(), jaq_to_json(v)); + } + serde_json::Value::Object(map) + } + } +} |
