//! Chat endpoint for LLM-powered file editing.
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
use crate::db::{models::BodyElement, repository};
use crate::llm::{
claude::{self, ClaudeClient, ClaudeError, ClaudeModel},
execute_tool_call,
groq::{GroqClient, GroqError, Message, ToolCallResponse},
LlmModel, ToolCall, ToolResult, AVAILABLE_TOOLS,
};
use crate::server::state::{FileUpdateNotification, SharedState};
/// Maximum number of tool-calling rounds to prevent infinite loops
const MAX_TOOL_ROUNDS: usize = 10;
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
/// The user's message/instruction
pub message: String,
/// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq"
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatResponse {
/// The LLM's response message
pub response: String,
/// Tool calls that were executed
pub tool_calls: Vec<ToolCallInfo>,
/// Updated file body after tool execution
pub updated_body: Vec<BodyElement>,
/// Updated summary (if changed)
pub updated_summary: Option<String>,
}
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ToolCallInfo {
pub name: String,
pub result: ToolResult,
}
/// Enum to hold LLM clients
enum LlmClient {
Groq(GroqClient),
Claude(ClaudeClient),
}
/// Unified result from LLM call
struct LlmResult {
content: Option<String>,
tool_calls: Vec<ToolCall>,
raw_tool_calls: Vec<ToolCallResponse>,
finish_reason: String,
}
/// Chat with a file using LLM tool calling
#[utoipa::path(
post,
path = "/api/v1/files/{id}/chat",
request_body = ChatRequest,
responses(
(status = 200, description = "Chat completed successfully", body = ChatResponse),
(status = 404, description = "File not found"),
(status = 500, description = "Internal server error")
),
params(
("id" = Uuid, Path, description = "File ID")
),
tag = "chat"
)]
pub async fn chat_handler(
State(state): State<SharedState>,
Path(id): Path<Uuid>,
Json(request): Json<ChatRequest>,
) -> impl IntoResponse {
// Check if database is configured
let Some(ref pool) = state.db_pool else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "Database not configured"
})),
)
.into_response();
};
// Get the file
let file = match repository::get_file(pool, id).await {
Ok(Some(file)) => file,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "File not found"
})),
)
.into_response();
}
Err(e) => {
tracing::error!("Database error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Database error: {}", e)
})),
)
.into_response();
}
};
// Parse model selection (default to Claude Sonnet)
let model = request
.model
.as_ref()
.and_then(|m| LlmModel::from_str(m))
.unwrap_or_default();
tracing::info!("Using LLM model: {:?}", model);
// Initialize the appropriate LLM client
let llm_client = match model {
LlmModel::ClaudeSonnet => {
match ClaudeClient::from_env(ClaudeModel::Sonnet) {
Ok(client) => LlmClient::Claude(client),
Err(ClaudeError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "ANTHROPIC_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Claude client error: {}", e)
})),
)
.into_response();
}
}
}
LlmModel::ClaudeOpus => {
match ClaudeClient::from_env(ClaudeModel::Opus) {
Ok(client) => LlmClient::Claude(client),
Err(ClaudeError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "ANTHROPIC_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Claude client error: {}", e)
})),
)
.into_response();
}
}
}
LlmModel::GroqKimi => {
match GroqClient::from_env() {
Ok(client) => LlmClient::Groq(client),
Err(GroqError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "GROQ_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Groq client error: {}", e)
})),
)
.into_response();
}
}
}
};
// Build context about the file
let file_context = build_file_context(&file);
// Build initial messages (Groq/OpenAI format - will be converted for Claude)
let mut messages = vec![
Message {
role: "system".to_string(),
content: Some(format!(
"You are a helpful assistant that helps users edit and analyze document files. \
You have access to tools to add headings, paragraphs, charts, and set summaries. \
When the user asks you to modify the file, use the appropriate tools.\n\n\
IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \
and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \
to call add_chart.\n\n\
Current file context:\n{}",
file_context
)),
tool_calls: None,
tool_call_id: None,
},
Message {
role: "user".to_string(),
content: Some(request.message.clone()),
tool_calls: None,
tool_call_id: None,
},
];
// State for tracking changes
let mut current_body = file.body.clone();
let mut current_summary = file.summary.clone();
let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new();
let mut final_response: Option<String> = None;
// Multi-turn tool calling loop
for round in 0..MAX_TOOL_ROUNDS {
tracing::debug!(round = round, "LLM tool calling round");
// Call the appropriate LLM API
let result = match &llm_client {
LlmClient::Groq(groq) => {
match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await {
Ok(r) => LlmResult {
content: r.content,
tool_calls: r.tool_calls,
raw_tool_calls: r.raw_tool_calls,
finish_reason: r.finish_reason,
},
Err(e) => {
tracing::error!("Groq API error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("LLM API error: {}", e)
})),
)
.into_response();
}
}
}
LlmClient::Claude(claude_client) => {
// Convert messages to Claude format
let claude_messages = claude::groq_messages_to_claude(&messages);
match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await {
Ok(r) => {
// Convert Claude tool uses to Groq-style ToolCallResponse for consistency
let raw_tool_calls: Vec<ToolCallResponse> = r
.tool_calls
.iter()
.map(|tc| ToolCallResponse {
id: tc.id.clone(),
call_type: "function".to_string(),
function: crate::llm::groq::FunctionCall {
name: tc.name.clone(),
arguments: tc.arguments.to_string(),
},
})
.collect();
LlmResult {
content: r.content,
tool_calls: r.tool_calls,
raw_tool_calls,
finish_reason: r.stop_reason,
}
}
Err(e) => {
tracing::error!("Claude API error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("LLM API error: {}", e)
})),
)
.into_response();
}
}
}
};
// Check if there are tool calls to execute
if result.tool_calls.is_empty() {
// No more tool calls - capture the final response and exit loop
final_response = result.content;
break;
}
// Add assistant message with tool calls to conversation
messages.push(Message {
role: "assistant".to_string(),
content: result.content.clone(),
tool_calls: Some(result.raw_tool_calls.clone()),
tool_call_id: None,
});
// Execute each tool call and add results to conversation
for (i, tool_call) in result.tool_calls.iter().enumerate() {
let execution_result =
execute_tool_call(tool_call, ¤t_body, current_summary.as_deref());
// Apply state changes
if let Some(new_body) = execution_result.new_body {
current_body = new_body;
}
if let Some(new_summary) = execution_result.new_summary {
current_summary = Some(new_summary);
}
// Build tool result message content
let result_content = if let Some(parsed_data) = &execution_result.parsed_data {
// Include parsed data in the result for the LLM to use
serde_json::json!({
"success": execution_result.result.success,
"message": execution_result.result.message,
"data": parsed_data
})
.to_string()
} else {
serde_json::json!({
"success": execution_result.result.success,
"message": execution_result.result.message
})
.to_string()
};
// Add tool result message
// Use the appropriate ID format for each provider
let tool_call_id = match &llm_client {
LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(),
LlmClient::Claude(_) => tool_call.id.clone(),
};
messages.push(Message {
role: "tool".to_string(),
content: Some(result_content),
tool_calls: None,
tool_call_id: Some(tool_call_id),
});
// Track for response
all_tool_call_infos.push(ToolCallInfo {
name: tool_call.name.clone(),
result: execution_result.result,
});
}
// If finish reason indicates completion, exit loop
let finish_lower = result.finish_reason.to_lowercase();
if finish_lower == "stop" || finish_lower == "end_turn" {
final_response = result.content;
break;
}
}
// Save changes to database if any tools were executed
if !all_tool_call_infos.is_empty() {
let update_req = crate::db::models::UpdateFileRequest {
name: None,
description: None,
transcript: None,
summary: current_summary.clone(),
body: Some(current_body.clone()),
version: None, // Internal update, skip version check
};
match repository::update_file(pool, id, update_req).await {
Ok(Some(updated_file)) => {
// Broadcast update notification for LLM changes
let mut updated_fields = vec!["body".to_string()];
if current_summary.is_some() {
updated_fields.push("summary".to_string());
}
state.broadcast_file_update(FileUpdateNotification {
file_id: id,
version: updated_file.version,
updated_fields,
updated_by: "llm".to_string(),
});
}
Ok(None) => {
// File was deleted during processing
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "File not found"
})),
)
.into_response();
}
Err(e) => {
tracing::error!("Failed to save file changes: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Failed to save changes: {}", e)
})),
)
.into_response();
}
}
}
// Build response
let response_text = final_response.unwrap_or_else(|| {
if all_tool_call_infos.is_empty() {
"I couldn't understand your request. Please try rephrasing.".to_string()
} else {
format!(
"Done! Executed {} tool{}.",
all_tool_call_infos.len(),
if all_tool_call_infos.len() == 1 { "" } else { "s" }
)
}
});
(
StatusCode::OK,
Json(ChatResponse {
response: response_text,
tool_calls: all_tool_call_infos,
updated_body: current_body,
updated_summary: current_summary,
}),
)
.into_response()
}
fn build_file_context(file: &crate::db::models::File) -> String {
let mut context = format!("File: {}\n", file.name);
if let Some(ref desc) = file.description {
context.push_str(&format!("Description: {}\n", desc));
}
if let Some(ref summary) = file.summary {
context.push_str(&format!("Summary: {}\n", summary));
}
context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
context.push_str(&format!("Body elements: {}\n", file.body.len()));
// Add body overview
if !file.body.is_empty() {
context.push_str("\nCurrent body elements:\n");
for (i, element) in file.body.iter().enumerate() {
let desc = match element {
BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
BodyElement::Paragraph { text } => {
let preview = if text.len() > 50 {
format!("{}...", &text[..50])
} else {
text.clone()
};
format!("Paragraph: {}", preview)
}
BodyElement::Chart { chart_type, title, .. } => {
format!(
"Chart ({:?}){}",
chart_type,
title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
)
}
BodyElement::Image { alt, .. } => {
format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
}
};
context.push_str(&format!(" [{}] {}\n", i, desc));
}
}
// Add transcript preview if available
if !file.transcript.is_empty() {
context.push_str("\nTranscript preview (first 5 entries):\n");
for entry in file.transcript.iter().take(5) {
context.push_str(&format!(" - {}: {}\n", entry.speaker, entry.text));
}
if file.transcript.len() > 5 {
context.push_str(&format!(" ... and {} more entries\n", file.transcript.len() - 5));
}
}
context
}