summaryrefslogtreecommitdiff
path: root/makima/src/llm/groq.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2025-12-23 14:43:23 +0000
committersoryu <soryu@soryu.co>2025-12-23 14:47:18 +0000
commit555061b179b8ec034cb70f9a2dd6c823ced0f637 (patch)
tree0545b4395dab6d957884d8d36bf15b8da529dc1f /makima/src/llm/groq.rs
parenta32dc56d2e5447ef8988cb98b8686476cc94e70c (diff)
downloadsoryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.tar.gz
soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.zip
Add file body and initial tool call system
Diffstat (limited to 'makima/src/llm/groq.rs')
-rw-r--r--makima/src/llm/groq.rs175
1 files changed, 175 insertions, 0 deletions
diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs
new file mode 100644
index 0000000..be0e2bc
--- /dev/null
+++ b/makima/src/llm/groq.rs
@@ -0,0 +1,175 @@
+//! Groq API client for LLM tool calling.
+
+use serde::{Deserialize, Serialize};
+use thiserror::Error;
+
+use super::tools::{Tool, ToolCall};
+
+const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions";
+const MODEL: &str = "moonshotai/kimi-k2-instruct-0905";
+
+#[derive(Debug, Error)]
+pub enum GroqError {
+ #[error("HTTP request failed: {0}")]
+ Request(#[from] reqwest::Error),
+ #[error("API error: {0}")]
+ Api(String),
+ #[error("Missing API key")]
+ MissingApiKey,
+}
+
+#[derive(Debug, Clone)]
+pub struct GroqClient {
+ api_key: String,
+ client: reqwest::Client,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Message {
+ pub role: String,
+ pub content: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_calls: Option<Vec<ToolCallResponse>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_call_id: Option<String>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ToolCallResponse {
+ pub id: String,
+ #[serde(rename = "type")]
+ pub call_type: String,
+ pub function: FunctionCall,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct FunctionCall {
+ pub name: String,
+ pub arguments: String,
+}
+
+#[derive(Debug, Serialize)]
+struct ChatRequest {
+ model: String,
+ messages: Vec<Message>,
+ tools: Vec<ToolDefinition>,
+ tool_choice: String,
+}
+
+#[derive(Debug, Serialize)]
+struct ToolDefinition {
+ #[serde(rename = "type")]
+ tool_type: String,
+ function: FunctionDefinition,
+}
+
+#[derive(Debug, Serialize)]
+struct FunctionDefinition {
+ name: String,
+ description: String,
+ parameters: serde_json::Value,
+}
+
+#[derive(Debug, Deserialize)]
+struct ChatResponse {
+ choices: Vec<Choice>,
+}
+
+#[derive(Debug, Deserialize)]
+struct Choice {
+ message: MessageResponse,
+ finish_reason: String,
+}
+
+#[derive(Debug, Deserialize)]
+struct MessageResponse {
+ role: String,
+ content: Option<String>,
+ tool_calls: Option<Vec<ToolCallResponse>>,
+}
+
+#[derive(Debug)]
+pub struct ChatResult {
+ pub content: Option<String>,
+ pub tool_calls: Vec<ToolCall>,
+ pub finish_reason: String,
+}
+
+impl GroqClient {
+ pub fn new(api_key: String) -> Self {
+ Self {
+ api_key,
+ client: reqwest::Client::new(),
+ }
+ }
+
+ pub fn from_env() -> Result<Self, GroqError> {
+ let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?;
+ Ok(Self::new(api_key))
+ }
+
+ pub async fn chat_with_tools(
+ &self,
+ messages: Vec<Message>,
+ tools: &[Tool],
+ ) -> Result<ChatResult, GroqError> {
+ let tool_definitions: Vec<ToolDefinition> = tools
+ .iter()
+ .map(|t| ToolDefinition {
+ tool_type: "function".to_string(),
+ function: FunctionDefinition {
+ name: t.name.clone(),
+ description: t.description.clone(),
+ parameters: t.parameters.clone(),
+ },
+ })
+ .collect();
+
+ let request = ChatRequest {
+ model: MODEL.to_string(),
+ messages,
+ tools: tool_definitions,
+ tool_choice: "auto".to_string(),
+ };
+
+ let response = self
+ .client
+ .post(GROQ_API_URL)
+ .header("Authorization", format!("Bearer {}", self.api_key))
+ .header("Content-Type", "application/json")
+ .json(&request)
+ .send()
+ .await?;
+
+ if !response.status().is_success() {
+ let error_text = response.text().await.unwrap_or_default();
+ return Err(GroqError::Api(error_text));
+ }
+
+ let chat_response: ChatResponse = response.json().await?;
+
+ let choice = chat_response
+ .choices
+ .into_iter()
+ .next()
+ .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?;
+
+ let tool_calls = choice
+ .message
+ .tool_calls
+ .unwrap_or_default()
+ .into_iter()
+ .map(|tc| ToolCall {
+ id: tc.id,
+ name: tc.function.name,
+ arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
+ })
+ .collect();
+
+ Ok(ChatResult {
+ content: choice.message.content,
+ tool_calls,
+ finish_reason: choice.finish_reason,
+ })
+ }
+}