summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/chat.rs
blob: 3bdbc7445a138f6970d24519224178ff733ab4eb (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14













                                                 
                                                           
                      

                                                             
  
                                                                
 


                                                                   




                                       


                                                                                     





















                                              











                                          
























































                                                                                         
































                                                                          
         












































                                                                          





                                                 

                                                                                 





                                                                                                  


                                                                                                                













                                                   
                                 

                                                   

                                                                
 


                                                                 
 


































































                                                                                                  

         





                                                                
           

























































                                                                                           


                                                          
                                        





                                                               
                                                                 

          

































                                                                         



                     

                                                          



                                                                                    

                                                                     







                                    
                                            






























































                                                                                                    
//! Chat endpoint for LLM-powered file editing.

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::IntoResponse,
    Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;

use crate::db::{models::BodyElement, repository};
use crate::llm::{
    claude::{self, ClaudeClient, ClaudeError, ClaudeModel},
    execute_tool_call,
    groq::{GroqClient, GroqError, Message, ToolCallResponse},
    LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS,
};
use crate::server::state::{FileUpdateNotification, SharedState};

/// Maximum number of tool-calling rounds to prevent infinite loops
const MAX_TOOL_ROUNDS: usize = 10;

#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
    /// The user's message/instruction
    pub message: String,
    /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq"
    #[serde(default)]
    pub model: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatResponse {
    /// The LLM's response message
    pub response: String,
    /// Tool calls that were executed
    pub tool_calls: Vec<ToolCallInfo>,
    /// Updated file body after tool execution
    pub updated_body: Vec<BodyElement>,
    /// Updated summary (if changed)
    pub updated_summary: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ToolCallInfo {
    pub name: String,
    pub result: ToolResult,
}

/// Enum to hold LLM clients
enum LlmClient {
    Groq(GroqClient),
    Claude(ClaudeClient),
}

/// Unified result from LLM call
struct LlmResult {
    content: Option<String>,
    tool_calls: Vec<ToolCall>,
    raw_tool_calls: Vec<ToolCallResponse>,
    finish_reason: String,
}

/// Chat with a file using LLM tool calling
#[utoipa::path(
    post,
    path = "/api/v1/files/{id}/chat",
    request_body = ChatRequest,
    responses(
        (status = 200, description = "Chat completed successfully", body = ChatResponse),
        (status = 404, description = "File not found"),
        (status = 500, description = "Internal server error")
    ),
    params(
        ("id" = Uuid, Path, description = "File ID")
    ),
    tag = "chat"
)]
pub async fn chat_handler(
    State(state): State<SharedState>,
    Path(id): Path<Uuid>,
    Json(request): Json<ChatRequest>,
) -> impl IntoResponse {
    // Check if database is configured
    let Some(ref pool) = state.db_pool else {
        return (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(serde_json::json!({
                "error": "Database not configured"
            })),
        )
            .into_response();
    };

    // Get the file
    let file = match repository::get_file(pool, id).await {
        Ok(Some(file)) => file,
        Ok(None) => {
            return (
                StatusCode::NOT_FOUND,
                Json(serde_json::json!({
                    "error": "File not found"
                })),
            )
                .into_response();
        }
        Err(e) => {
            tracing::error!("Database error: {}", e);
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("Database error: {}", e)
                })),
            )
                .into_response();
        }
    };

    // Parse model selection (default to Claude Sonnet)
    let model = request
        .model
        .as_ref()
        .and_then(|m| LlmModel::from_str(m))
        .unwrap_or_default();

    tracing::info!("Using LLM model: {:?}", model);

    // Initialize the appropriate LLM client
    let llm_client = match model {
        LlmModel::ClaudeSonnet => {
            match ClaudeClient::from_env(ClaudeModel::Sonnet) {
                Ok(client) => LlmClient::Claude(client),
                Err(ClaudeError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "ANTHROPIC_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Claude client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
        LlmModel::ClaudeOpus => {
            match ClaudeClient::from_env(ClaudeModel::Opus) {
                Ok(client) => LlmClient::Claude(client),
                Err(ClaudeError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "ANTHROPIC_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Claude client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
        LlmModel::GroqKimi => {
            match GroqClient::from_env() {
                Ok(client) => LlmClient::Groq(client),
                Err(GroqError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "GROQ_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Groq client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
    };

    // Build context about the file
    let file_context = build_file_context(&file);

    // Build initial messages (Groq/OpenAI format - will be converted for Claude)
    let mut messages = vec![
        Message {
            role: "system".to_string(),
            content: Some(format!(
                "You are a helpful assistant that helps users edit and analyze document files. \
                You have access to tools to add headings, paragraphs, charts, and set summaries. \
                When the user asks you to modify the file, use the appropriate tools.\n\n\
                IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \
                and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \
                to call add_chart.\n\n\
                Current file context:\n{}",
                file_context
            )),
            tool_calls: None,
            tool_call_id: None,
        },
        Message {
            role: "user".to_string(),
            content: Some(request.message.clone()),
            tool_calls: None,
            tool_call_id: None,
        },
    ];

    // State for tracking changes
    let mut current_body = file.body.clone();
    let mut current_summary = file.summary.clone();
    let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new();
    let mut final_response: Option<String> = None;

    // Multi-turn tool calling loop
    for round in 0..MAX_TOOL_ROUNDS {
        tracing::debug!(round = round, "LLM tool calling round");

        // Call the appropriate LLM API
        let result = match &llm_client {
            LlmClient::Groq(groq) => {
                match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await {
                    Ok(r) => LlmResult {
                        content: r.content,
                        tool_calls: r.tool_calls,
                        raw_tool_calls: r.raw_tool_calls,
                        finish_reason: r.finish_reason,
                    },
                    Err(e) => {
                        tracing::error!("Groq API error: {}", e);
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            Json(serde_json::json!({
                                "error": format!("LLM API error: {}", e)
                            })),
                        )
                            .into_response();
                    }
                }
            }
            LlmClient::Claude(claude_client) => {
                // Convert messages to Claude format
                let claude_messages = claude::groq_messages_to_claude(&messages);
                match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await {
                    Ok(r) => {
                        // Convert Claude tool uses to Groq-style ToolCallResponse for consistency
                        let raw_tool_calls: Vec<ToolCallResponse> = r
                            .tool_calls
                            .iter()
                            .map(|tc| ToolCallResponse {
                                id: tc.id.clone(),
                                call_type: "function".to_string(),
                                function: crate::llm::groq::FunctionCall {
                                    name: tc.name.clone(),
                                    arguments: tc.arguments.to_string(),
                                },
                            })
                            .collect();

                        LlmResult {
                            content: r.content,
                            tool_calls: r.tool_calls,
                            raw_tool_calls,
                            finish_reason: r.stop_reason,
                        }
                    }
                    Err(e) => {
                        tracing::error!("Claude API error: {}", e);
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            Json(serde_json::json!({
                                "error": format!("LLM API error: {}", e)
                            })),
                        )
                            .into_response();
                    }
                }
            }
        };

        // Check if there are tool calls to execute
        if result.tool_calls.is_empty() {
            // No more tool calls - capture the final response and exit loop
            final_response = result.content;
            break;
        }

        // Add assistant message with tool calls to conversation
        messages.push(Message {
            role: "assistant".to_string(),
            content: result.content.clone(),
            tool_calls: Some(result.raw_tool_calls.clone()),
            tool_call_id: None,
        });

        // Execute each tool call and add results to conversation
        for (i, tool_call) in result.tool_calls.iter().enumerate() {
            let execution_result =
                execute_tool_call(tool_call, &current_body, current_summary.as_deref());

            // Apply state changes
            if let Some(new_body) = execution_result.new_body {
                current_body = new_body;
            }
            if let Some(new_summary) = execution_result.new_summary {
                current_summary = Some(new_summary);
            }

            // Build tool result message content
            let result_content = if let Some(parsed_data) = &execution_result.parsed_data {
                // Include parsed data in the result for the LLM to use
                serde_json::json!({
                    "success": execution_result.result.success,
                    "message": execution_result.result.message,
                    "data": parsed_data
                })
                .to_string()
            } else {
                serde_json::json!({
                    "success": execution_result.result.success,
                    "message": execution_result.result.message
                })
                .to_string()
            };

            // Add tool result message
            // Use the appropriate ID format for each provider
            let tool_call_id = match &llm_client {
                LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(),
                LlmClient::Claude(_) => tool_call.id.clone(),
            };

            messages.push(Message {
                role: "tool".to_string(),
                content: Some(result_content),
                tool_calls: None,
                tool_call_id: Some(tool_call_id),
            });

            // Track for response
            all_tool_call_infos.push(ToolCallInfo {
                name: tool_call.name.clone(),
                result: execution_result.result,
            });
        }

        // If finish reason indicates completion, exit loop
        let finish_lower = result.finish_reason.to_lowercase();
        if finish_lower == "stop" || finish_lower == "end_turn" {
            final_response = result.content;
            break;
        }
    }

    // Save changes to database if any tools were executed
    if !all_tool_call_infos.is_empty() {
        let update_req = crate::db::models::UpdateFileRequest {
            name: None,
            description: None,
            transcript: None,
            summary: current_summary.clone(),
            body: Some(current_body.clone()),
            version: None, // Internal update, skip version check
        };

        match repository::update_file(pool, id, update_req).await {
            Ok(Some(updated_file)) => {
                // Broadcast update notification for LLM changes
                let mut updated_fields = vec!["body".to_string()];
                if current_summary.is_some() {
                    updated_fields.push("summary".to_string());
                }
                state.broadcast_file_update(FileUpdateNotification {
                    file_id: id,
                    version: updated_file.version,
                    updated_fields,
                    updated_by: "llm".to_string(),
                });
            }
            Ok(None) => {
                // File was deleted during processing
                return (
                    StatusCode::NOT_FOUND,
                    Json(serde_json::json!({
                        "error": "File not found"
                    })),
                )
                    .into_response();
            }
            Err(e) => {
                tracing::error!("Failed to save file changes: {}", e);
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(serde_json::json!({
                        "error": format!("Failed to save changes: {}", e)
                    })),
                )
                    .into_response();
            }
        }
    }

    // Build response
    let response_text = final_response.unwrap_or_else(|| {
        if all_tool_call_infos.is_empty() {
            "I couldn't understand your request. Please try rephrasing.".to_string()
        } else {
            format!(
                "Done! Executed {} tool{}.",
                all_tool_call_infos.len(),
                if all_tool_call_infos.len() == 1 { "" } else { "s" }
            )
        }
    });

    (
        StatusCode::OK,
        Json(ChatResponse {
            response: response_text,
            tool_calls: all_tool_call_infos,
            updated_body: current_body,
            updated_summary: current_summary,
        }),
    )
        .into_response()
}

fn build_file_context(file: &crate::db::models::File) -> String {
    let mut context = format!("File: {}\n", file.name);

    if let Some(ref desc) = file.description {
        context.push_str(&format!("Description: {}\n", desc));
    }

    if let Some(ref summary) = file.summary {
        context.push_str(&format!("Summary: {}\n", summary));
    }

    context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
    context.push_str(&format!("Body elements: {}\n", file.body.len()));

    // Add body overview
    if !file.body.is_empty() {
        context.push_str("\nCurrent body elements:\n");
        for (i, element) in file.body.iter().enumerate() {
            let desc = match element {
                BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
                BodyElement::Paragraph { text } => {
                    let preview = if text.len() > 50 {
                        format!("{}...", &text[..50])
                    } else {
                        text.clone()
                    };
                    format!("Paragraph: {}", preview)
                }
                BodyElement::Chart { chart_type, title, .. } => {
                    format!(
                        "Chart ({:?}){}",
                        chart_type,
                        title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
                    )
                }
                BodyElement::Image { alt, .. } => {
                    format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
                }
            };
            context.push_str(&format!("  [{}] {}\n", i, desc));
        }
    }

    // Add transcript preview if available
    if !file.transcript.is_empty() {
        context.push_str("\nTranscript preview (first 5 entries):\n");
        for entry in file.transcript.iter().take(5) {
            context.push_str(&format!("  - {}: {}\n", entry.speaker, entry.text));
        }
        if file.transcript.len() > 5 {
            context.push_str(&format!("  ... and {} more entries\n", file.transcript.len() - 5));
        }
    }

    context
}