summaryrefslogblamecommitdiff
path: root/makima/src/server/handlers/chat.rs
blob: 396c9732ab5d7b4dad58d27e3c4c2f61ad19791e (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12











                                               
                                                                          
                 
                                                           
                      
                                                             
                                                                        
  
                                                                
 


                                                                   




                                       


                                                                                     





















                                              











                                          
























































                                                                                         
































                                                                          
         












































                                                                          





                                                 

                                                                                 





                                                                                                  


                                                                                                                













                                                   
                                 

                                                   

                                                                



                                                                           
 


                                                                 
 


































































                                                                                                  

         





                                                                
           


                                                                    
                                      

                                                                                        



























                                                                               

                                                               



                                                                                   


                                                                     


                                                                                   













































                                                                                           


                                                          


                                                                                            





                                                               
                                                                 

          

































                                                                         



                     

                                                          



                                                                                    

                                                                     







                                    
                                            






























































                                                                                                    



































































































































































































                                                                                                                           
//! Chat endpoint for LLM-powered file editing.

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::IntoResponse,
    Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;

use crate::db::{models::BodyElement, repository::{self, RepositoryError}};
use crate::llm::{
    claude::{self, ClaudeClient, ClaudeError, ClaudeModel},
    execute_tool_call,
    groq::{GroqClient, GroqError, Message, ToolCallResponse},
    LlmModel, ToolCall, ToolResult, VersionToolRequest, AVAILABLE_TOOLS,
};
use crate::server::state::{FileUpdateNotification, SharedState};

/// Maximum number of tool-calling rounds to prevent infinite loops
const MAX_TOOL_ROUNDS: usize = 10;

#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
    /// The user's message/instruction
    pub message: String,
    /// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq"
    #[serde(default)]
    pub model: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatResponse {
    /// The LLM's response message
    pub response: String,
    /// Tool calls that were executed
    pub tool_calls: Vec<ToolCallInfo>,
    /// Updated file body after tool execution
    pub updated_body: Vec<BodyElement>,
    /// Updated summary (if changed)
    pub updated_summary: Option<String>,
}

#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ToolCallInfo {
    pub name: String,
    pub result: ToolResult,
}

/// Enum to hold LLM clients
enum LlmClient {
    Groq(GroqClient),
    Claude(ClaudeClient),
}

/// Unified result from LLM call
struct LlmResult {
    content: Option<String>,
    tool_calls: Vec<ToolCall>,
    raw_tool_calls: Vec<ToolCallResponse>,
    finish_reason: String,
}

/// Chat with a file using LLM tool calling
#[utoipa::path(
    post,
    path = "/api/v1/files/{id}/chat",
    request_body = ChatRequest,
    responses(
        (status = 200, description = "Chat completed successfully", body = ChatResponse),
        (status = 404, description = "File not found"),
        (status = 500, description = "Internal server error")
    ),
    params(
        ("id" = Uuid, Path, description = "File ID")
    ),
    tag = "chat"
)]
pub async fn chat_handler(
    State(state): State<SharedState>,
    Path(id): Path<Uuid>,
    Json(request): Json<ChatRequest>,
) -> impl IntoResponse {
    // Check if database is configured
    let Some(ref pool) = state.db_pool else {
        return (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(serde_json::json!({
                "error": "Database not configured"
            })),
        )
            .into_response();
    };

    // Get the file
    let file = match repository::get_file(pool, id).await {
        Ok(Some(file)) => file,
        Ok(None) => {
            return (
                StatusCode::NOT_FOUND,
                Json(serde_json::json!({
                    "error": "File not found"
                })),
            )
                .into_response();
        }
        Err(e) => {
            tracing::error!("Database error: {}", e);
            return (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(serde_json::json!({
                    "error": format!("Database error: {}", e)
                })),
            )
                .into_response();
        }
    };

    // Parse model selection (default to Claude Sonnet)
    let model = request
        .model
        .as_ref()
        .and_then(|m| LlmModel::from_str(m))
        .unwrap_or_default();

    tracing::info!("Using LLM model: {:?}", model);

    // Initialize the appropriate LLM client
    let llm_client = match model {
        LlmModel::ClaudeSonnet => {
            match ClaudeClient::from_env(ClaudeModel::Sonnet) {
                Ok(client) => LlmClient::Claude(client),
                Err(ClaudeError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "ANTHROPIC_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Claude client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
        LlmModel::ClaudeOpus => {
            match ClaudeClient::from_env(ClaudeModel::Opus) {
                Ok(client) => LlmClient::Claude(client),
                Err(ClaudeError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "ANTHROPIC_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Claude client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
        LlmModel::GroqKimi => {
            match GroqClient::from_env() {
                Ok(client) => LlmClient::Groq(client),
                Err(GroqError::MissingApiKey) => {
                    return (
                        StatusCode::SERVICE_UNAVAILABLE,
                        Json(serde_json::json!({
                            "error": "GROQ_API_KEY not configured"
                        })),
                    )
                        .into_response();
                }
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(serde_json::json!({
                            "error": format!("Groq client error: {}", e)
                        })),
                    )
                        .into_response();
                }
            }
        }
    };

    // Build context about the file
    let file_context = build_file_context(&file);

    // Build initial messages (Groq/OpenAI format - will be converted for Claude)
    let mut messages = vec![
        Message {
            role: "system".to_string(),
            content: Some(format!(
                "You are a helpful assistant that helps users edit and analyze document files. \
                You have access to tools to add headings, paragraphs, charts, and set summaries. \
                When the user asks you to modify the file, use the appropriate tools.\n\n\
                IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \
                and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \
                to call add_chart.\n\n\
                Current file context:\n{}",
                file_context
            )),
            tool_calls: None,
            tool_call_id: None,
        },
        Message {
            role: "user".to_string(),
            content: Some(request.message.clone()),
            tool_calls: None,
            tool_call_id: None,
        },
    ];

    // State for tracking changes
    let mut current_body = file.body.clone();
    let mut current_summary = file.summary.clone();
    let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new();
    let mut final_response: Option<String> = None;
    // Track if a version restore already happened (to avoid double-saving)
    let mut version_restored = false;
    // Track if there were modifications after a restore
    let mut has_changes_after_restore = false;

    // Multi-turn tool calling loop
    for round in 0..MAX_TOOL_ROUNDS {
        tracing::debug!(round = round, "LLM tool calling round");

        // Call the appropriate LLM API
        let result = match &llm_client {
            LlmClient::Groq(groq) => {
                match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await {
                    Ok(r) => LlmResult {
                        content: r.content,
                        tool_calls: r.tool_calls,
                        raw_tool_calls: r.raw_tool_calls,
                        finish_reason: r.finish_reason,
                    },
                    Err(e) => {
                        tracing::error!("Groq API error: {}", e);
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            Json(serde_json::json!({
                                "error": format!("LLM API error: {}", e)
                            })),
                        )
                            .into_response();
                    }
                }
            }
            LlmClient::Claude(claude_client) => {
                // Convert messages to Claude format
                let claude_messages = claude::groq_messages_to_claude(&messages);
                match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await {
                    Ok(r) => {
                        // Convert Claude tool uses to Groq-style ToolCallResponse for consistency
                        let raw_tool_calls: Vec<ToolCallResponse> = r
                            .tool_calls
                            .iter()
                            .map(|tc| ToolCallResponse {
                                id: tc.id.clone(),
                                call_type: "function".to_string(),
                                function: crate::llm::groq::FunctionCall {
                                    name: tc.name.clone(),
                                    arguments: tc.arguments.to_string(),
                                },
                            })
                            .collect();

                        LlmResult {
                            content: r.content,
                            tool_calls: r.tool_calls,
                            raw_tool_calls,
                            finish_reason: r.stop_reason,
                        }
                    }
                    Err(e) => {
                        tracing::error!("Claude API error: {}", e);
                        return (
                            StatusCode::INTERNAL_SERVER_ERROR,
                            Json(serde_json::json!({
                                "error": format!("LLM API error: {}", e)
                            })),
                        )
                            .into_response();
                    }
                }
            }
        };

        // Check if there are tool calls to execute
        if result.tool_calls.is_empty() {
            // No more tool calls - capture the final response and exit loop
            final_response = result.content;
            break;
        }

        // Add assistant message with tool calls to conversation
        messages.push(Message {
            role: "assistant".to_string(),
            content: result.content.clone(),
            tool_calls: Some(result.raw_tool_calls.clone()),
            tool_call_id: None,
        });

        // Execute each tool call and add results to conversation
        for (i, tool_call) in result.tool_calls.iter().enumerate() {
            let mut execution_result =
                execute_tool_call(tool_call, &current_body, current_summary.as_deref());

            // Handle version tool requests that need async database access
            if let Some(version_request) = &execution_result.version_request {
                let version_result = handle_version_request(
                    pool,
                    id,
                    version_request,
                    &current_body,
                    current_summary.as_deref(),
                    file.version,
                )
                .await;

                // Update execution result with actual version operation result
                execution_result.result = version_result.result;
                execution_result.parsed_data = version_result.data;

                // Apply state changes from restore operation
                if let Some(new_body) = version_result.new_body {
                    current_body = new_body;
                    // Mark that a restore happened - file was already saved
                    version_restored = true;
                }
                if let Some(new_summary) = version_result.new_summary {
                    current_summary = Some(new_summary);
                }
            }

            // Apply state changes from regular tools
            if let Some(new_body) = execution_result.new_body {
                current_body = new_body;
                // If this is a regular tool (not a version operation), track it
                if execution_result.version_request.is_none() && version_restored {
                    has_changes_after_restore = true;
                }
            }
            if let Some(new_summary) = execution_result.new_summary {
                current_summary = Some(new_summary);
                if execution_result.version_request.is_none() && version_restored {
                    has_changes_after_restore = true;
                }
            }

            // Build tool result message content
            let result_content = if let Some(parsed_data) = &execution_result.parsed_data {
                // Include parsed data in the result for the LLM to use
                serde_json::json!({
                    "success": execution_result.result.success,
                    "message": execution_result.result.message,
                    "data": parsed_data
                })
                .to_string()
            } else {
                serde_json::json!({
                    "success": execution_result.result.success,
                    "message": execution_result.result.message
                })
                .to_string()
            };

            // Add tool result message
            // Use the appropriate ID format for each provider
            let tool_call_id = match &llm_client {
                LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(),
                LlmClient::Claude(_) => tool_call.id.clone(),
            };

            messages.push(Message {
                role: "tool".to_string(),
                content: Some(result_content),
                tool_calls: None,
                tool_call_id: Some(tool_call_id),
            });

            // Track for response
            all_tool_call_infos.push(ToolCallInfo {
                name: tool_call.name.clone(),
                result: execution_result.result,
            });
        }

        // If finish reason indicates completion, exit loop
        let finish_lower = result.finish_reason.to_lowercase();
        if finish_lower == "stop" || finish_lower == "end_turn" {
            final_response = result.content;
            break;
        }
    }

    // Save changes to database if any tools were executed
    // Skip if a version restore already happened (file was already saved during restore)
    // UNLESS there were additional modifications after the restore
    if !all_tool_call_infos.is_empty() && (!version_restored || has_changes_after_restore) {
        let update_req = crate::db::models::UpdateFileRequest {
            name: None,
            description: None,
            transcript: None,
            summary: current_summary.clone(),
            body: Some(current_body.clone()),
            version: None, // Internal update, skip version check
        };

        match repository::update_file(pool, id, update_req).await {
            Ok(Some(updated_file)) => {
                // Broadcast update notification for LLM changes
                let mut updated_fields = vec!["body".to_string()];
                if current_summary.is_some() {
                    updated_fields.push("summary".to_string());
                }
                state.broadcast_file_update(FileUpdateNotification {
                    file_id: id,
                    version: updated_file.version,
                    updated_fields,
                    updated_by: "llm".to_string(),
                });
            }
            Ok(None) => {
                // File was deleted during processing
                return (
                    StatusCode::NOT_FOUND,
                    Json(serde_json::json!({
                        "error": "File not found"
                    })),
                )
                    .into_response();
            }
            Err(e) => {
                tracing::error!("Failed to save file changes: {}", e);
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(serde_json::json!({
                        "error": format!("Failed to save changes: {}", e)
                    })),
                )
                    .into_response();
            }
        }
    }

    // Build response
    let response_text = final_response.unwrap_or_else(|| {
        if all_tool_call_infos.is_empty() {
            "I couldn't understand your request. Please try rephrasing.".to_string()
        } else {
            format!(
                "Done! Executed {} tool{}.",
                all_tool_call_infos.len(),
                if all_tool_call_infos.len() == 1 { "" } else { "s" }
            )
        }
    });

    (
        StatusCode::OK,
        Json(ChatResponse {
            response: response_text,
            tool_calls: all_tool_call_infos,
            updated_body: current_body,
            updated_summary: current_summary,
        }),
    )
        .into_response()
}

fn build_file_context(file: &crate::db::models::File) -> String {
    let mut context = format!("File: {}\n", file.name);

    if let Some(ref desc) = file.description {
        context.push_str(&format!("Description: {}\n", desc));
    }

    if let Some(ref summary) = file.summary {
        context.push_str(&format!("Summary: {}\n", summary));
    }

    context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
    context.push_str(&format!("Body elements: {}\n", file.body.len()));

    // Add body overview
    if !file.body.is_empty() {
        context.push_str("\nCurrent body elements:\n");
        for (i, element) in file.body.iter().enumerate() {
            let desc = match element {
                BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
                BodyElement::Paragraph { text } => {
                    let preview = if text.len() > 50 {
                        format!("{}...", &text[..50])
                    } else {
                        text.clone()
                    };
                    format!("Paragraph: {}", preview)
                }
                BodyElement::Chart { chart_type, title, .. } => {
                    format!(
                        "Chart ({:?}){}",
                        chart_type,
                        title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
                    )
                }
                BodyElement::Image { alt, .. } => {
                    format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
                }
            };
            context.push_str(&format!("  [{}] {}\n", i, desc));
        }
    }

    // Add transcript preview if available
    if !file.transcript.is_empty() {
        context.push_str("\nTranscript preview (first 5 entries):\n");
        for entry in file.transcript.iter().take(5) {
            context.push_str(&format!("  - {}: {}\n", entry.speaker, entry.text));
        }
        if file.transcript.len() > 5 {
            context.push_str(&format!("  ... and {} more entries\n", file.transcript.len() - 5));
        }
    }

    context
}

/// Result of handling a version tool request
struct VersionRequestResult {
    result: ToolResult,
    data: Option<serde_json::Value>,
    new_body: Option<Vec<BodyElement>>,
    new_summary: Option<String>,
}

/// Handle version tool requests that require async database access
async fn handle_version_request(
    pool: &sqlx::PgPool,
    file_id: Uuid,
    request: &VersionToolRequest,
    _current_body: &[BodyElement],
    _current_summary: Option<&str>,
    current_version: i32,
) -> VersionRequestResult {
    match request {
        VersionToolRequest::ListVersions => {
            match repository::list_file_versions(pool, file_id).await {
                Ok(versions) => {
                    let version_data: Vec<serde_json::Value> = versions
                        .iter()
                        .map(|v| {
                            serde_json::json!({
                                "version": v.version,
                                "source": v.source,
                                "createdAt": v.created_at.to_rfc3339(),
                                "changeDescription": v.change_description,
                            })
                        })
                        .collect();

                    VersionRequestResult {
                        result: ToolResult {
                            success: true,
                            message: format!("Found {} versions. Current version is {}.", versions.len(), current_version),
                        },
                        data: Some(serde_json::json!({
                            "currentVersion": current_version,
                            "versions": version_data,
                        })),
                        new_body: None,
                        new_summary: None,
                    }
                }
                Err(e) => VersionRequestResult {
                    result: ToolResult {
                        success: false,
                        message: format!("Failed to list versions: {}", e),
                    },
                    data: None,
                    new_body: None,
                    new_summary: None,
                },
            }
        }
        VersionToolRequest::ReadVersion { version } => {
            match repository::get_file_version(pool, file_id, *version).await {
                Ok(Some(ver)) => {
                    // Convert body elements to a readable format
                    let body_preview: Vec<String> = ver
                        .body
                        .iter()
                        .enumerate()
                        .map(|(i, element)| {
                            let desc = match element {
                                BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
                                BodyElement::Paragraph { text } => {
                                    let preview = if text.len() > 100 {
                                        format!("{}...", &text[..100])
                                    } else {
                                        text.clone()
                                    };
                                    format!("Paragraph: {}", preview)
                                }
                                BodyElement::Chart { chart_type, title, .. } => {
                                    format!(
                                        "Chart ({:?}){}",
                                        chart_type,
                                        title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
                                    )
                                }
                                BodyElement::Image { alt, .. } => {
                                    format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
                                }
                            };
                            format!("[{}] {}", i, desc)
                        })
                        .collect();

                    VersionRequestResult {
                        result: ToolResult {
                            success: true,
                            message: format!(
                                "Version {} from {} (source: {}). {} body elements.",
                                ver.version,
                                ver.created_at.format("%Y-%m-%d %H:%M"),
                                ver.source,
                                ver.body.len()
                            ),
                        },
                        data: Some(serde_json::json!({
                            "version": ver.version,
                            "source": ver.source,
                            "createdAt": ver.created_at.to_rfc3339(),
                            "summary": ver.summary,
                            "bodyPreview": body_preview,
                            "changeDescription": ver.change_description,
                        })),
                        new_body: None,
                        new_summary: None,
                    }
                }
                Ok(None) => VersionRequestResult {
                    result: ToolResult {
                        success: false,
                        message: format!("Version {} not found", version),
                    },
                    data: None,
                    new_body: None,
                    new_summary: None,
                },
                Err(e) => VersionRequestResult {
                    result: ToolResult {
                        success: false,
                        message: format!("Failed to read version: {}", e),
                    },
                    data: None,
                    new_body: None,
                    new_summary: None,
                },
            }
        }
        VersionToolRequest::RestoreVersion { target_version, reason } => {
            // Set change description if provided
            if let Some(reason) = reason {
                let _ = repository::set_change_description(pool, reason).await;
            }

            match repository::restore_file_version(pool, file_id, *target_version, current_version).await {
                Ok(Some(restored_file)) => {
                    VersionRequestResult {
                        result: ToolResult {
                            success: true,
                            message: format!(
                                "Restored to version {}. New version is {}.",
                                target_version, restored_file.version
                            ),
                        },
                        data: Some(serde_json::json!({
                            "previousVersion": current_version,
                            "restoredFromVersion": target_version,
                            "newVersion": restored_file.version,
                        })),
                        new_body: Some(restored_file.body),
                        new_summary: restored_file.summary,
                    }
                }
                Ok(None) => VersionRequestResult {
                    result: ToolResult {
                        success: false,
                        message: format!("Version {} not found", target_version),
                    },
                    data: None,
                    new_body: None,
                    new_summary: None,
                },
                Err(RepositoryError::VersionConflict { expected, actual }) => {
                    VersionRequestResult {
                        result: ToolResult {
                            success: false,
                            message: format!(
                                "Version conflict: expected {}, actual {}. Document was modified.",
                                expected, actual
                            ),
                        },
                        data: None,
                        new_body: None,
                        new_summary: None,
                    }
                }
                Err(e) => VersionRequestResult {
                    result: ToolResult {
                        success: false,
                        message: format!("Failed to restore version: {}", e),
                    },
                    data: None,
                    new_body: None,
                    new_summary: None,
                },
            }
        }
    }
}