diff options
| author | soryu <soryu@soryu.co> | 2025-12-23 14:43:23 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 555061b179b8ec034cb70f9a2dd6c823ced0f637 (patch) | |
| tree | 0545b4395dab6d957884d8d36bf15b8da529dc1f /makima/src/server/handlers | |
| parent | a32dc56d2e5447ef8988cb98b8686476cc94e70c (diff) | |
| download | soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.tar.gz soryu-555061b179b8ec034cb70f9a2dd6c823ced0f637.zip | |
Add file body and initial tool call system
Diffstat (limited to 'makima/src/server/handlers')
| -rw-r--r-- | makima/src/server/handlers/chat.rs | 296 | ||||
| -rw-r--r-- | makima/src/server/handlers/listen.rs | 83 | ||||
| -rw-r--r-- | makima/src/server/handlers/mod.rs | 1 |
3 files changed, 375 insertions, 5 deletions
diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs new file mode 100644 index 0000000..e6d22ca --- /dev/null +++ b/makima/src/server/handlers/chat.rs @@ -0,0 +1,296 @@ +//! Chat endpoint for LLM-powered file editing. + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + Json, +}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +use crate::db::{models::BodyElement, repository}; +use crate::llm::{ + execute_tool_call, + groq::{GroqClient, GroqError, Message}, + ToolResult, AVAILABLE_TOOLS, +}; +use crate::server::state::SharedState; + +#[derive(Debug, Deserialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChatRequest { + /// The user's message/instruction + pub message: String, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ChatResponse { + /// The LLM's response message + pub response: String, + /// Tool calls that were executed + pub tool_calls: Vec<ToolCallInfo>, + /// Updated file body after tool execution + pub updated_body: Vec<BodyElement>, + /// Updated summary (if changed) + pub updated_summary: Option<String>, +} + +#[derive(Debug, Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallInfo { + pub name: String, + pub result: ToolResult, +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +/// Chat with a file using LLM tool calling +#[utoipa::path( + post, + path = "/api/v1/files/{id}/chat", + request_body = ChatRequest, + responses( + (status = 200, description = "Chat completed successfully", body = ChatResponse), + (status = 404, description = "File not found"), + (status = 500, description = "Internal server error") + ), + params( + ("id" = Uuid, Path, description = "File ID") + ), + tag = "chat" +)] +pub async fn chat_handler( + State(state): State<SharedState>, + Path(id): Path<Uuid>, + Json(request): Json<ChatRequest>, +) -> impl IntoResponse { + // Check if database is configured + let Some(ref pool) = state.db_pool else { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "Database not configured" + })), + ) + .into_response(); + }; + + // Get the file + let file = match repository::get_file(pool, id).await { + Ok(Some(file)) => file, + Ok(None) => { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ + "error": "File not found" + })), + ) + .into_response(); + } + Err(e) => { + tracing::error!("Database error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Database error: {}", e) + })), + ) + .into_response(); + } + }; + + // Initialize Groq client + let groq = match GroqClient::from_env() { + Ok(client) => client, + Err(GroqError::MissingApiKey) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "error": "GROQ_API_KEY not configured" + })), + ) + .into_response(); + } + Err(e) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Groq client error: {}", e) + })), + ) + .into_response(); + } + }; + + // Build context about the file + let file_context = build_file_context(&file); + + // Build messages + let messages = vec![ + Message { + role: "system".to_string(), + content: Some(format!( + "You are a helpful assistant that helps users edit and analyze document files. \ + You have access to tools to add headings, paragraphs, charts, and set summaries. \ + When the user asks you to modify the file, use the appropriate tools.\n\n\ + Current file context:\n{}", + file_context + )), + tool_calls: None, + tool_call_id: None, + }, + Message { + role: "user".to_string(), + content: Some(request.message.clone()), + tool_calls: None, + tool_call_id: None, + }, + ]; + + // Call Groq API + let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await { + Ok(result) => result, + Err(e) => { + tracing::error!("Groq API error: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("LLM API error: {}", e) + })), + ) + .into_response(); + } + }; + + // Execute tool calls + let mut current_body = file.body.clone(); + let mut current_summary = file.summary.clone(); + let mut tool_call_infos = Vec::new(); + + for tool_call in &result.tool_calls { + let execution_result = + execute_tool_call(tool_call, ¤t_body, current_summary.as_deref()); + + // Apply state changes + if let Some(new_body) = execution_result.new_body { + current_body = new_body; + } + if let Some(new_summary) = execution_result.new_summary { + current_summary = Some(new_summary); + } + + tool_call_infos.push(ToolCallInfo { + name: tool_call.name.clone(), + result: execution_result.result, + }); + } + + // Save changes to database if any tools were executed + if !result.tool_calls.is_empty() { + let update_req = crate::db::models::UpdateFileRequest { + name: None, + description: None, + transcript: None, + summary: current_summary.clone(), + body: Some(current_body.clone()), + }; + + if let Err(e) = repository::update_file(pool, id, update_req).await { + tracing::error!("Failed to save file changes: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": format!("Failed to save changes: {}", e) + })), + ) + .into_response(); + } + } + + // Build response + let response_text = result.content.unwrap_or_else(|| { + if tool_call_infos.is_empty() { + "I couldn't understand your request. Please try rephrasing.".to_string() + } else { + format!( + "Done! Executed {} tool{}.", + tool_call_infos.len(), + if tool_call_infos.len() == 1 { "" } else { "s" } + ) + } + }); + + ( + StatusCode::OK, + Json(ChatResponse { + response: response_text, + tool_calls: tool_call_infos, + updated_body: current_body, + updated_summary: current_summary, + }), + ) + .into_response() +} + +fn build_file_context(file: &crate::db::models::File) -> String { + let mut context = format!("File: {}\n", file.name); + + if let Some(ref desc) = file.description { + context.push_str(&format!("Description: {}\n", desc)); + } + + if let Some(ref summary) = file.summary { + context.push_str(&format!("Summary: {}\n", summary)); + } + + context.push_str(&format!("Transcript entries: {}\n", file.transcript.len())); + context.push_str(&format!("Body elements: {}\n", file.body.len())); + + // Add body overview + if !file.body.is_empty() { + context.push_str("\nCurrent body elements:\n"); + for (i, element) in file.body.iter().enumerate() { + let desc = match element { + BodyElement::Heading { level, text } => format!("H{}: {}", level, text), + BodyElement::Paragraph { text } => { + let preview = if text.len() > 50 { + format!("{}...", &text[..50]) + } else { + text.clone() + }; + format!("Paragraph: {}", preview) + } + BodyElement::Chart { chart_type, title, .. } => { + format!( + "Chart ({:?}){}", + chart_type, + title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default() + ) + } + BodyElement::Image { alt, .. } => { + format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default()) + } + }; + context.push_str(&format!(" [{}] {}\n", i, desc)); + } + } + + // Add transcript preview if available + if !file.transcript.is_empty() { + context.push_str("\nTranscript preview (first 5 entries):\n"); + for entry in file.transcript.iter().take(5) { + context.push_str(&format!(" - {}: {}\n", entry.speaker, entry.text)); + } + if file.transcript.len() > 5 { + context.push_str(&format!(" ... and {} more entries\n", file.transcript.len() - 5)); + } + } + + context +} diff --git a/makima/src/server/handlers/listen.rs b/makima/src/server/handlers/listen.rs index 93062f3..3055cb7 100644 --- a/makima/src/server/handlers/listen.rs +++ b/makima/src/server/handlers/listen.rs @@ -449,21 +449,31 @@ async fn handle_socket(socket: WebSocket, state: SharedState) { // Save final transcript to file if we have one if let Some(fid) = file_id { if let Some(ref pool) = state.db_pool { + // Deduplicate transcript entries before saving + let deduplicated = deduplicate_transcripts(&transcript_entries); + // Mark all entries as final - for entry in &mut transcript_entries { - entry.is_final = true; - } + let final_entries: Vec<TranscriptEntry> = deduplicated + .into_iter() + .map(|mut entry| { + entry.is_final = true; + entry + }) + .collect(); match repository::update_file(pool, fid, UpdateFileRequest { name: None, description: None, - transcript: Some(transcript_entries.clone()), + transcript: Some(final_entries.clone()), + summary: None, + body: None, }).await { Ok(_) => { tracing::info!( session_id = %session_id, file_id = %fid, - transcript_count = transcript_entries.len(), + original_count = transcript_entries.len(), + deduplicated_count = final_entries.len(), "Saved final transcript to file" ); } @@ -502,6 +512,69 @@ fn decode_audio_chunk(data: &[u8], format: &StartMessage) -> Vec<f32> { } } +/// Deduplicate transcript entries by removing entries with similar start times and text. +/// +/// Entries are considered duplicates if: +/// - Start times are within 0.5 seconds of each other +/// - Speaker is the same +/// - Text is identical or one is a substring of the other +fn deduplicate_transcripts(entries: &[TranscriptEntry]) -> Vec<TranscriptEntry> { + if entries.is_empty() { + return vec![]; + } + + // Sort by start time + let mut sorted: Vec<TranscriptEntry> = entries.to_vec(); + sorted.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap_or(std::cmp::Ordering::Equal)); + + let mut result: Vec<TranscriptEntry> = Vec::new(); + + for entry in sorted { + // Check if this entry is a duplicate of any existing entry + let is_duplicate = result.iter().any(|existing| { + // Check if start times are close (within 0.5 seconds) + let time_close = (existing.start - entry.start).abs() < 0.5; + + // Check if same speaker + let same_speaker = existing.speaker == entry.speaker; + + // Check if text matches or one contains the other + let text_match = existing.text == entry.text + || existing.text.contains(&entry.text) + || entry.text.contains(&existing.text); + + time_close && same_speaker && text_match + }); + + if !is_duplicate { + result.push(entry); + } else { + // If duplicate, check if the new entry has longer text and update + for existing in &mut result { + let time_close = (existing.start - entry.start).abs() < 0.5; + let same_speaker = existing.speaker == entry.speaker; + + if time_close && same_speaker && entry.text.len() > existing.text.len() { + // Keep the longer text version + existing.text = entry.text.clone(); + existing.end = entry.end; + break; + } + } + } + } + + // Reassign IDs to be sequential + for (i, entry) in result.iter_mut().enumerate() { + let parts: Vec<&str> = entry.id.split('-').collect(); + if let Some(session_prefix) = parts.first() { + entry.id = format!("{}-{}", session_prefix, i + 1); + } + } + + result +} + /// Process audio using sliding window through STT and streaming diarization models. /// /// Only processes the last MAX_WINDOW_SECONDS of audio to maintain constant diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs index f249234..b13668a 100644 --- a/makima/src/server/handlers/mod.rs +++ b/makima/src/server/handlers/mod.rs @@ -1,4 +1,5 @@ //! HTTP and WebSocket request handlers. +pub mod chat; pub mod files; pub mod listen; |
