summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/chat.rs
blob: e6d22ca8a7a7183bca850d27200279f044ac3eaf (plain) (tree)







































































































































































































































































































                                                                                                    
//! Chat endpoint for LLM-powered file editing.

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::IntoResponse,
    Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;

use crate::db::{models::BodyElement, repository};
use crate::llm::{
    execute_tool_call,
    groq::{GroqClient, GroqError, Message},
    ToolResult, AVAILABLE_TOOLS,
};
use crate::server::state::SharedState;

#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
    /// The user's message/instruction
    pub message: String,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatResponse {
    /// The LLM's response message
    pub response: String,
    /// Tool calls that were executed
    pub tool_calls: Vec<ToolCallInfo>,
    /// Updated file body after tool execution
    pub updated_body: Vec<BodyElement>,
    /// Updated summary (if changed)
    pub updated_summary: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ToolCallInfo {
    pub name: String,
    pub result: ToolResult,
}

#[derive(Debug, Serialize)]
struct ErrorResponse {
    error: String,
}

/// Chat with a file using LLM tool calling
#[utoipa::path(
    post,
    path = "/api/v1/files/{id}/chat",
    request_body = ChatRequest,
    responses(
        (status = 200, description = "Chat completed successfully", body = ChatResponse),
        (status = 404, description = "File not found"),
        (status = 500, description = "Internal server error")
    ),
    params(
        ("id" = Uuid, Path, description = "File ID")
    ),
    tag = "chat"
)]
pub async fn chat_handler(
    State(state): State<SharedState>,
    Path(id): Path<Uuid>,
    Json(request): Json<ChatRequest>,
) -> impl IntoResponse {
    // Check if database is configured
    let Some(ref pool) = state.db_pool else {
        return (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(serde_json::json!({
                "error": "Database not configured"
            })),
        )
            .into_response();
    };

    // Get the file
    let file = match repository::get_file(pool, id).await {
        Ok(Some(file)) => file,
        Ok(None) => {
            return (
                StatusCode::NOT_FOUND,
                Json(serde_json::json!({
                    "error": "File not found"
                })),
            )
                .into_response();
        }
        Err(e) => {
            tracing::error!("Database error: {}", e);
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("Database error: {}", e)
                })),
            )
                .into_response();
        }
    };

    // Initialize Groq client
    let groq = match GroqClient::from_env() {
        Ok(client) => client,
        Err(GroqError::MissingApiKey) => {
            return (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(serde_json::json!({
                    "error": "GROQ_API_KEY not configured"
                })),
            )
                .into_response();
        }
        Err(e) => {
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("Groq client error: {}", e)
                })),
            )
                .into_response();
        }
    };

    // Build context about the file
    let file_context = build_file_context(&file);

    // Build messages
    let messages = vec![
        Message {
            role: "system".to_string(),
            content: Some(format!(
                "You are a helpful assistant that helps users edit and analyze document files. \
                You have access to tools to add headings, paragraphs, charts, and set summaries. \
                When the user asks you to modify the file, use the appropriate tools.\n\n\
                Current file context:\n{}",
                file_context
            )),
            tool_calls: None,
            tool_call_id: None,
        },
        Message {
            role: "user".to_string(),
            content: Some(request.message.clone()),
            tool_calls: None,
            tool_call_id: None,
        },
    ];

    // Call Groq API
    let result = match groq.chat_with_tools(messages, &AVAILABLE_TOOLS).await {
        Ok(result) => result,
        Err(e) => {
            tracing::error!("Groq API error: {}", e);
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("LLM API error: {}", e)
                })),
            )
                .into_response();
        }
    };

    // Execute tool calls
    let mut current_body = file.body.clone();
    let mut current_summary = file.summary.clone();
    let mut tool_call_infos = Vec::new();

    for tool_call in &result.tool_calls {
        let execution_result =
            execute_tool_call(tool_call, &current_body, current_summary.as_deref());

        // Apply state changes
        if let Some(new_body) = execution_result.new_body {
            current_body = new_body;
        }
        if let Some(new_summary) = execution_result.new_summary {
            current_summary = Some(new_summary);
        }

        tool_call_infos.push(ToolCallInfo {
            name: tool_call.name.clone(),
            result: execution_result.result,
        });
    }

    // Save changes to database if any tools were executed
    if !result.tool_calls.is_empty() {
        let update_req = crate::db::models::UpdateFileRequest {
            name: None,
            description: None,
            transcript: None,
            summary: current_summary.clone(),
            body: Some(current_body.clone()),
        };

        if let Err(e) = repository::update_file(pool, id, update_req).await {
            tracing::error!("Failed to save file changes: {}", e);
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("Failed to save changes: {}", e)
                })),
            )
                .into_response();
        }
    }

    // Build response
    let response_text = result.content.unwrap_or_else(|| {
        if tool_call_infos.is_empty() {
            "I couldn't understand your request. Please try rephrasing.".to_string()
        } else {
            format!(
                "Done! Executed {} tool{}.",
                tool_call_infos.len(),
                if tool_call_infos.len() == 1 { "" } else { "s" }
            )
        }
    });

    (
        StatusCode::OK,
        Json(ChatResponse {
            response: response_text,
            tool_calls: tool_call_infos,
            updated_body: current_body,
            updated_summary: current_summary,
        }),
    )
        .into_response()
}

fn build_file_context(file: &crate::db::models::File) -> String {
    let mut context = format!("File: {}\n", file.name);

    if let Some(ref desc) = file.description {
        context.push_str(&format!("Description: {}\n", desc));
    }

    if let Some(ref summary) = file.summary {
        context.push_str(&format!("Summary: {}\n", summary));
    }

    context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
    context.push_str(&format!("Body elements: {}\n", file.body.len()));

    // Add body overview
    if !file.body.is_empty() {
        context.push_str("\nCurrent body elements:\n");
        for (i, element) in file.body.iter().enumerate() {
            let desc = match element {
                BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
                BodyElement::Paragraph { text } => {
                    let preview = if text.len() > 50 {
                        format!("{}...", &text[..50])
                    } else {
                        text.clone()
                    };
                    format!("Paragraph: {}", preview)
                }
                BodyElement::Chart { chart_type, title, .. } => {
                    format!(
                        "Chart ({:?}){}",
                        chart_type,
                        title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
                    )
                }
                BodyElement::Image { alt, .. } => {
                    format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
                }
            };
            context.push_str(&format!("  [{}] {}\n", i, desc));
        }
    }

    // Add transcript preview if available
    if !file.transcript.is_empty() {
        context.push_str("\nTranscript preview (first 5 entries):\n");
        for entry in file.transcript.iter().take(5) {
            context.push_str(&format!("  - {}: {}\n", entry.speaker, entry.text));
        }
        if file.transcript.len() > 5 {
            context.push_str(&format!("  ... and {} more entries\n", file.transcript.len() - 5));
        }
    }

    context
}