summaryrefslogtreecommitdiff
path: root/makima/src/server/handlers/chat.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/server/handlers/chat.rs')
-rw-r--r--makima/src/server/handlers/chat.rs296
1 files changed, 296 insertions, 0 deletions
diff --git a/makima/src/server/handlers/chat.rs b/makima/src/server/handlers/chat.rs
new file mode 100644
index 0000000..e6d22ca
--- /dev/null
+++ b/makima/src/server/handlers/chat.rs
@@ -0,0 +1,296 @@
+//! Chat endpoint for LLM-powered file editing.
+
+use axum::{
+ extract::{Path, State},
+ http::StatusCode,
+ response::IntoResponse,
+ Json,
+};
+use serde::{Deserialize, Serialize};
+use utoipa::ToSchema;
+use uuid::Uuid;
+
+use crate::db::{models::BodyElement, repository};
+use crate::llm::{
+ execute_tool_call,
+ groq::{GroqClient, GroqError, Message},
+ ToolResult, AVAILABLE_TOOLS,
+};
+use crate::server::state::SharedState;
+
+#[derive(Debug, Deserialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct ChatRequest {
+ /// The user's message/instruction
+ pub message: String,
+}
+
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct ChatResponse {
+ /// The LLM's response message
+ pub response: String,
+ /// Tool calls that were executed
+ pub tool_calls: Vec<ToolCallInfo>,
+ /// Updated file body after tool execution
+ pub updated_body: Vec<BodyElement>,
+ /// Updated summary (if changed)
+ pub updated_summary: Option<String>,
+}
+
+#[derive(Debug, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct ToolCallInfo {
+ pub name: String,
+ pub result: ToolResult,
+}
+
+#[derive(Debug, Serialize)]
+struct ErrorResponse {
+ error: String,
+}
+
+/// Chat with a file using LLM tool calling
+#[utoipa::path(
+ post,
+ path = "/api/v1/files/{id}/chat",
+ request_body = ChatRequest,
+ responses(
+ (status = 200, description = "Chat completed successfully", body = ChatResponse),
+ (status = 404, description = "File not found"),
+ (status = 500, description = "Internal server error")
+ ),
+ params(
+ ("id" = Uuid, Path, description = "File ID")
+ ),
+ tag = "chat"
+)]
+pub async fn chat_handler(
+ State(state): State<SharedState>,
+ Path(id): Path<Uuid>,
+ Json(request): Json<ChatRequest>,
+) -> impl IntoResponse {
+ // Check if database is configured
+ let Some(ref pool) = state.db_pool else {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ Json(serde_json::json!({
+ "error": "Database not configured"
+ })),
+ )
+ .into_response();
+ };
+
+ // Get the file
+ let file = match repository::get_file(pool, id).await {
+ Ok(Some(file)) => file,
+ Ok(None) => {
+ return (
+ StatusCode::NOT_FOUND,
+ Json(serde_json::json!({
+ "error": "File not found"
+ })),
+ )
+ .into_response();
+ }
+ Err(e) => {
+ tracing::error!("Database error: {}", e);
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Database error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ };
+
+ // Initialize Groq client
+ let groq = match GroqClient::from_env() {
+ Ok(client) => client,
+ Err(GroqError::MissingApiKey) => {
+ return (
+ StatusCode::SERVICE_UNAVAILABLE,
+ Json(serde_json::json!({
+ "error": "GROQ_API_KEY not configured"
+ })),
+ )
+ .into_response();
+ }
+ Err(e) => {
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Groq client error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ };
+
+ // Build context about the file
+ let file_context = build_file_context(&file);
+
+ // Build messages
+ let messages = vec![
+ Message {
+ role: "system".to_string(),
+ content: Some(format!(
+ "You are a helpful assistant that helps users edit and analyze document files. \
+ You have access to tools to add headings, paragraphs, charts, and set summaries. \
+ When the user asks you to modify the file, use the appropriate tools.\n\n\
+ Current file context:\n{}",
+ file_context
+ )),
+ tool_calls: None,
+ tool_call_id: None,
+ },
+ Message {
+ role: "user".to_string(),
+ content: Some(request.message.clone()),
+ tool_calls: None,
+ tool_call_id: None,
+ },
+ ];
+
+ // Call Groq API
+ let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await {
+ Ok(result) => result,
+ Err(e) => {
+ tracing::error!("Groq API error: {}", e);
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("LLM API error: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ };
+
+ // Execute tool calls
+ let mut current_body = file.body.clone();
+ let mut current_summary = file.summary.clone();
+ let mut tool_call_infos = Vec::new();
+
+ for tool_call in &result.tool_calls {
+ let execution_result =
+ execute_tool_call(tool_call, &current_body, current_summary.as_deref());
+
+ // Apply state changes
+ if let Some(new_body) = execution_result.new_body {
+ current_body = new_body;
+ }
+ if let Some(new_summary) = execution_result.new_summary {
+ current_summary = Some(new_summary);
+ }
+
+ tool_call_infos.push(ToolCallInfo {
+ name: tool_call.name.clone(),
+ result: execution_result.result,
+ });
+ }
+
+ // Save changes to database if any tools were executed
+ if !result.tool_calls.is_empty() {
+ let update_req = crate::db::models::UpdateFileRequest {
+ name: None,
+ description: None,
+ transcript: None,
+ summary: current_summary.clone(),
+ body: Some(current_body.clone()),
+ };
+
+ if let Err(e) = repository::update_file(pool, id, update_req).await {
+ tracing::error!("Failed to save file changes: {}", e);
+ return (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ Json(serde_json::json!({
+ "error": format!("Failed to save changes: {}", e)
+ })),
+ )
+ .into_response();
+ }
+ }
+
+ // Build response
+ let response_text = result.content.unwrap_or_else(|| {
+ if tool_call_infos.is_empty() {
+ "I couldn't understand your request. Please try rephrasing.".to_string()
+ } else {
+ format!(
+ "Done! Executed {} tool{}.",
+ tool_call_infos.len(),
+ if tool_call_infos.len() == 1 { "" } else { "s" }
+ )
+ }
+ });
+
+ (
+ StatusCode::OK,
+ Json(ChatResponse {
+ response: response_text,
+ tool_calls: tool_call_infos,
+ updated_body: current_body,
+ updated_summary: current_summary,
+ }),
+ )
+ .into_response()
+}
+
+fn build_file_context(file: &crate::db::models::File) -> String {
+ let mut context = format!("File: {}\n", file.name);
+
+ if let Some(ref desc) = file.description {
+ context.push_str(&format!("Description: {}\n", desc));
+ }
+
+ if let Some(ref summary) = file.summary {
+ context.push_str(&format!("Summary: {}\n", summary));
+ }
+
+ context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
+ context.push_str(&format!("Body elements: {}\n", file.body.len()));
+
+ // Add body overview
+ if !file.body.is_empty() {
+ context.push_str("\nCurrent body elements:\n");
+ for (i, element) in file.body.iter().enumerate() {
+ let desc = match element {
+ BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
+ BodyElement::Paragraph { text } => {
+ let preview = if text.len() > 50 {
+ format!("{}...", &text[..50])
+ } else {
+ text.clone()
+ };
+ format!("Paragraph: {}", preview)
+ }
+ BodyElement::Chart { chart_type, title, .. } => {
+ format!(
+ "Chart ({:?}){}",
+ chart_type,
+ title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
+ )
+ }
+ BodyElement::Image { alt, .. } => {
+ format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
+ }
+ };
+ context.push_str(&format!(" [{}] {}\n", i, desc));
+ }
+ }
+
+ // Add transcript preview if available
+ if !file.transcript.is_empty() {
+ context.push_str("\nTranscript preview (first 5 entries):\n");
+ for entry in file.transcript.iter().take(5) {
+ context.push_str(&format!(" - {}: {}\n", entry.speaker, entry.text));
+ }
+ if file.transcript.len() > 5 {
+ context.push_str(&format!(" ... and {} more entries\n", file.transcript.len() - 5));
+ }
+ }
+
+ context
+}