diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 14:43:23 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (patch) | |
| tree | 0545b4395dab6d957884d8d36bf15b8da529dc1f /makima/src | |
| parent | a32dc56d2e5447ef8988cb98b8686476cc94e70c (diff) | |
| download | soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.tar.gz soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.zip | |
Add file body and initial tool call system
Diffstat (limited to 'makima/src')
| -rw-r--r-- | makima/src/db/models.rs | 43 | ||||
| -rw-r--r-- | makima/src/db/repository.rs | 21 | ||||
| -rw-r--r-- | makima/src/lib.rs | 1 | ||||
| -rw-r--r-- | makima/src/llm/groq.rs | 175 | ||||
| -rw-r--r-- | makima/src/llm/mod.rs | 7 | ||||
| -rw-r--r-- | makima/src/llm/tools.rs | 618 | ||||
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 296 | ||||
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 83 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 1 | ||||
| -rw-r--r-- | makima/src/server/mod.rs | 5 |
10 files changed, 1236 insertions, 14 deletions
diff --git a/makima/src/db/models.rs b/makima/src/db/models.rs index 45b0e53..135ae75 100644 --- a/makima/src/db/models.rs +++ b/makima/src/db/models.rs @@ -18,6 +18,40 @@ pub struct TranscriptEntry { pub is_final: bool, } +/// Chart type for visualization elements +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum ChartType { + Line, + Bar, + Pie, + Area, +} + +/// Body element types for structured file content +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum BodyElement { + /// Heading element (h1-h6) + Heading { level: u8, text: String }, + /// Paragraph text + Paragraph { text: String }, + /// Chart visualization + Chart { + #[serde(rename = "chartType")] + chart_type: ChartType, + title: Option<String>, + data: serde_json::Value, + config: Option<serde_json::Value>, + }, + /// Image element (deferred for MVP) + Image { + src: String, + alt: Option<String>, + caption: Option<String>, + }, +} + /// File record from the database. #[derive(Debug, Clone, FromRow, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] @@ -29,6 +63,11 @@ pub struct File { #[sqlx(json)] pub transcript: Vec<TranscriptEntry>, pub location: Option<String>, + /// AI-generated summary of the transcript + pub summary: Option<String>, + /// Structured body content (headings, paragraphs, charts) + #[sqlx(json)] + pub body: Vec<BodyElement>, pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, } @@ -57,6 +96,10 @@ pub struct UpdateFileRequest { pub description: Option<String>, /// New transcript (optional) pub transcript: Option<Vec<TranscriptEntry>>, + /// AI-generated summary (optional) + pub summary: Option<String>, + /// Structured body content (optional) + pub body: Option<Vec<BodyElement>>, } /// Response for file list endpoint. diff --git a/makima/src/db/repository.rs b/makima/src/db/repository.rs index 90cb1b9..f8b90b3 100644 --- a/makima/src/db/repository.rs +++ b/makima/src/db/repository.rs @@ -19,12 +19,13 @@ fn generate_default_name() -> String { pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, sqlx::Error> { let name = req.name.unwrap_or_else(generate_default_name); let transcript_json = serde_json::to_value(&req.transcript).unwrap_or_default(); + let body_json = serde_json::to_value::<Vec<super::models::BodyElement>>(vec![]).unwrap(); sqlx::query_as::<_, File>( r#" - INSERT INTO files (owner_id, name, description, transcript, location) - VALUES ($1, $2, $3, $4, $5) - RETURNING id, owner_id, name, description, transcript, location, created_at, updated_at + INSERT INTO files (owner_id, name, description, transcript, location, summary, body) + VALUES ($1, $2, $3, $4, $5, NULL, $6) + RETURNING id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at "#, ) .bind(ANONYMOUS_OWNER_ID) @@ -32,6 +33,7 @@ pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, .bind(&req.description) .bind(&transcript_json) .bind(&req.location) + .bind(&body_json) .fetch_one(pool) .await } @@ -40,7 +42,7 @@ pub async fn create_file(pool: &PgPool, req: CreateFileRequest) -> Result<File, pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Error> { sqlx::query_as::<_, File>( r#" - SELECT id, owner_id, name, description, transcript, location, created_at, updated_at + SELECT id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at FROM files WHERE id = $1 AND owner_id = $2 "#, @@ -55,7 +57,7 @@ pub async fn get_file(pool: &PgPool, id: Uuid) -> Result<Option<File>, sqlx::Err pub async fn list_files(pool: &PgPool) -> Result<Vec<File>, sqlx::Error> { sqlx::query_as::<_, File>( r#" - SELECT id, owner_id, name, description, transcript, location, created_at, updated_at + SELECT id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at FROM files WHERE owner_id = $1 ORDER BY created_at DESC @@ -83,13 +85,16 @@ pub async fn update_file( let description = req.description.or(existing.description); let transcript = req.transcript.unwrap_or(existing.transcript); let transcript_json = serde_json::to_value(&transcript).unwrap_or_default(); + let summary = req.summary.or(existing.summary); + let body = req.body.unwrap_or(existing.body); + let body_json = serde_json::to_value(&body).unwrap_or_default(); sqlx::query_as::<_, File>( r#" UPDATE files - SET name = $3, description = $4, transcript = $5 + SET name = $3, description = $4, transcript = $5, summary = $6, body = $7, updated_at = NOW() WHERE id = $1 AND owner_id = $2 - RETURNING id, owner_id, name, description, transcript, location, created_at, updated_at + RETURNING id, owner_id, name, description, transcript, location, summary, body, created_at, updated_at "#, ) .bind(id) @@ -97,6 +102,8 @@ pub async fn update_file( .bind(&name) .bind(&description) .bind(&transcript_json) + .bind(&summary) + .bind(&body_json) .fetch_optional(pool) .await } diff --git a/makima/src/lib.rs b/makima/src/lib.rs index 35d376c..064b123 100644 --- a/makima/src/lib.rs +++ b/makima/src/lib.rs @@ -1,5 +1,6 @@ pub mod audio; pub mod db; pub mod listen; +pub mod llm; pub mod server; pub mod tts; diff --git a/makima/src/llm/groq.rs b/makima/src/llm/groq.rs new file mode 100644 index 0000000..be0e2bc --- /dev/null +++ b/makima/src/llm/groq.rs @@ -0,0 +1,175 @@ +//! Groq API client for LLM tool calling. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::tools::{Tool, ToolCall}; + +const GROQ_API_URL: &str = "https://api.groq.com/openai/v1/chat/completions"; +const MODEL: &str = "moonshotai/kimi-k2-instruct-0905"; + +#[derive(Debug, Error)] +pub enum GroqError { + #[error("HTTP request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("API error: {0}")] + Api(String), + #[error("Missing API key")] + MissingApiKey, +} + +#[derive(Debug, Clone)] +pub struct GroqClient { + api_key: String, + client: reqwest::Client, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option<Vec<ToolCallResponse>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResponse { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec<Message>, + tools: Vec<ToolDefinition>, + tool_choice: String, +} + +#[derive(Debug, Serialize)] +struct ToolDefinition { + #[serde(rename = "type")] + tool_type: String, + function: FunctionDefinition, +} + +#[derive(Debug, Serialize)] +struct FunctionDefinition { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct ChatResponse { + choices: Vec<Choice>, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: MessageResponse, + finish_reason: String, +} + +#[derive(Debug, Deserialize)] +struct MessageResponse { + role: String, + content: Option<String>, + tool_calls: Option<Vec<ToolCallResponse>>, +} + +#[derive(Debug)] +pub struct ChatResult { + pub content: Option<String>, + pub tool_calls: Vec<ToolCall>, + pub finish_reason: String, +} + +impl GroqClient { + pub fn new(api_key: String) -> Self { + Self { + api_key, + client: reqwest::Client::new(), + } + } + + pub fn from_env() -> Result<Self, GroqError> { + let api_key = std::env::var("GROQ_API_KEY").map_err(|_| GroqError::MissingApiKey)?; + Ok(Self::new(api_key)) + } + + pub async fn chat_with_tools( + &self, + messages: Vec<Message>, + tools: &[Tool], + ) -> Result<ChatResult, GroqError> { + let tool_definitions: Vec<ToolDefinition> = tools + .iter() + .map(|t| ToolDefinition { + tool_type: "function".to_string(), + function: FunctionDefinition { + name: t.name.clone(), + description: t.description.clone(), + parameters: t.parameters.clone(), + }, + }) + .collect(); + + let request = ChatRequest { + model: MODEL.to_string(), + messages, + tools: tool_definitions, + tool_choice: "auto".to_string(), + }; + + let response = self + .client + .post(GROQ_API_URL) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(GroqError::Api(error_text)); + } + + let chat_response: ChatResponse = response.json().await?; + + let choice = chat_response + .choices + .into_iter() + .next() + .ok_or_else(|| GroqError::Api("No choices in response".to_string()))?; + + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ToolCall { + id: tc.id, + name: tc.function.name, + arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(), + }) + .collect(); + + Ok(ChatResult { + content: choice.message.content, + tool_calls, + finish_reason: choice.finish_reason, + }) + } +} diff --git a/makima/src/llm/mod.rs b/makima/src/llm/mod.rs new file mode 100644 index 0000000..00f3333 --- /dev/null +++ b/makima/src/llm/mod.rs @@ -0,0 +1,7 @@ +//! LLM integration module for file editing via tool calling. + +pub mod groq; +pub mod tools; + +pub use groq::GroqClient; +pub use tools::{execute_tool_call, Tool, ToolCall, ToolResult, AVAILABLE_TOOLS}; diff --git a/makima/src/llm/tools.rs b/makima/src/llm/tools.rs new file mode 100644 index 0000000..3bd102f --- /dev/null +++ b/makima/src/llm/tools.rs @@ -0,0 +1,618 @@ +//! Tool definitions for file editing via LLM. + +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::db::models::{BodyElement, ChartType}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ToolResult { + pub success: bool, + pub message: String, +} + +/// Available tools for file editing +pub static AVAILABLE_TOOLS: once_cell::sync::Lazy<Vec<Tool>> = + once_cell::sync::Lazy::new(|| { + vec![ + Tool { + name: "add_heading".to_string(), + description: "Add a heading element to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "level": { + "type": "integer", + "description": "Heading level (1-6)", + "minimum": 1, + "maximum": 6 + }, + "text": { + "type": "string", + "description": "The heading text" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["level", "text"] + }), + }, + Tool { + name: "add_paragraph".to_string(), + description: "Add a paragraph element to the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The paragraph text" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["text"] + }), + }, + Tool { + name: "add_chart".to_string(), + description: "Add a chart visualization to the file body. Supports line, bar, pie, and area charts.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "chart_type": { + "type": "string", + "enum": ["line", "bar", "pie", "area"], + "description": "Type of chart to create" + }, + "title": { + "type": "string", + "description": "Optional chart title" + }, + "data": { + "type": "array", + "description": "Array of data points. Each point should have a 'name' field and one or more numeric value fields.", + "items": { + "type": "object" + } + }, + "config": { + "type": "object", + "description": "Optional chart configuration (colors, axes, etc.)" + }, + "position": { + "type": "integer", + "description": "Optional position to insert at (0-indexed). If not specified, appends to end." + } + }, + "required": ["chart_type", "data"] + }), + }, + Tool { + name: "remove_element".to_string(), + description: "Remove an element from the file body by index".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of element to remove (0-indexed)" + } + }, + "required": ["index"] + }), + }, + Tool { + name: "update_element".to_string(), + description: "Update an existing element in the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "Index of element to update (0-indexed)" + }, + "element": { + "type": "object", + "description": "New element data. Must include 'type' field (heading, paragraph, chart)." + } + }, + "required": ["index", "element"] + }), + }, + Tool { + name: "reorder_elements".to_string(), + description: "Move an element from one position to another".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "from_index": { + "type": "integer", + "description": "Current index of the element" + }, + "to_index": { + "type": "integer", + "description": "New index for the element" + } + }, + "required": ["from_index", "to_index"] + }), + }, + Tool { + name: "set_summary".to_string(), + description: "Set the file summary text".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "summary": { + "type": "string", + "description": "The summary text" + } + }, + "required": ["summary"] + }), + }, + Tool { + name: "parse_csv".to_string(), + description: "Parse CSV data into JSON format suitable for charts".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "csv": { + "type": "string", + "description": "CSV data string with header row" + } + }, + "required": ["csv"] + }), + }, + Tool { + name: "clear_body".to_string(), + description: "Clear all elements from the file body".to_string(), + parameters: json!({ + "type": "object", + "properties": {} + }), + }, + ] + }); + +/// Result of executing a tool call with modified file state +#[derive(Debug)] +pub struct ToolExecutionResult { + pub result: ToolResult, + pub new_body: Option<Vec<BodyElement>>, + pub new_summary: Option<String>, + pub parsed_data: Option<serde_json::Value>, +} + +/// Execute a tool call and return the result along with any state changes +pub fn execute_tool_call( + call: &ToolCall, + current_body: &[BodyElement], + current_summary: Option<&str>, +) -> ToolExecutionResult { + match call.name.as_str() { + "add_heading" => execute_add_heading(call, current_body), + "add_paragraph" => execute_add_paragraph(call, current_body), + "add_chart" => execute_add_chart(call, current_body), + "remove_element" => execute_remove_element(call, current_body), + "update_element" => execute_update_element(call, current_body), + "reorder_elements" => execute_reorder_elements(call, current_body), + "set_summary" => execute_set_summary(call, current_summary), + "parse_csv" => execute_parse_csv(call), + "clear_body" => execute_clear_body(), + _ => ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Unknown tool: {}", call.name), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }, + } +} + +fn execute_add_heading(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let level = call.arguments.get("level").and_then(|v| v.as_u64()).unwrap_or(1) as u8; + let text = call + .arguments + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Heading { level, text: text.clone() }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added heading: {}", text), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_add_paragraph(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let text = call + .arguments + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Paragraph { text: text.clone() }; + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text + }; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Added paragraph: {}", preview), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_add_chart(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let chart_type_str = call + .arguments + .get("chart_type") + .and_then(|v| v.as_str()) + .unwrap_or("bar"); + + let chart_type = match chart_type_str { + "line" => ChartType::Line, + "bar" => ChartType::Bar, + "pie" => ChartType::Pie, + "area" => ChartType::Area, + _ => ChartType::Bar, + }; + + let title = call + .arguments + .get("title") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let data = call + .arguments + .get("data") + .cloned() + .unwrap_or(json!([])); + + let config = call.arguments.get("config").cloned(); + let position = call.arguments.get("position").and_then(|v| v.as_u64()); + + let element = BodyElement::Chart { + chart_type, + title: title.clone(), + data, + config, + }; + + let mut new_body = current_body.to_vec(); + + if let Some(pos) = position { + let pos = pos as usize; + if pos <= new_body.len() { + new_body.insert(pos, element); + } else { + new_body.push(element); + } + } else { + new_body.push(element); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!( + "Added {} chart{}", + chart_type_str, + title.map(|t| format!(": {}", t)).unwrap_or_default() + ), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_remove_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let index = call.arguments.get("index").and_then(|v| v.as_u64()); + + let Some(index) = index else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let index = index as usize; + if index >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Index {} out of bounds (body has {} elements)", index, current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let mut new_body = current_body.to_vec(); + new_body.remove(index); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Removed element at index {}", index), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_update_element(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let index = call.arguments.get("index").and_then(|v| v.as_u64()); + let element_json = call.arguments.get("element"); + + let Some(index) = index else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let Some(element_json) = element_json else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing element parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let index = index as usize; + if index >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Index {} out of bounds (body has {} elements)", index, current_body.len()), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let new_element: Result<BodyElement, _> = serde_json::from_value(element_json.clone()); + match new_element { + Ok(element) => { + let mut new_body = current_body.to_vec(); + new_body[index] = element; + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Updated element at index {}", index), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } + } + Err(e) => ToolExecutionResult { + result: ToolResult { + success: false, + message: format!("Invalid element format: {}", e), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }, + } +} + +fn execute_reorder_elements(call: &ToolCall, current_body: &[BodyElement]) -> ToolExecutionResult { + let from_index = call.arguments.get("from_index").and_then(|v| v.as_u64()); + let to_index = call.arguments.get("to_index").and_then(|v| v.as_u64()); + + let (Some(from), Some(to)) = (from_index, to_index) else { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Missing from_index or to_index parameter".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + }; + + let from = from as usize; + let to = to as usize; + + if from >= current_body.len() || to >= current_body.len() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: format!( + "Index out of bounds: from={}, to={}, body has {} elements", + from, to, current_body.len() + ), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let mut new_body = current_body.to_vec(); + let element = new_body.remove(from); + new_body.insert(to, element); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Moved element from index {} to {}", from, to), + }, + new_body: Some(new_body), + new_summary: None, + parsed_data: None, + } +} + +fn execute_set_summary(call: &ToolCall, _current_summary: Option<&str>) -> ToolExecutionResult { + let summary = call + .arguments + .get("summary") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + ToolExecutionResult { + result: ToolResult { + success: true, + message: "Summary updated".to_string(), + }, + new_body: None, + new_summary: Some(summary), + parsed_data: None, + } +} + +fn execute_parse_csv(call: &ToolCall) -> ToolExecutionResult { + let csv = call + .arguments + .get("csv") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let lines: Vec<&str> = csv.lines().collect(); + if lines.is_empty() { + return ToolExecutionResult { + result: ToolResult { + success: false, + message: "Empty CSV data".to_string(), + }, + new_body: None, + new_summary: None, + parsed_data: None, + }; + } + + let headers: Vec<&str> = lines[0].split(',').map(|s| s.trim()).collect(); + let mut data: Vec<serde_json::Value> = Vec::new(); + + for line in lines.iter().skip(1) { + if line.trim().is_empty() { + continue; + } + let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect(); + let mut row = serde_json::Map::new(); + + for (i, header) in headers.iter().enumerate() { + if let Some(value) = values.get(i) { + // Try to parse as number, otherwise use string + if let Ok(num) = value.parse::<f64>() { + row.insert(header.to_string(), json!(num)); + } else { + row.insert(header.to_string(), json!(value)); + } + } + } + + data.push(serde_json::Value::Object(row)); + } + + ToolExecutionResult { + result: ToolResult { + success: true, + message: format!("Parsed {} rows from CSV", data.len()), + }, + new_body: None, + new_summary: None, + parsed_data: Some(json!(data)), + } +} + +fn execute_clear_body() -> ToolExecutionResult { + ToolExecutionResult { + result: ToolResult { + success: true, + message: "Cleared all body elements".to_string(), + }, + new_body: Some(vec![]), + new_summary: None, + parsed_data: None, + } +} diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs new file mode 100644 index 0000000..e6d22ca --- /dev/null +++ b/makima/src/server/handlers/chat.rs @@ -0,0 +1,296 @@ +//! Chat endpoint for LLM-powered file editing. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::db::{models::BodyElement, repository}; +use crate::llm::{ + execute_tool_call, + groq::{GroqClient, GroqError, Message}, + ToolResult, AVAILABLE_TOOLS, +}; +use crate::server::state::SharedState; + +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChatRequest { + /// The user's message/instruction + pub message: String, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChatResponse { + /// The LLM's response message + pub response: String, + /// Tool calls that were executed + pub tool_calls: Vec<ToolCallInfo>, + /// Updated file body after tool execution + pub updated_body: Vec<BodyElement>, + /// Updated summary (if changed) + pub updated_summary: Option<String>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallInfo { + pub name: String, + pub result: ToolResult, +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +/// Chat with a file using LLM tool calling +#[utoipa::path( + post, + path = "/api/v1/files/{id}/chat", + request_body = ChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = ChatResponse), + (status = 404, description = "File not found"), + (status = 500, description = "Internal server error") + ), + params( + ("id" = Uuid, Path, description = "File ID") + ), + tag = "chat" +)] +pub async fn chat_handler( + State(state): State<SharedState>, + Path(id): Path<Uuid>, + Json(request): Json<ChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "Database not configured" + })), + ) + .into_response(); + }; + + // Get the file + let file = match repository::get_file(pool, id).await { + Ok(Some(file)) => file, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ + "error": "File not found" + })), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Database error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Database error: {}", e) + })), + ) + .into_response(); + } + }; + + // Initialize Groq client + let groq = match GroqClient::from_env() { + Ok(client) => client, + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "GROQ_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Groq client error: {}", e) + })), + ) + .into_response(); + } + }; + + // Build context about the file + let file_context = build_file_context(&file); + + // Build messages + let messages = vec![ + Message { + role: "system".to_string(), + content: Some(format!( + "You are a helpful assistant that helps users edit and analyze document files. \ + You have access to tools to add headings, paragraphs, charts, and set summaries. \ + When the user asks you to modify the file, use the appropriate tools.\n\n\ + Current file context:\n{}", + file_context + )), + tool_calls: None, + tool_call_id: None, + }, + Message { + role: "user".to_string(), + content: Some(request.message.clone()), + tool_calls: None, + tool_call_id: None, + }, + ]; + + // Call Groq API + let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await { + Ok(result) => result, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + }; + + // Execute tool calls + let mut current_body = file.body.clone(); + let mut current_summary = file.summary.clone(); + let mut tool_call_infos = Vec::new(); + + for tool_call in &result.tool_calls { + let execution_result = + execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + + // Apply state changes + if let Some(new_body) = execution_result.new_body { + current_body = new_body; + } + if let Some(new_summary) = execution_result.new_summary { + current_summary = Some(new_summary); + } + + tool_call_infos.push(ToolCallInfo { + name: tool_call.name.clone(), + result: execution_result.result, + }); + } + + // Save changes to database if any tools were executed + if !result.tool_calls.is_empty() { + let update_req = crate::db::models::UpdateFileRequest { + name: None, + description: None, + transcript: None, + summary: current_summary.clone(), + body: Some(current_body.clone()), + }; + + if let Err(e) = repository::update_file(pool, id, update_req).await { + tracing::error!("Failed to save file changes: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Failed to save changes: {}", e) + })), + ) + .into_response(); + } + } + + // Build response + let response_text = result.content.unwrap_or_else(|| { + if tool_call_infos.is_empty() { + "I couldn't understand your request. Please try rephrasing.".to_string() + } else { + format!( + "Done! Executed {} tool{}.", + tool_call_infos.len(), + if tool_call_infos.len() == 1 { "" } else { "s" } + ) + } + }); + + ( + StatusCode::OK, + Json(ChatResponse { + response: response_text, + tool_calls: tool_call_infos, + updated_body: current_body, + updated_summary: current_summary, + }), + ) + .into_response() +} + +fn build_file_context(file: &crate::db::models::File) -> String { + let mut context = format!("File: {}\n", file.name); + + if let Some(ref desc) = file.description { + context.push_str(&format!("Description: {}\n", desc)); + } + + if let Some(ref summary) = file.summary { + context.push_str(&format!("Summary: {}\n", summary)); + } + + context.push_str(&format!("Transcript entries: {}\n", file.transcript.len())); + context.push_str(&format!("Body elements: {}\n", file.body.len())); + + // Add body overview + if !file.body.is_empty() { + context.push_str("\nCurrent body elements:\n"); + for (i, element) in file.body.iter().enumerate() { + let desc = match element { + BodyElement::Heading { level, text } => format!("H{}: {}", level, text), + BodyElement::Paragraph { text } => { + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text.clone() + }; + format!("Paragraph: {}", preview) + } + BodyElement::Chart { chart_type, title, .. } => { + format!( + "Chart ({:?}){}", + chart_type, + title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default() + ) + } + BodyElement::Image { alt, .. } => { + format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default()) + } + }; + context.push_str(&format!(" [{}] {}\n", i, desc)); + } + } + + // Add transcript preview if available + if !file.transcript.is_empty() { + context.push_str("\nTranscript preview (first 5 entries):\n"); + for entry in file.transcript.iter().take(5) { + context.push_str(&format!(" - {}: {}\n", entry.speaker, entry.text)); + } + if file.transcript.len() > 5 { + context.push_str(&format!(" ... and {} more entries\n", file.transcript.len() - 5)); + } + } + + context +} diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs index 93062f3..3055cb7 100644 --- a/makima/src/server/handlers/listen.rs +++ b/makima/src/server/handlers/listen.rs @@ -449,21 +449,31 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { // Save final transcript to file if we have one if let Some(fid) = file_id { if let Some(ref pool) = state.db_pool { + // Deduplicate transcript entries before saving + let deduplicated = deduplicate_transcripts(&transcript_entries); + // Mark all entries as final - for entry in &mut transcript_entries { - entry.is_final = true; - } + let final_entries: Vec<TranscriptEntry> = deduplicated + .into_iter() + .map(|mut entry| { + entry.is_final = true; + entry + }) + .collect(); match repository::update_file(pool, fid, UpdateFileRequest { name: None, description: None, - transcript: Some(transcript_entries.clone()), + transcript: Some(final_entries.clone()), + summary: None, + body: None, }).await { Ok(_) => { tracing::info!( session_id = %session_id, file_id = %fid, - transcript_count = transcript_entries.len(), + original_count = transcript_entries.len(), + deduplicated_count = final_entries.len(), "Saved final transcript to file" ); } @@ -502,6 +512,69 @@ fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> { } } +/// Deduplicate transcript entries by removing entries with similar start times and text. +/// +/// Entries are considered duplicates if: +/// - Start times are within 0.5 seconds of each other +/// - Speaker is the same +/// - Text is identical or one is a substring of the other +fn deduplicate_transcripts(entries: &[TranscriptEntry]) -> Vec<TranscriptEntry> { + if entries.is_empty() { + return vec![]; + } + + // Sort by start time + let mut sorted: Vec<TranscriptEntry> = entries.to_vec(); + sorted.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)); + + let mut result: Vec<TranscriptEntry> = Vec::new(); + + for entry in sorted { + // Check if this entry is a duplicate of any existing entry + let is_duplicate = result.iter().any(|existing| { + // Check if start times are close (within 0.5 seconds) + let time_close = (existing.start - entry.start).abs() < 0.5; + + // Check if same speaker + let same_speaker = existing.speaker == entry.speaker; + + // Check if text matches or one contains the other + let text_match = existing.text == entry.text + || existing.text.contains(&entry.text) + || entry.text.contains(&existing.text); + + time_close && same_speaker && text_match + }); + + if !is_duplicate { + result.push(entry); + } else { + // If duplicate, check if the new entry has longer text and update + for existing in &mut result { + let time_close = (existing.start - entry.start).abs() < 0.5; + let same_speaker = existing.speaker == entry.speaker; + + if time_close && same_speaker && entry.text.len() > existing.text.len() { + // Keep the longer text version + existing.text = entry.text.clone(); + existing.end = entry.end; + break; + } + } + } + } + + // Reassign IDs to be sequential + for (i, entry) in result.iter_mut().enumerate() { + let parts: Vec<&str> = entry.id.split('-').collect(); + if let Some(session_prefix) = parts.first() { + entry.id = format!("{}-{}", session_prefix, i + 1); + } + } + + result +} + /// Process audio using sliding window through STT and streaming diarization models. /// /// Only processes the last MAX_WINDOW_SECONDS of audio to maintain constant diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index f249234..b13668a 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,4 +1,5 @@ //! HTTP and WebSocket request handlers. +pub mod chat; pub mod files; pub mod listen; diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs index bc3e679..a8f98a6 100644 --- a/makima/src/server/mod.rs +++ b/makima/src/server/mod.rs @@ -8,7 +8,7 @@ pub mod state; use axum::{ http::StatusCode, response::IntoResponse, - routing::get, + routing::{get, post}, Json, Router, }; use serde::Serialize; @@ -17,7 +17,7 @@ use tower_http::trace::TraceLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use crate::server::handlers::{files, listen}; +use crate::server::handlers::{chat, files, listen}; use crate::server::openapi::ApiDoc; use crate::server::state::SharedState; @@ -50,6 +50,7 @@ pub fn make_router(state: SharedState) -> Router { .put(files::update_file) .delete(files::delete_file), ) + .route("/files/{id}/chat", post(chat::chat_handler)) .with_state(state); let swagger = SwaggerUi::new("/swagger-ui") |
