//! WebSocket client for connecting to the makima server.
use std::sync::Arc;
use std::time::Duration;
use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, RwLock};
use tokio_tungstenite::{connect_async, tungstenite::{client::IntoClientRequest, Message}};
use uuid::Uuid;
use super::protocol::{DaemonCommand, DaemonMessage};
use crate::config::ServerConfig;
use crate::error::{DaemonError, Result};
/// WebSocket client state.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
/// Not connected to server.
Disconnected,
/// Currently connecting.
Connecting,
/// Connected and authenticated.
Connected,
/// Connection failed, will retry.
Reconnecting,
/// Permanently failed (e.g., auth failure).
Failed,
}
/// WebSocket client for daemon-server communication.
pub struct WsClient {
config: ServerConfig,
machine_id: String,
hostname: String,
max_concurrent_tasks: i32,
state: Arc<RwLock<ConnectionState>>,
daemon_id: Arc<RwLock<Option<Uuid>>>,
/// Channel to receive messages to send to server.
outgoing_rx: mpsc::Receiver<DaemonMessage>,
/// Sender for outgoing messages (clone this to send messages).
outgoing_tx: mpsc::Sender<DaemonMessage>,
/// Channel to send received commands to the task manager.
incoming_tx: mpsc::Sender<DaemonCommand>,
}
impl WsClient {
/// Create a new WebSocket client.
pub fn new(
config: ServerConfig,
machine_id: String,
hostname: String,
max_concurrent_tasks: i32,
incoming_tx: mpsc::Sender<DaemonCommand>,
) -> Self {
let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
Self {
config,
machine_id,
hostname,
max_concurrent_tasks,
state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
daemon_id: Arc::new(RwLock::new(None)),
outgoing_rx,
outgoing_tx,
incoming_tx,
}
}
/// Get a sender for outgoing messages.
pub fn sender(&self) -> mpsc::Sender<DaemonMessage> {
self.outgoing_tx.clone()
}
/// Get current connection state.
pub async fn state(&self) -> ConnectionState {
*self.state.read().await
}
/// Get daemon ID if authenticated.
pub async fn daemon_id(&self) -> Option<Uuid> {
*self.daemon_id.read().await
}
/// Run the WebSocket client with automatic reconnection.
pub async fn run(&mut self) -> Result<()> {
let mut backoff = ExponentialBackoff {
initial_interval: Duration::from_secs(self.config.reconnect_interval_secs),
max_interval: Duration::from_secs(60),
max_elapsed_time: if self.config.max_reconnect_attempts > 0 {
Some(Duration::from_secs(
self.config.reconnect_interval_secs * self.config.max_reconnect_attempts as u64 * 10,
))
} else {
None // Infinite retries
},
..Default::default()
};
loop {
*self.state.write().await = ConnectionState::Connecting;
tracing::info!("Connecting to server: {}", self.config.url);
match self.connect_and_run().await {
Ok(()) => {
// Clean shutdown
tracing::info!("WebSocket connection closed cleanly");
break;
}
Err(DaemonError::AuthFailed(msg)) => {
tracing::error!("Authentication failed: {}", msg);
*self.state.write().await = ConnectionState::Failed;
return Err(DaemonError::AuthFailed(msg));
}
Err(e) => {
tracing::warn!("Connection error: {}", e);
*self.state.write().await = ConnectionState::Reconnecting;
if let Some(delay) = backoff.next_backoff() {
tracing::info!("Reconnecting in {:?}...", delay);
tokio::time::sleep(delay).await;
} else {
tracing::error!("Max reconnection attempts reached");
*self.state.write().await = ConnectionState::Failed;
return Err(DaemonError::ConnectionLost);
}
}
}
}
Ok(())
}
/// Connect to server and run the message loop.
async fn connect_and_run(&mut self) -> Result<()> {
// Build WebSocket URL
let ws_url = format!("{}/api/v1/mesh/daemons/connect", self.config.url);
tracing::debug!("Connecting to WebSocket: {}", ws_url);
// Build request with API key header
let mut request = ws_url.into_client_request()?;
request.headers_mut().insert(
"x-makima-api-key",
self.config.api_key.parse().map_err(|_| {
DaemonError::AuthFailed("Invalid API key format".into())
})?,
);
// Connect with API key in headers
let (ws_stream, _response) = connect_async(request).await?;
let (mut write, mut read) = ws_stream.split();
// Send daemon info after connection (server authenticated us via header)
let info_msg = DaemonMessage::authenticate(
&self.config.api_key,
&self.machine_id,
&self.hostname,
self.max_concurrent_tasks,
);
let info_json = serde_json::to_string(&info_msg)?;
write.send(Message::Text(info_json)).await?;
// Wait for authentication response
let auth_response = read
.next()
.await
.ok_or(DaemonError::ConnectionLost)??;
let auth_text = match auth_response {
Message::Text(text) => text,
Message::Close(_) => return Err(DaemonError::ConnectionLost),
_ => return Err(DaemonError::AuthFailed("Unexpected response type".into())),
};
let command: DaemonCommand = serde_json::from_str(&auth_text)?;
match command {
DaemonCommand::Authenticated { daemon_id } => {
tracing::info!("Authenticated with daemon ID: {}", daemon_id);
*self.daemon_id.write().await = Some(daemon_id);
*self.state.write().await = ConnectionState::Connected;
// Send daemon directories info to server
let working_directory = std::env::current_dir()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_else(|_| ".".to_string());
let home_directory = dirs::home_dir()
.map(|h| h.join(".makima").join("home"))
.unwrap_or_else(|| std::path::PathBuf::from("~/.makima/home"));
// Create home directory if it doesn't exist
if let Err(e) = std::fs::create_dir_all(&home_directory) {
tracing::warn!("Failed to create home directory {:?}: {}", home_directory, e);
}
let home_directory_str = home_directory.to_string_lossy().to_string();
let worktrees_directory = dirs::home_dir()
.map(|h| h.join(".makima").join("worktrees").to_string_lossy().to_string())
.unwrap_or_else(|| "~/.makima/worktrees".to_string());
let dirs_msg = DaemonMessage::DaemonDirectories {
working_directory,
home_directory: home_directory_str,
worktrees_directory,
};
let dirs_json = serde_json::to_string(&dirs_msg)?;
write.send(Message::Text(dirs_json)).await?;
tracing::info!("Sent daemon directories info to server");
}
DaemonCommand::Error { code, message } => {
return Err(DaemonError::AuthFailed(format!("{}: {}", code, message)));
}
_ => {
return Err(DaemonError::AuthFailed(
"Unexpected response to authentication".into(),
));
}
}
// Start main message loop
let heartbeat_interval = Duration::from_secs(self.config.heartbeat_interval_secs);
let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
loop {
tokio::select! {
// Handle incoming server commands
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
tracing::info!("Received WebSocket message: {} bytes", text.len());
match serde_json::from_str::<DaemonCommand>(&text) {
Ok(command) => {
tracing::info!("Parsed command: {:?}", command);
tracing::info!("Sending command to task manager channel...");
if self.incoming_tx.send(command).await.is_err() {
tracing::warn!("Command receiver dropped, shutting down");
break;
}
tracing::info!("Command sent to task manager successfully");
}
Err(e) => {
tracing::warn!("Failed to parse server message: {}", e);
tracing::debug!("Raw message: {}", text);
}
}
}
Some(Ok(Message::Ping(data))) => {
write.send(Message::Pong(data)).await?;
}
Some(Ok(Message::Close(_))) | None => {
tracing::info!("Server closed connection");
return Err(DaemonError::ConnectionLost);
}
Some(Err(e)) => {
tracing::warn!("WebSocket error: {}", e);
return Err(e.into());
}
_ => {}
}
}
// Handle outgoing messages
msg = self.outgoing_rx.recv() => {
match msg {
Some(message) => {
let json = serde_json::to_string(&message)?;
tracing::trace!("Sending message: {}", json);
write.send(Message::Text(json)).await?;
}
None => {
// Sender dropped, shutdown
tracing::info!("Outgoing channel closed, shutting down");
break;
}
}
}
// Send heartbeat
_ = heartbeat_timer.tick() => {
// Get active task IDs from task manager
// For now, send empty list - will be connected to task manager
let heartbeat = DaemonMessage::heartbeat(vec![]);
let json = serde_json::to_string(&heartbeat)?;
write.send(Message::Text(json)).await?;
}
}
}
Ok(())
}
}