import Foundation
/// Minimal WebSocket client for `/api/v1/mesh/tasks/subscribe`.
///
/// Protocol lines up with `src/server/handlers/mesh_ws.rs`:
///
/// client -> server
/// { "type": "subscribeAll" }
/// { "type": "subscribe", "taskId": "<uuid>" }
/// { "type": "subscribeOutput", "taskId": "<uuid>" }
///
/// server -> client
/// { "type": "taskUpdated", "taskId", "status", "updatedFields", ... }
/// { "type": "taskOutput", "taskId", "messageType", "content", ... }
/// { "type": "error", "code", "message" }
///
/// Auth: the current server's WS upgrade handler does NOT check auth headers
/// (verified in mesh_ws.rs — handler lacks an extractor). Filtering is done
/// by owner_id server-side when subscribing to a specific task. We still
/// attach `x-makima-api-key` on the upgrade request so future tightening
/// doesn't break us.
@MainActor
final class TaskWebSocket {
enum Event {
case taskUpdated(TaskUpdatedEvent)
case taskOutput(TaskOutputEvent)
case error(code: String, message: String)
}
struct TaskUpdatedEvent: Decodable {
let taskId: String
let version: Int?
let status: String?
let updatedFields: [String]?
let updatedBy: String?
}
struct TaskOutputEvent: Decodable {
let taskId: String
let messageType: String
let content: String
let toolName: String?
let toolInput: JSONValue?
let isError: Bool?
let costUsd: Double?
let durationMs: Int?
let isPartial: Bool?
}
var onStatusChange: ((WebSocketStatus) -> Void)?
private let profile: ServerProfile
private let apiKey: String
private var task: URLSessionWebSocketTask?
private var session: URLSession
private var retries: Int = 0
private var subscribedAll = false
private var perTaskHandlers: [String: (Event) -> Void] = [:]
private var status: WebSocketStatus = .idle {
didSet { onStatusChange?(status) }
}
private var reconnectTask: Task<Void, Never>?
init(profile: ServerProfile, apiKey: String, session: URLSession = .shared) {
self.profile = profile
self.apiKey = apiKey
self.session = session
}
func connect() {
guard let url = profile.apiWebSocketBaseURL?.appendingPathComponent("mesh/tasks/subscribe") else {
return
}
var request = URLRequest(url: url)
request.setValue(apiKey, forHTTPHeaderField: "x-makima-api-key")
request.setValue("makima-ios/\(APIClient.appVersion)", forHTTPHeaderField: "user-agent")
let task = session.webSocketTask(with: request)
self.task = task
self.status = .connecting
task.resume()
// Subscribe-all on open so Home/Tasks get updates without per-task
// subscriptions racing the initial render.
send(.subscribeAll)
receiveLoop()
}
func disconnect() {
reconnectTask?.cancel()
reconnectTask = nil
task?.cancel(with: .goingAway, reason: nil)
task = nil
status = .offline
perTaskHandlers.removeAll()
subscribedAll = false
}
func subscribe(taskId: String, handler: @escaping (Event) -> Void) {
perTaskHandlers[taskId] = handler
send(.subscribe(taskId: taskId))
send(.subscribeOutput(taskId: taskId))
}
func unsubscribe(taskId: String) {
perTaskHandlers.removeValue(forKey: taskId)
send(.unsubscribe(taskId: taskId))
send(.unsubscribeOutput(taskId: taskId))
}
// MARK: - Wire protocol
enum ClientMessage {
case subscribeAll
case subscribe(taskId: String)
case unsubscribe(taskId: String)
case subscribeOutput(taskId: String)
case unsubscribeOutput(taskId: String)
var jsonObject: [String: Any] {
switch self {
case .subscribeAll: return ["type": "subscribeAll"]
case .subscribe(let id): return ["type": "subscribe", "taskId": id]
case .unsubscribe(let id): return ["type": "unsubscribe", "taskId": id]
case .subscribeOutput(let id): return ["type": "subscribeOutput", "taskId": id]
case .unsubscribeOutput(let id): return ["type": "unsubscribeOutput", "taskId": id]
}
}
}
private func send(_ message: ClientMessage) {
guard let data = try? JSONSerialization.data(withJSONObject: message.jsonObject),
let str = String(data: data, encoding: .utf8)
else { return }
if case .subscribeAll = message { subscribedAll = true }
task?.send(.string(str)) { [weak self] error in
if let error {
print("[TaskWebSocket] send failed: \(error)")
Task { @MainActor in self?.scheduleReconnect() }
}
}
}
private func receiveLoop() {
task?.receive { [weak self] result in
Task { @MainActor in
guard let self else { return }
switch result {
case .success(let message):
self.status = .online
self.retries = 0
switch message {
case .string(let text):
self.handle(text: text)
case .data(let data):
if let text = String(data: data, encoding: .utf8) {
self.handle(text: text)
}
@unknown default:
break
}
self.receiveLoop()
case .failure(let error):
print("[TaskWebSocket] receive failed: \(error)")
self.scheduleReconnect()
}
}
}
}
private func handle(text: String) {
guard let data = text.data(using: .utf8),
let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
let type = obj["type"] as? String
else { return }
let decoder = JSONDecoder()
decoder.dateDecodingStrategy = .iso8601
switch type {
case "taskUpdated":
if let event = try? decoder.decode(TaskUpdatedEvent.self, from: data) {
dispatch(.taskUpdated(event), forTaskId: event.taskId)
}
case "taskOutput":
if let event = try? decoder.decode(TaskOutputEvent.self, from: data) {
dispatch(.taskOutput(event), forTaskId: event.taskId)
}
case "error":
let code = (obj["code"] as? String) ?? "ERROR"
let msg = (obj["message"] as? String) ?? ""
print("[TaskWebSocket] server error \(code): \(msg)")
case "subscribed", "subscribedAll", "unsubscribed", "unsubscribedAll",
"outputSubscribed", "outputUnsubscribed":
break
default:
break
}
}
private func dispatch(_ event: Event, forTaskId id: String) {
perTaskHandlers[id]?(event)
}
private func scheduleReconnect() {
guard reconnectTask == nil else { return }
status = .offline
retries = min(retries + 1, 6)
let delay = min(30.0, pow(2.0, Double(retries)))
reconnectTask = Task { [weak self] in
try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
if Task.isCancelled { return }
guard let self else { return }
self.reconnectTask = nil
self.disconnect()
self.connect()
}
}
}
/// Loosely-typed JSON value for WS tool inputs where the schema varies.
enum JSONValue: Decodable {
case null
case bool(Bool)
case number(Double)
case string(String)
case array([JSONValue])
case object([String: JSONValue])
init(from decoder: Decoder) throws {
let c = try decoder.singleValueContainer()
if c.decodeNil() { self = .null; return }
if let v = try? c.decode(Bool.self) { self = .bool(v); return }
if let v = try? c.decode(Double.self) { self = .number(v); return }
if let v = try? c.decode(String.self) { self = .string(v); return }
if let v = try? c.decode([JSONValue].self) { self = .array(v); return }
if let v = try? c.decode([String: JSONValue].self) { self = .object(v); return }
throw DecodingError.dataCorruptedError(in: c, debugDescription: "Unknown JSON value")
}
}