diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 14:43:23 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (patch) | |
| tree | 0545b4395dab6d957884d8d36bf15b8da529dc1f /makima/src/llm/groq.rs | |
| parent | a32dc56d2e5447ef8988cb98b8686476cc94e70c (diff) | |
| download | soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.tar.gz soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.zip | |
Add file body and initial tool call system
Diffstat (limited to 'makima/src/llm/groq.rs')
| -rw-r--r-- | makima/src/llm/groq.rs | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs new file mode 100644 index 0000000..be0e2bc --- /dev/null +++ b/makima/src/llm/groq.rs @@ -0,0 +1,175 @@ +//! Groq API client for LLM tool calling. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::tools::{Tool, ToolCall}; + +const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions"; +const MODEL: &str = "moonshotai/kimi-k2-instruct-0905"; + +#[derive(Debug, Error)] +pub enum GroqError { + #[error("HTTP request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("API error: {0}")] + Api(String), + #[error("Missing API key")] + MissingApiKey, +} + +#[derive(Debug, Clone)] +pub struct GroqClient { + api_key: String, + client: reqwest::Client, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option<Vec<ToolCallResponse>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResponse { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec<Message>, + tools: Vec<ToolDefinition>, + tool_choice: String, +} + +#[derive(Debug, Serialize)] +struct ToolDefinition { + #[serde(rename = "type")] + tool_type: String, + function: FunctionDefinition, +} + +#[derive(Debug, Serialize)] +struct FunctionDefinition { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct ChatResponse { + choices: Vec<Choice>, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: MessageResponse, + finish_reason: String, +} + +#[derive(Debug, Deserialize)] +struct MessageResponse { + role: String, + content: Option<String>, + tool_calls: Option<Vec<ToolCallResponse>>, +} + +#[derive(Debug)] +pub struct ChatResult { + pub content: Option<String>, + pub tool_calls: Vec<ToolCall>, + pub finish_reason: String, +} + +impl GroqClient { + pub fn new(api_key: String) -> Self { + Self { + api_key, + client: reqwest::Client::new(), + } + } + + pub fn from_env() -> Result<Self, GroqError> { + let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?; + Ok(Self::new(api_key)) + } + + pub async fn chat_with_tools( + &self, + messages: Vec<Message>, + tools: &[Tool], + ) -> Result<ChatResult, GroqError> { + let tool_definitions: Vec<ToolDefinition> = tools + .iter() + .map(|t| ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: t.name.clone(), + description: t.description.clone(), + parameters: t.parameters.clone(), + }, + }) + .collect(); + + let request = ChatRequest { + model: MODEL.to_string(), + messages, + tools: tool_definitions, + tool_choice: "auto".to_string(), + }; + + let response = self + .client + .post(GROQ_API_URL) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(GroqError::Api(error_text)); + } + + let chat_response: ChatResponse = response.json().await?; + + let choice = chat_response + .choices + .into_iter() + .next() + .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; + + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ToolCall { + id: tc.id, + name: tc.function.name, + arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), + }) + .collect(); + + Ok(ChatResult { + content: choice.message.content, + tool_calls, + finish_reason: choice.finish_reason, + }) + } +} |
