summaryrefslogblamecommitdiff
path: root/makima/src/llm/claude.rs
blob: f475acd6b0fb13478321789f248d66bbdc7cf849 (plain) (tree)















































































































































































































































































































                                                                                                                 
//! Claude API client for LLM tool calling.

use serde::{Deserialize, Serialize};
use thiserror::Error;

use super::tools::{Tool, ToolCall};

const CLAUDE_API_URL: &str = "https://api.anthropic.com/v1/messages";
const ANTHROPIC_VERSION: &str = "2023-06-01";

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClaudeModel {
    Opus,
    Sonnet,
}

impl ClaudeModel {
    pub fn model_id(&self) -> &'static str {
        match self {
            ClaudeModel::Opus => "claude-opus-4-5-20251101",
            ClaudeModel::Sonnet => "claude-sonnet-4-5-20250929",
        }
    }
}

impl Default for ClaudeModel {
    fn default() -> Self {
        ClaudeModel::Opus
    }
}

#[derive(Debug, Error)]
pub enum ClaudeError {
    #[error("HTTP request failed: {0}")]
    Request(#[from] reqwest::Error),
    #[error("API error: {0}")]
    Api(String),
    #[error("Missing API key")]
    MissingApiKey,
}

#[derive(Debug, Clone)]
pub struct ClaudeClient {
    api_key: String,
    client: reqwest::Client,
    model: ClaudeModel,
}

// Request types
#[derive(Debug, Serialize)]
struct ClaudeRequest {
    model: String,
    max_tokens: u32,
    messages: Vec<Message>,
    #[serde(skip_serializing_if = "Option::is_none")]
    tools: Option<Vec<ToolDefinition>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: MessageContent,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
    Text(String),
    Blocks(Vec<ContentBlock>),
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
    #[serde(rename = "text")]
    Text { text: String },
    #[serde(rename = "tool_use")]
    ToolUse {
        id: String,
        name: String,
        input: serde_json::Value,
    },
    #[serde(rename = "tool_result")]
    ToolResult {
        tool_use_id: String,
        content: String,
    },
}

#[derive(Debug, Serialize)]
struct ToolDefinition {
    name: String,
    description: String,
    input_schema: serde_json::Value,
}

// Response types
#[derive(Debug, Deserialize)]
struct ClaudeResponse {
    content: Vec<ResponseContentBlock>,
    stop_reason: Option<String>,
}

#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseContentBlock {
    #[serde(rename = "text")]
    Text { text: String },
    #[serde(rename = "tool_use")]
    ToolUse {
        id: String,
        name: String,
        input: serde_json::Value,
    },
}

#[derive(Debug)]
pub struct ChatResult {
    pub content: Option<String>,
    pub tool_calls: Vec<ToolCall>,
    /// Raw tool use blocks for including in subsequent messages
    pub raw_tool_uses: Vec<ResponseContentBlock>,
    pub stop_reason: String,
}

impl ClaudeClient {
    pub fn new(api_key: String, model: ClaudeModel) -> Self {
        Self {
            api_key,
            client: reqwest::Client::new(),
            model,
        }
    }

    pub fn from_env(model: ClaudeModel) -> Result<Self, ClaudeError> {
        let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| ClaudeError::MissingApiKey)?;
        Ok(Self::new(api_key, model))
    }

    pub async fn chat_with_tools(
        &self,
        messages: Vec<Message>,
        tools: &[Tool],
    ) -> Result<ChatResult, ClaudeError> {
        let tool_definitions: Vec<ToolDefinition> = tools
            .iter()
            .map(|t| ToolDefinition {
                name: t.name.clone(),
                description: t.description.clone(),
                input_schema: t.parameters.clone(),
            })
            .collect();

        let request = ClaudeRequest {
            model: self.model.model_id().to_string(),
            max_tokens: 4096,
            messages,
            tools: Some(tool_definitions),
        };

        let response = self
            .client
            .post(CLAUDE_API_URL)
            .header("x-api-key", &self.api_key)
            .header("anthropic-version", ANTHROPIC_VERSION)
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            return Err(ClaudeError::Api(error_text));
        }

        let claude_response: ClaudeResponse = response.json().await?;

        let stop_reason = claude_response.stop_reason.unwrap_or_else(|| "end_turn".to_string());

        // Extract text content and tool uses from content blocks
        let mut text_parts: Vec<String> = Vec::new();
        let mut raw_tool_uses: Vec<ResponseContentBlock> = Vec::new();

        for block in &claude_response.content {
            match block {
                ResponseContentBlock::Text { text } => {
                    if !text.is_empty() {
                        text_parts.push(text.clone());
                    }
                }
                ResponseContentBlock::ToolUse { .. } => {
                    raw_tool_uses.push(block.clone());
                }
            }
        }

        let content = if text_parts.is_empty() {
            None
        } else {
            Some(text_parts.join("\n"))
        };

        // Convert tool uses to ToolCalls
        let tool_calls: Vec<ToolCall> = raw_tool_uses
            .iter()
            .filter_map(|block| {
                if let ResponseContentBlock::ToolUse { id, name, input } = block {
                    Some(ToolCall {
                        id: id.clone(),
                        name: name.clone(),
                        arguments: input.clone(),
                    })
                } else {
                    None
                }
            })
            .collect();

        Ok(ChatResult {
            content,
            tool_calls,
            raw_tool_uses,
            stop_reason,
        })
    }
}

/// Helper to convert Groq-style messages to Claude messages
pub fn groq_messages_to_claude(messages: &[super::groq::Message]) -> Vec<Message> {
    let mut claude_messages: Vec<Message> = Vec::new();

    for msg in messages {
        match msg.role.as_str() {
            "system" => {
                // Claude handles system prompts as first user message
                if let Some(ref content) = msg.content {
                    claude_messages.push(Message {
                        role: "user".to_string(),
                        content: MessageContent::Text(format!("[System Instructions]: {}", content)),
                    });
                    // Add assistant acknowledgment to maintain conversation structure
                    claude_messages.push(Message {
                        role: "assistant".to_string(),
                        content: MessageContent::Text("Understood. I'll follow these instructions.".to_string()),
                    });
                }
            }
            "user" => {
                if let Some(ref content) = msg.content {
                    claude_messages.push(Message {
                        role: "user".to_string(),
                        content: MessageContent::Text(content.clone()),
                    });
                }
            }
            "assistant" => {
                let mut blocks: Vec<ContentBlock> = Vec::new();

                // Add text content if present
                if let Some(ref content) = msg.content {
                    if !content.is_empty() {
                        blocks.push(ContentBlock::Text { text: content.clone() });
                    }
                }

                // Add tool uses if present
                if let Some(ref tool_calls) = msg.tool_calls {
                    for tc in tool_calls {
                        let input: serde_json::Value =
                            serde_json::from_str(&tc.function.arguments).unwrap_or_default();
                        blocks.push(ContentBlock::ToolUse {
                            id: tc.id.clone(),
                            name: tc.function.name.clone(),
                            input,
                        });
                    }
                }

                if !blocks.is_empty() {
                    claude_messages.push(Message {
                        role: "assistant".to_string(),
                        content: MessageContent::Blocks(blocks),
                    });
                }
            }
            "tool" => {
                // Tool results in Claude go in a user message with tool_result blocks
                if let Some(ref content) = msg.content {
                    let tool_use_id = msg.tool_call_id.clone().unwrap_or_default();
                    claude_messages.push(Message {
                        role: "user".to_string(),
                        content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
                            tool_use_id,
                            content: content.clone(),
                        }]),
                    });
                }
            }
            _ => {}
        }
    }

    claude_messages
}