diff options
Diffstat (limited to 'makima/src/llm/groq.rs')
| -rw-r--r-- | makima/src/llm/groq.rs | 177 |
1 files changed, 0 insertions, 177 deletions
diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs deleted file mode 100644 index ee01fcf..0000000 --- a/makima/src/llm/groq.rs +++ /dev/null @@ -1,177 +0,0 @@ -//! Groq API client for LLM tool calling. - -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -use super::tools::{Tool, ToolCall}; - -const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions"; -const MODEL: &str = "moonshotai/kimi-k2-instruct-0905"; - -#[derive(Debug, Error)] -pub enum GroqError { - #[error("HTTP request failed: {0}")] - Request(#[from] reqwest::Error), - #[error("API error: {0}")] - Api(String), - #[error("Missing API key")] - MissingApiKey, -} - -#[derive(Debug, Clone)] -pub struct GroqClient { - api_key: String, - client: reqwest::Client, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: String, - pub content: Option<String>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option<Vec<ToolCallResponse>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option<String>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallResponse { - pub id: String, - #[serde(rename = "type")] - pub call_type: String, - pub function: FunctionCall, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionCall { - pub name: String, - pub arguments: String, -} - -#[derive(Debug, Serialize)] -struct ChatRequest { - model: String, - messages: Vec<Message>, - tools: Vec<ToolDefinition>, - tool_choice: String, -} - -#[derive(Debug, Serialize)] -struct ToolDefinition { - #[serde(rename = "type")] - tool_type: String, - function: FunctionDefinition, -} - -#[derive(Debug, Serialize)] -struct FunctionDefinition { - name: String, - description: String, - parameters: serde_json::Value, -} - -#[derive(Debug, Deserialize)] -struct ChatResponse { - choices: Vec<Choice>, -} - -#[derive(Debug, Deserialize)] -struct Choice { - message: MessageResponse, - finish_reason: String, -} - -#[derive(Debug, Deserialize)] -struct MessageResponse { - role: String, - content: Option<String>, - tool_calls: Option<Vec<ToolCallResponse>>, -} - -#[derive(Debug)] -pub struct ChatResult { - pub content: Option<String>, - pub tool_calls: Vec<ToolCall>, - /// Raw tool call responses for including in subsequent messages - pub raw_tool_calls: Vec<ToolCallResponse>, - pub finish_reason: String, -} - -impl GroqClient { - pub fn new(api_key: String) -> Self { - Self { - api_key, - client: reqwest::Client::new(), - } - } - - pub fn from_env() -> Result<Self, GroqError> { - let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?; - Ok(Self::new(api_key)) - } - - pub async fn chat_with_tools( - &self, - messages: Vec<Message>, - tools: &[Tool], - ) -> Result<ChatResult, GroqError> { - let tool_definitions: Vec<ToolDefinition> = tools - .iter() - .map(|t| ToolDefinition { - tool_type: "function".to_string(), - function: FunctionDefinition { - name: t.name.clone(), - description: t.description.clone(), - parameters: t.parameters.clone(), - }, - }) - .collect(); - - let request = ChatRequest { - model: MODEL.to_string(), - messages, - tools: tool_definitions, - tool_choice: "auto".to_string(), - }; - - let response = self - .client - .post(GROQ_API_URL) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await.unwrap_or_default(); - return Err(GroqError::Api(error_text)); - } - - let chat_response: ChatResponse = response.json().await?; - - let choice = chat_response - .choices - .into_iter() - .next() - .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; - - let raw_tool_calls = choice.message.tool_calls.unwrap_or_default(); - - let tool_calls = raw_tool_calls - .iter() - .map(|tc| ToolCall { - id: tc.id.clone(), - name: tc.function.name.clone(), - arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), - }) - .collect(); - - Ok(ChatResult { - content: choice.message.content, - tool_calls, - raw_tool_calls, - finish_reason: choice.finish_reason, - }) - } -} |
