summaryrefslogblamecommitdiff
path: root/makima/ios/Sources/Makima/Net/TaskWebSocket.swift
blob: 08efc9ce65ca13380ebe2b188c3ed8ece189aac7 (plain) (tree)

















































































































































































































































                                                                                                          
import Foundation

/// Minimal WebSocket client for `/api/v1/mesh/tasks/subscribe`.
///
/// Protocol lines up with `src/server/handlers/mesh_ws.rs`:
///
///   client -> server
///     { "type": "subscribeAll" }
///     { "type": "subscribe", "taskId": "<uuid>" }
///     { "type": "subscribeOutput", "taskId": "<uuid>" }
///
///   server -> client
///     { "type": "taskUpdated", "taskId", "status", "updatedFields", ... }
///     { "type": "taskOutput",  "taskId", "messageType", "content", ... }
///     { "type": "error", "code", "message" }
///
/// Auth: the current server's WS upgrade handler does NOT check auth headers
/// (verified in mesh_ws.rs — handler lacks an extractor). Filtering is done
/// by owner_id server-side when subscribing to a specific task. We still
/// attach `x-makima-api-key` on the upgrade request so future tightening
/// doesn't break us.
@MainActor
final class TaskWebSocket {
    enum Event {
        case taskUpdated(TaskUpdatedEvent)
        case taskOutput(TaskOutputEvent)
        case error(code: String, message: String)
    }

    struct TaskUpdatedEvent: Decodable {
        let taskId: String
        let version: Int?
        let status: String?
        let updatedFields: [String]?
        let updatedBy: String?
    }

    struct TaskOutputEvent: Decodable {
        let taskId: String
        let messageType: String
        let content: String
        let toolName: String?
        let toolInput: JSONValue?
        let isError: Bool?
        let costUsd: Double?
        let durationMs: Int?
        let isPartial: Bool?
    }

    var onStatusChange: ((WebSocketStatus) -> Void)?

    private let profile: ServerProfile
    private let apiKey: String
    private var task: URLSessionWebSocketTask?
    private var session: URLSession
    private var retries: Int = 0
    private var subscribedAll = false
    private var perTaskHandlers: [String: (Event) -> Void] = [:]
    private var status: WebSocketStatus = .idle {
        didSet { onStatusChange?(status) }
    }
    private var reconnectTask: Task<Void, Never>?

    init(profile: ServerProfile, apiKey: String, session: URLSession = .shared) {
        self.profile = profile
        self.apiKey = apiKey
        self.session = session
    }

    func connect() {
        guard let url = profile.apiWebSocketBaseURL?.appendingPathComponent("mesh/tasks/subscribe") else {
            return
        }
        var request = URLRequest(url: url)
        request.setValue(apiKey, forHTTPHeaderField: "x-makima-api-key")
        request.setValue("makima-ios/\(APIClient.appVersion)", forHTTPHeaderField: "user-agent")

        let task = session.webSocketTask(with: request)
        self.task = task
        self.status = .connecting
        task.resume()

        // Subscribe-all on open so Home/Tasks get updates without per-task
        // subscriptions racing the initial render.
        send(.subscribeAll)
        receiveLoop()
    }

    func disconnect() {
        reconnectTask?.cancel()
        reconnectTask = nil
        task?.cancel(with: .goingAway, reason: nil)
        task = nil
        status = .offline
        perTaskHandlers.removeAll()
        subscribedAll = false
    }

    func subscribe(taskId: String, handler: @escaping (Event) -> Void) {
        perTaskHandlers[taskId] = handler
        send(.subscribe(taskId: taskId))
        send(.subscribeOutput(taskId: taskId))
    }

    func unsubscribe(taskId: String) {
        perTaskHandlers.removeValue(forKey: taskId)
        send(.unsubscribe(taskId: taskId))
        send(.unsubscribeOutput(taskId: taskId))
    }

    // MARK: - Wire protocol

    enum ClientMessage {
        case subscribeAll
        case subscribe(taskId: String)
        case unsubscribe(taskId: String)
        case subscribeOutput(taskId: String)
        case unsubscribeOutput(taskId: String)

        var jsonObject: [String: Any] {
            switch self {
            case .subscribeAll:                 return ["type": "subscribeAll"]
            case .subscribe(let id):            return ["type": "subscribe", "taskId": id]
            case .unsubscribe(let id):          return ["type": "unsubscribe", "taskId": id]
            case .subscribeOutput(let id):      return ["type": "subscribeOutput", "taskId": id]
            case .unsubscribeOutput(let id):    return ["type": "unsubscribeOutput", "taskId": id]
            }
        }
    }

    private func send(_ message: ClientMessage) {
        guard let data = try? JSONSerialization.data(withJSONObject: message.jsonObject),
              let str = String(data: data, encoding: .utf8)
        else { return }

        if case .subscribeAll = message { subscribedAll = true }

        task?.send(.string(str)) { [weak self] error in
            if let error {
                print("[TaskWebSocket] send failed: \(error)")
                Task { @MainActor in self?.scheduleReconnect() }
            }
        }
    }

    private func receiveLoop() {
        task?.receive { [weak self] result in
            Task { @MainActor in
                guard let self else { return }
                switch result {
                case .success(let message):
                    self.status = .online
                    self.retries = 0
                    switch message {
                    case .string(let text):
                        self.handle(text: text)
                    case .data(let data):
                        if let text = String(data: data, encoding: .utf8) {
                            self.handle(text: text)
                        }
                    @unknown default:
                        break
                    }
                    self.receiveLoop()
                case .failure(let error):
                    print("[TaskWebSocket] receive failed: \(error)")
                    self.scheduleReconnect()
                }
            }
        }
    }

    private func handle(text: String) {
        guard let data = text.data(using: .utf8),
              let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
              let type = obj["type"] as? String
        else { return }

        let decoder = JSONDecoder()
        decoder.dateDecodingStrategy = .iso8601

        switch type {
        case "taskUpdated":
            if let event = try? decoder.decode(TaskUpdatedEvent.self, from: data) {
                dispatch(.taskUpdated(event), forTaskId: event.taskId)
            }
        case "taskOutput":
            if let event = try? decoder.decode(TaskOutputEvent.self, from: data) {
                dispatch(.taskOutput(event), forTaskId: event.taskId)
            }
        case "error":
            let code = (obj["code"] as? String) ?? "ERROR"
            let msg  = (obj["message"] as? String) ?? ""
            print("[TaskWebSocket] server error \(code): \(msg)")
        case "subscribed", "subscribedAll", "unsubscribed", "unsubscribedAll",
             "outputSubscribed", "outputUnsubscribed":
            break
        default:
            break
        }
    }

    private func dispatch(_ event: Event, forTaskId id: String) {
        perTaskHandlers[id]?(event)
    }

    private func scheduleReconnect() {
        guard reconnectTask == nil else { return }
        status = .offline
        retries = min(retries + 1, 6)
        let delay = min(30.0, pow(2.0, Double(retries)))
        reconnectTask = Task { [weak self] in
            try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
            if Task.isCancelled { return }
            guard let self else { return }
            self.reconnectTask = nil
            self.disconnect()
            self.connect()
        }
    }
}

/// Loosely-typed JSON value for WS tool inputs where the schema varies.
enum JSONValue: Decodable {
    case null
    case bool(Bool)
    case number(Double)
    case string(String)
    case array([JSONValue])
    case object([String: JSONValue])

    init(from decoder: Decoder) throws {
        let c = try decoder.singleValueContainer()
        if c.decodeNil() { self = .null; return }
        if let v = try? c.decode(Bool.self)   { self = .bool(v); return }
        if let v = try? c.decode(Double.self) { self = .number(v); return }
        if let v = try? c.decode(String.self) { self = .string(v); return }
        if let v = try? c.decode([JSONValue].self) { self = .array(v); return }
        if let v = try? c.decode([String: JSONValue].self) { self = .object(v); return }
        throw DecodingError.dataCorruptedError(in: c, debugDescription: "Unknown JSON value")
    }
}