summaryrefslogblamecommitdiff
path: root/makima/src/llm/groq.rs
blob: ee01fcf3d7debc81d662a9f927b04a54742bd5d2 (plain) (tree)





























































































                                                                             

                                                                    





























































                                                                                           



                                                                           
                                

                                               






                                                                                            
                           



                                                
//! Groq API client for LLM tool calling.

use serde::{Deserialize, Serialize};
use thiserror::Error;

use super::tools::{Tool, ToolCall};

const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions";
const MODEL: &str = "moonshotai/kimi-k2-instruct-0905";

#[derive(Debug, Error)]
pub enum GroqError {
    #[error("HTTP request failed: {0}")]
    Request(#[from] reqwest::Error),
    #[error("API error: {0}")]
    Api(String),
    #[error("Missing API key")]
    MissingApiKey,
}

#[derive(Debug, Clone)]
pub struct GroqClient {
    api_key: String,
    client: reqwest::Client,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
    pub role: String,
    pub content: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_calls: Option<Vec<ToolCallResponse>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_call_id: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallResponse {
    pub id: String,
    #[serde(rename = "type")]
    pub call_type: String,
    pub function: FunctionCall,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
    pub name: String,
    pub arguments: String,
}

#[derive(Debug, Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<Message>,
    tools: Vec<ToolDefinition>,
    tool_choice: String,
}

#[derive(Debug, Serialize)]
struct ToolDefinition {
    #[serde(rename = "type")]
    tool_type: String,
    function: FunctionDefinition,
}

#[derive(Debug, Serialize)]
struct FunctionDefinition {
    name: String,
    description: String,
    parameters: serde_json::Value,
}

#[derive(Debug, Deserialize)]
struct ChatResponse {
    choices: Vec<Choice>,
}

#[derive(Debug, Deserialize)]
struct Choice {
    message: MessageResponse,
    finish_reason: String,
}

#[derive(Debug, Deserialize)]
struct MessageResponse {
    role: String,
    content: Option<String>,
    tool_calls: Option<Vec<ToolCallResponse>>,
}

#[derive(Debug)]
pub struct ChatResult {
    pub content: Option<String>,
    pub tool_calls: Vec<ToolCall>,
    /// Raw tool call responses for including in subsequent messages
    pub raw_tool_calls: Vec<ToolCallResponse>,
    pub finish_reason: String,
}

impl GroqClient {
    pub fn new(api_key: String) -> Self {
        Self {
            api_key,
            client: reqwest::Client::new(),
        }
    }

    pub fn from_env() -> Result<Self, GroqError> {
        let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?;
        Ok(Self::new(api_key))
    }

    pub async fn chat_with_tools(
        &self,
        messages: Vec<Message>,
        tools: &[Tool],
    ) -> Result<ChatResult, GroqError> {
        let tool_definitions: Vec<ToolDefinition> = tools
            .iter()
            .map(|t| ToolDefinition {
                tool_type: "function".to_string(),
                function: FunctionDefinition {
                    name: t.name.clone(),
                    description: t.description.clone(),
                    parameters: t.parameters.clone(),
                },
            })
            .collect();

        let request = ChatRequest {
            model: MODEL.to_string(),
            messages,
            tools: tool_definitions,
            tool_choice: "auto".to_string(),
        };

        let response = self
            .client
            .post(GROQ_API_URL)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request)
            .send()
            .await?;

        if !response.status().is_success() {
            let error_text = response.text().await.unwrap_or_default();
            return Err(GroqError::Api(error_text));
        }

        let chat_response: ChatResponse = response.json().await?;

        let choice = chat_response
            .choices
            .into_iter()
            .next()
            .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?;

        let raw_tool_calls = choice.message.tool_calls.unwrap_or_default();

        let tool_calls = raw_tool_calls
            .iter()
            .map(|tc| ToolCall {
                id: tc.id.clone(),
                name: tc.function.name.clone(),
                arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
            })
            .collect();

        Ok(ChatResult {
            content: choice.message.content,
            tool_calls,
            raw_tool_calls,
            finish_reason: choice.finish_reason,
        })
    }
}