//! Groq API client for LLM tool calling.
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::tools::{Tool, ToolCall};
const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions";
const MODEL: &str = "moonshotai/kimi-k2-instruct-0905";
#[derive(Debug, Error)]
pub enum GroqError {
#[error("HTTP request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("API error: {0}")]
Api(String),
#[error("Missing API key")]
MissingApiKey,
}
#[derive(Debug, Clone)]
pub struct GroqClient {
api_key: String,
client: reqwest::Client,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallResponse>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallResponse {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
tool_choice: String,
}
#[derive(Debug, Serialize)]
struct ToolDefinition {
#[serde(rename = "type")]
tool_type: String,
function: FunctionDefinition,
}
#[derive(Debug, Serialize)]
struct FunctionDefinition {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: MessageResponse,
finish_reason: String,
}
#[derive(Debug, Deserialize)]
struct MessageResponse {
role: String,
content: Option<String>,
tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Debug)]
pub struct ChatResult {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
/// Raw tool call responses for including in subsequent messages
pub raw_tool_calls: Vec<ToolCallResponse>,
pub finish_reason: String,
}
impl GroqClient {
pub fn new(api_key: String) -> Self {
Self {
api_key,
client: reqwest::Client::new(),
}
}
pub fn from_env() -> Result<Self, GroqError> {
let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?;
Ok(Self::new(api_key))
}
pub async fn chat_with_tools(
&self,
messages: Vec<Message>,
tools: &[Tool],
) -> Result<ChatResult, GroqError> {
let tool_definitions: Vec<ToolDefinition> = tools
.iter()
.map(|t| ToolDefinition {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
},
})
.collect();
let request = ChatRequest {
model: MODEL.to_string(),
messages,
tools: tool_definitions,
tool_choice: "auto".to_string(),
};
let response = self
.client
.post(GROQ_API_URL)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(GroqError::Api(error_text));
}
let chat_response: ChatResponse = response.json().await?;
let choice = chat_response
.choices
.into_iter()
.next()
.ok_or_else(|| GroqError::Api("No choices in response".to_string()))?;
let raw_tool_calls = choice.message.tool_calls.unwrap_or_default();
let tool_calls = raw_tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
})
.collect();
Ok(ChatResult {
content: choice.message.content,
tool_calls,
raw_tool_calls,
finish_reason: choice.finish_reason,
})
}
}