//! Groq API client for LLM tool calling. use serde::{Deserialize, Serialize}; use thiserror::Error; use super::tools::{Tool, ToolCall}; const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions"; const MODEL: &str = "moonshotai/kimi-k2-instruct-0905"; #[derive(Debug, Error)] pub enum GroqError { #[error("HTTP request failed: {0}")] Request(#[from] reqwest::Error), #[error("API error: {0}")] Api(String), #[error("Missing API key")] MissingApiKey, } #[derive(Debug, Clone)] pub struct GroqClient { api_key: String, client: reqwest::Client, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: String, pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCallResponse { pub id: String, #[serde(rename = "type")] pub call_type: String, pub function: FunctionCall, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionCall { pub name: String, pub arguments: String, } #[derive(Debug, Serialize)] struct ChatRequest { model: String, messages: Vec, tools: Vec, tool_choice: String, } #[derive(Debug, Serialize)] struct ToolDefinition { #[serde(rename = "type")] tool_type: String, function: FunctionDefinition, } #[derive(Debug, Serialize)] struct FunctionDefinition { name: String, description: String, parameters: serde_json::Value, } #[derive(Debug, Deserialize)] struct ChatResponse { choices: Vec, } #[derive(Debug, Deserialize)] struct Choice { message: MessageResponse, finish_reason: String, } #[derive(Debug, Deserialize)] struct MessageResponse { #[allow(dead_code)] role: String, content: Option, tool_calls: Option>, } #[derive(Debug)] pub struct ChatResult { pub content: Option, pub tool_calls: Vec, /// Raw tool call responses for including in subsequent messages pub raw_tool_calls: Vec, pub finish_reason: String, } impl GroqClient { pub fn new(api_key: String) -> Self { Self { api_key, client: reqwest::Client::new(), } } pub fn from_env() -> Result { let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?; Ok(Self::new(api_key)) } pub async fn chat_with_tools( &self, messages: Vec, tools: &[Tool], ) -> Result { let tool_definitions: Vec = tools .iter() .map(|t| ToolDefinition { tool_type: "function".to_string(), function: FunctionDefinition { name: t.name.clone(), description: t.description.clone(), parameters: t.parameters.clone(), }, }) .collect(); let request = ChatRequest { model: MODEL.to_string(), messages, tools: tool_definitions, tool_choice: "auto".to_string(), }; let response = self .client .post(GROQ_API_URL) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .json(&request) .send() .await?; if !response.status().is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(GroqError::Api(error_text)); } let chat_response: ChatResponse = response.json().await?; let choice = chat_response .choices .into_iter() .next() .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; let raw_tool_calls = choice.message.tool_calls.unwrap_or_default(); let tool_calls = raw_tool_calls .iter() .map(|tc| ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), }) .collect(); Ok(ChatResult { content: choice.message.content, tool_calls, raw_tool_calls, finish_reason: choice.finish_reason, }) } }