//! Chat endpoint for LLM-powered file editing.
use axum::{
extract::{Path, State},
http::StatusCode,
response::IntoResponse,
Json,
};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
use crate::db::{models::BodyElement, repository::{self, RepositoryError}};
use crate::llm::{
claude::{self, ClaudeClient, ClaudeError, ClaudeModel},
execute_tool_call,
groq::{GroqClient, GroqError, Message, ToolCallResponse},
LlmModel, ToolCall, ToolResult, VersionToolRequest, AVAILABLE_TOOLS,
};
use crate::server::state::{FileUpdateNotification, SharedState};
/// Maximum number of tool-calling rounds to prevent infinite loops
const MAX_TOOL_ROUNDS: usize = 10;
#[derive(Debug, Deserialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatRequest {
/// The user's message/instruction
pub message: String,
/// Optional model selection: "claude-sonnet" (default), "claude-opus", or "groq"
#[serde(default)]
pub model: Option<String>,
}
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatResponse {
/// The LLM's response message
pub response: String,
/// Tool calls that were executed
pub tool_calls: Vec<ToolCallInfo>,
/// Updated file body after tool execution
pub updated_body: Vec<BodyElement>,
/// Updated summary (if changed)
pub updated_summary: Option<String>,
}
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ToolCallInfo {
pub name: String,
pub result: ToolResult,
}
/// Enum to hold LLM clients
enum LlmClient {
Groq(GroqClient),
Claude(ClaudeClient),
}
/// Unified result from LLM call
struct LlmResult {
content: Option<String>,
tool_calls: Vec<ToolCall>,
raw_tool_calls: Vec<ToolCallResponse>,
finish_reason: String,
}
/// Chat with a file using LLM tool calling
#[utoipa::path(
post,
path = "/api/v1/files/{id}/chat",
request_body = ChatRequest,
responses(
(status = 200, description = "Chat completed successfully", body = ChatResponse),
(status = 404, description = "File not found"),
(status = 500, description = "Internal server error")
),
params(
("id" = Uuid, Path, description = "File ID")
),
tag = "chat"
)]
pub async fn chat_handler(
State(state): State<SharedState>,
Path(id): Path<Uuid>,
Json(request): Json<ChatRequest>,
) -> impl IntoResponse {
// Check if database is configured
let Some(ref pool) = state.db_pool else {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "Database not configured"
})),
)
.into_response();
};
// Get the file
let file = match repository::get_file(pool, id).await {
Ok(Some(file)) => file,
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "File not found"
})),
)
.into_response();
}
Err(e) => {
tracing::error!("Database error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Database error: {}", e)
})),
)
.into_response();
}
};
// Parse model selection (default to Claude Sonnet)
let model = request
.model
.as_ref()
.and_then(|m| LlmModel::from_str(m))
.unwrap_or_default();
tracing::info!("Using LLM model: {:?}", model);
// Initialize the appropriate LLM client
let llm_client = match model {
LlmModel::ClaudeSonnet => {
match ClaudeClient::from_env(ClaudeModel::Sonnet) {
Ok(client) => LlmClient::Claude(client),
Err(ClaudeError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "ANTHROPIC_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Claude client error: {}", e)
})),
)
.into_response();
}
}
}
LlmModel::ClaudeOpus => {
match ClaudeClient::from_env(ClaudeModel::Opus) {
Ok(client) => LlmClient::Claude(client),
Err(ClaudeError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "ANTHROPIC_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Claude client error: {}", e)
})),
)
.into_response();
}
}
}
LlmModel::GroqKimi => {
match GroqClient::from_env() {
Ok(client) => LlmClient::Groq(client),
Err(GroqError::MissingApiKey) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "GROQ_API_KEY not configured"
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Groq client error: {}", e)
})),
)
.into_response();
}
}
}
};
// Build context about the file
let file_context = build_file_context(&file);
// Build initial messages (Groq/OpenAI format - will be converted for Claude)
let mut messages = vec![
Message {
role: "system".to_string(),
content: Some(format!(
"You are a helpful assistant that helps users edit and analyze document files. \
You have access to tools to add headings, paragraphs, charts, and set summaries. \
When the user asks you to modify the file, use the appropriate tools.\n\n\
IMPORTANT: You can call multiple tools in sequence. For example, if the user provides CSV data \
and asks for a chart, first call parse_csv to convert the data to JSON, then use that JSON \
to call add_chart.\n\n\
Current file context:\n{}",
file_context
)),
tool_calls: None,
tool_call_id: None,
},
Message {
role: "user".to_string(),
content: Some(request.message.clone()),
tool_calls: None,
tool_call_id: None,
},
];
// State for tracking changes
let mut current_body = file.body.clone();
let mut current_summary = file.summary.clone();
let mut all_tool_call_infos: Vec<ToolCallInfo> = Vec::new();
let mut final_response: Option<String> = None;
// Track if a version restore already happened (to avoid double-saving)
let mut version_restored = false;
// Track if there were modifications after a restore
let mut has_changes_after_restore = false;
// Multi-turn tool calling loop
for round in 0..MAX_TOOL_ROUNDS {
tracing::debug!(round = round, "LLM tool calling round");
// Call the appropriate LLM API
let result = match &llm_client {
LlmClient::Groq(groq) => {
match groq.chat_with_tools(messages.clone(), &AVAILABLE_TOOLS).await {
Ok(r) => LlmResult {
content: r.content,
tool_calls: r.tool_calls,
raw_tool_calls: r.raw_tool_calls,
finish_reason: r.finish_reason,
},
Err(e) => {
tracing::error!("Groq API error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("LLM API error: {}", e)
})),
)
.into_response();
}
}
}
LlmClient::Claude(claude_client) => {
// Convert messages to Claude format
let claude_messages = claude::groq_messages_to_claude(&messages);
match claude_client.chat_with_tools(claude_messages, &AVAILABLE_TOOLS).await {
Ok(r) => {
// Convert Claude tool uses to Groq-style ToolCallResponse for consistency
let raw_tool_calls: Vec<ToolCallResponse> = r
.tool_calls
.iter()
.map(|tc| ToolCallResponse {
id: tc.id.clone(),
call_type: "function".to_string(),
function: crate::llm::groq::FunctionCall {
name: tc.name.clone(),
arguments: tc.arguments.to_string(),
},
})
.collect();
LlmResult {
content: r.content,
tool_calls: r.tool_calls,
raw_tool_calls,
finish_reason: r.stop_reason,
}
}
Err(e) => {
tracing::error!("Claude API error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("LLM API error: {}", e)
})),
)
.into_response();
}
}
}
};
// Check if there are tool calls to execute
if result.tool_calls.is_empty() {
// No more tool calls - capture the final response and exit loop
final_response = result.content;
break;
}
// Add assistant message with tool calls to conversation
messages.push(Message {
role: "assistant".to_string(),
content: result.content.clone(),
tool_calls: Some(result.raw_tool_calls.clone()),
tool_call_id: None,
});
// Execute each tool call and add results to conversation
for (i, tool_call) in result.tool_calls.iter().enumerate() {
let mut execution_result =
execute_tool_call(tool_call, ¤t_body, current_summary.as_deref());
// Handle version tool requests that need async database access
if let Some(version_request) = &execution_result.version_request {
let version_result = handle_version_request(
pool,
id,
version_request,
¤t_body,
current_summary.as_deref(),
file.version,
)
.await;
// Update execution result with actual version operation result
execution_result.result = version_result.result;
execution_result.parsed_data = version_result.data;
// Apply state changes from restore operation
if let Some(new_body) = version_result.new_body {
current_body = new_body;
// Mark that a restore happened - file was already saved
version_restored = true;
}
if let Some(new_summary) = version_result.new_summary {
current_summary = Some(new_summary);
}
}
// Apply state changes from regular tools
if let Some(new_body) = execution_result.new_body {
current_body = new_body;
// If this is a regular tool (not a version operation), track it
if execution_result.version_request.is_none() && version_restored {
has_changes_after_restore = true;
}
}
if let Some(new_summary) = execution_result.new_summary {
current_summary = Some(new_summary);
if execution_result.version_request.is_none() && version_restored {
has_changes_after_restore = true;
}
}
// Build tool result message content
let result_content = if let Some(parsed_data) = &execution_result.parsed_data {
// Include parsed data in the result for the LLM to use
serde_json::json!({
"success": execution_result.result.success,
"message": execution_result.result.message,
"data": parsed_data
})
.to_string()
} else {
serde_json::json!({
"success": execution_result.result.success,
"message": execution_result.result.message
})
.to_string()
};
// Add tool result message
// Use the appropriate ID format for each provider
let tool_call_id = match &llm_client {
LlmClient::Groq(_) => result.raw_tool_calls[i].id.clone(),
LlmClient::Claude(_) => tool_call.id.clone(),
};
messages.push(Message {
role: "tool".to_string(),
content: Some(result_content),
tool_calls: None,
tool_call_id: Some(tool_call_id),
});
// Track for response
all_tool_call_infos.push(ToolCallInfo {
name: tool_call.name.clone(),
result: execution_result.result,
});
}
// If finish reason indicates completion, exit loop
let finish_lower = result.finish_reason.to_lowercase();
if finish_lower == "stop" || finish_lower == "end_turn" {
final_response = result.content;
break;
}
}
// Save changes to database if any tools were executed
// Skip if a version restore already happened (file was already saved during restore)
// UNLESS there were additional modifications after the restore
if !all_tool_call_infos.is_empty() && (!version_restored || has_changes_after_restore) {
let update_req = crate::db::models::UpdateFileRequest {
name: None,
description: None,
transcript: None,
summary: current_summary.clone(),
body: Some(current_body.clone()),
version: None, // Internal update, skip version check
};
match repository::update_file(pool, id, update_req).await {
Ok(Some(updated_file)) => {
// Broadcast update notification for LLM changes
let mut updated_fields = vec!["body".to_string()];
if current_summary.is_some() {
updated_fields.push("summary".to_string());
}
state.broadcast_file_update(FileUpdateNotification {
file_id: id,
version: updated_file.version,
updated_fields,
updated_by: "llm".to_string(),
});
}
Ok(None) => {
// File was deleted during processing
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "File not found"
})),
)
.into_response();
}
Err(e) => {
tracing::error!("Failed to save file changes: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"error": format!("Failed to save changes: {}", e)
})),
)
.into_response();
}
}
}
// Build response
let response_text = final_response.unwrap_or_else(|| {
if all_tool_call_infos.is_empty() {
"I couldn't understand your request. Please try rephrasing.".to_string()
} else {
format!(
"Done! Executed {} tool{}.",
all_tool_call_infos.len(),
if all_tool_call_infos.len() == 1 { "" } else { "s" }
)
}
});
(
StatusCode::OK,
Json(ChatResponse {
response: response_text,
tool_calls: all_tool_call_infos,
updated_body: current_body,
updated_summary: current_summary,
}),
)
.into_response()
}
fn build_file_context(file: &crate::db::models::File) -> String {
let mut context = format!("File: {}\n", file.name);
if let Some(ref desc) = file.description {
context.push_str(&format!("Description: {}\n", desc));
}
if let Some(ref summary) = file.summary {
context.push_str(&format!("Summary: {}\n", summary));
}
context.push_str(&format!("Transcript entries: {}\n", file.transcript.len()));
context.push_str(&format!("Body elements: {}\n", file.body.len()));
// Add body overview
if !file.body.is_empty() {
context.push_str("\nCurrent body elements:\n");
for (i, element) in file.body.iter().enumerate() {
let desc = match element {
BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
BodyElement::Paragraph { text } => {
let preview = if text.len() > 50 {
format!("{}...", &text[..50])
} else {
text.clone()
};
format!("Paragraph: {}", preview)
}
BodyElement::Chart { chart_type, title, .. } => {
format!(
"Chart ({:?}){}",
chart_type,
title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
)
}
BodyElement::Image { alt, .. } => {
format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
}
};
context.push_str(&format!(" [{}] {}\n", i, desc));
}
}
// Add transcript preview if available
if !file.transcript.is_empty() {
context.push_str("\nTranscript preview (first 5 entries):\n");
for entry in file.transcript.iter().take(5) {
context.push_str(&format!(" - {}: {}\n", entry.speaker, entry.text));
}
if file.transcript.len() > 5 {
context.push_str(&format!(" ... and {} more entries\n", file.transcript.len() - 5));
}
}
context
}
/// Result of handling a version tool request
struct VersionRequestResult {
result: ToolResult,
data: Option<serde_json::Value>,
new_body: Option<Vec<BodyElement>>,
new_summary: Option<String>,
}
/// Handle version tool requests that require async database access
async fn handle_version_request(
pool: &sqlx::PgPool,
file_id: Uuid,
request: &VersionToolRequest,
_current_body: &[BodyElement],
_current_summary: Option<&str>,
current_version: i32,
) -> VersionRequestResult {
match request {
VersionToolRequest::ListVersions => {
match repository::list_file_versions(pool, file_id).await {
Ok(versions) => {
let version_data: Vec<serde_json::Value> = versions
.iter()
.map(|v| {
serde_json::json!({
"version": v.version,
"source": v.source,
"createdAt": v.created_at.to_rfc3339(),
"changeDescription": v.change_description,
})
})
.collect();
VersionRequestResult {
result: ToolResult {
success: true,
message: format!("Found {} versions. Current version is {}.", versions.len(), current_version),
},
data: Some(serde_json::json!({
"currentVersion": current_version,
"versions": version_data,
})),
new_body: None,
new_summary: None,
}
}
Err(e) => VersionRequestResult {
result: ToolResult {
success: false,
message: format!("Failed to list versions: {}", e),
},
data: None,
new_body: None,
new_summary: None,
},
}
}
VersionToolRequest::ReadVersion { version } => {
match repository::get_file_version(pool, file_id, *version).await {
Ok(Some(ver)) => {
// Convert body elements to a readable format
let body_preview: Vec<String> = ver
.body
.iter()
.enumerate()
.map(|(i, element)| {
let desc = match element {
BodyElement::Heading { level, text } => format!("H{}: {}", level, text),
BodyElement::Paragraph { text } => {
let preview = if text.len() > 100 {
format!("{}...", &text[..100])
} else {
text.clone()
};
format!("Paragraph: {}", preview)
}
BodyElement::Chart { chart_type, title, .. } => {
format!(
"Chart ({:?}){}",
chart_type,
title.as_ref().map(|t| format!(": {}", t)).unwrap_or_default()
)
}
BodyElement::Image { alt, .. } => {
format!("Image{}", alt.as_ref().map(|a| format!(": {}", a)).unwrap_or_default())
}
};
format!("[{}] {}", i, desc)
})
.collect();
VersionRequestResult {
result: ToolResult {
success: true,
message: format!(
"Version {} from {} (source: {}). {} body elements.",
ver.version,
ver.created_at.format("%Y-%m-%d %H:%M"),
ver.source,
ver.body.len()
),
},
data: Some(serde_json::json!({
"version": ver.version,
"source": ver.source,
"createdAt": ver.created_at.to_rfc3339(),
"summary": ver.summary,
"bodyPreview": body_preview,
"changeDescription": ver.change_description,
})),
new_body: None,
new_summary: None,
}
}
Ok(None) => VersionRequestResult {
result: ToolResult {
success: false,
message: format!("Version {} not found", version),
},
data: None,
new_body: None,
new_summary: None,
},
Err(e) => VersionRequestResult {
result: ToolResult {
success: false,
message: format!("Failed to read version: {}", e),
},
data: None,
new_body: None,
new_summary: None,
},
}
}
VersionToolRequest::RestoreVersion { target_version, reason } => {
// Set change description if provided
if let Some(reason) = reason {
let _ = repository::set_change_description(pool, reason).await;
}
match repository::restore_file_version(pool, file_id, *target_version, current_version).await {
Ok(Some(restored_file)) => {
VersionRequestResult {
result: ToolResult {
success: true,
message: format!(
"Restored to version {}. New version is {}.",
target_version, restored_file.version
),
},
data: Some(serde_json::json!({
"previousVersion": current_version,
"restoredFromVersion": target_version,
"newVersion": restored_file.version,
})),
new_body: Some(restored_file.body),
new_summary: restored_file.summary,
}
}
Ok(None) => VersionRequestResult {
result: ToolResult {
success: false,
message: format!("Version {} not found", target_version),
},
data: None,
new_body: None,
new_summary: None,
},
Err(RepositoryError::VersionConflict { expected, actual }) => {
VersionRequestResult {
result: ToolResult {
success: false,
message: format!(
"Version conflict: expected {}, actual {}. Document was modified.",
expected, actual
),
},
data: None,
new_body: None,
new_summary: None,
}
}
Err(e) => VersionRequestResult {
result: ToolResult {
success: false,
message: format!("Failed to restore version: {}", e),
},
data: None,
new_body: None,
new_summary: None,
},
}
}
}
}