summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useMultiTaskSubscription.ts
blob: 4303f1bae1b451232d6d785d2952865ed0aa4fda (plain) (tree)







































                                                                                    


                                                                                                                             




















































































































































                                                                                   
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import { TASK_SUBSCRIBE_ENDPOINT } from "../lib/api";
import type { TaskOutputEvent } from "./useTaskSubscription";

export interface MultiTaskOutputEntry extends TaskOutputEvent {
  /** Label for the task (e.g. step name or "Orchestrator") */
  taskLabel: string;
  /** Timestamp when the entry was received */
  receivedAt: number;
}

interface UseMultiTaskSubscriptionOptions {
  /** Map of taskId -> label */
  taskMap: Map<string, string>;
  /** Whether to actively subscribe */
  enabled?: boolean;
  /** Max entries to keep in buffer */
  maxEntries?: number;
}

export function useMultiTaskSubscription(options: UseMultiTaskSubscriptionOptions) {
  const { taskMap, enabled = true, maxEntries = 2000 } = options;

  const [connected, setConnected] = useState(false);
  const [entries, setEntries] = useState<MultiTaskOutputEntry[]>([]);
  const wsRef = useRef<WebSocket | null>(null);
  const reconnectTimeoutRef = useRef<number | null>(null);
  const subscribedTasksRef = useRef<Set<string>>(new Set());
  const taskMapRef = useRef(taskMap);
  const enabledRef = useRef(enabled);

  // Keep refs in sync
  useEffect(() => {
    taskMapRef.current = taskMap;
  }, [taskMap]);

  useEffect(() => {
    enabledRef.current = enabled;
  }, [enabled]);

  // Derive task IDs from the map, stabilized to avoid unnecessary effect triggers
  const taskIdsKey = useMemo(() => Array.from(taskMap.keys()).sort().join(","), [taskMap]);
  const taskIds = useMemo(() => Array.from(taskMap.keys()), [taskIdsKey]); // eslint-disable-line react-hooks/exhaustive-deps

  const subscribeToTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
      subscribedTasksRef.current.add(taskId);
    }
  }, []);

  const unsubscribeFromTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "unsubscribeOutput", taskId }));
      subscribedTasksRef.current.delete(taskId);
    }
  }, []);

  const connect = useCallback(() => {
    const currentState = wsRef.current?.readyState;
    if (currentState === WebSocket.OPEN || currentState === WebSocket.CONNECTING) {
      return;
    }

    if (wsRef.current && currentState === WebSocket.CLOSING) {
      wsRef.current = null;
    }

    try {
      const ws = new WebSocket(TASK_SUBSCRIBE_ENDPOINT);
      wsRef.current = ws;

      ws.onopen = () => {
        setConnected(true);
        // Re-subscribe to all tasks
        for (const taskId of subscribedTasksRef.current) {
          ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
        }
      };

      ws.onmessage = (event) => {
        try {
          const message = JSON.parse(event.data);

          if (message.type === "taskOutput") {
            const label = taskMapRef.current.get(message.taskId) || message.taskId;
            const entry: MultiTaskOutputEntry = {
              taskId: message.taskId,
              messageType: message.messageType,
              content: message.content,
              toolName: message.toolName,
              toolInput: message.toolInput,
              isError: message.isError,
              costUsd: message.costUsd,
              durationMs: message.durationMs,
              isPartial: message.isPartial,
              taskLabel: label,
              receivedAt: Date.now(),
            };

            setEntries((prev) => {
              const next = [...prev, entry];
              if (next.length > maxEntries) {
                return next.slice(next.length - maxEntries);
              }
              return next;
            });
          }
        } catch (e) {
          console.error("Failed to parse multi-task subscription message:", e);
        }
      };

      ws.onerror = () => {
        console.error("Multi-task WebSocket connection error");
      };

      ws.onclose = () => {
        setConnected(false);
        wsRef.current = null;

        // Reconnect if we still have subscriptions
        if (subscribedTasksRef.current.size > 0 && enabledRef.current) {
          reconnectTimeoutRef.current = window.setTimeout(() => {
            connect();
          }, 3000);
        }
      };
    } catch (e) {
      console.error("Failed to connect multi-task subscription:", e);
    }
  }, [maxEntries]);

  // Manage subscriptions when task IDs change
  useEffect(() => {
    if (!enabled || taskIds.length === 0) {
      // Close connection if no tasks
      if (wsRef.current) {
        subscribedTasksRef.current.clear();
        wsRef.current.close();
        wsRef.current = null;
      }
      return;
    }

    const newTaskIds = new Set(taskIds);
    const ws = wsRef.current;

    if (!ws || ws.readyState !== WebSocket.OPEN) {
      // Set desired subscriptions and connect
      subscribedTasksRef.current = newTaskIds;
      connect();
      return;
    }

    // Unsubscribe from removed tasks
    for (const existingId of subscribedTasksRef.current) {
      if (!newTaskIds.has(existingId)) {
        unsubscribeFromTask(ws, existingId);
      }
    }

    // Subscribe to new tasks
    for (const newId of newTaskIds) {
      if (!subscribedTasksRef.current.has(newId)) {
        subscribeToTask(ws, newId);
      }
    }
  }, [taskIds, enabled, connect, subscribeToTask, unsubscribeFromTask]);

  // Cleanup on unmount
  useEffect(() => {
    return () => {
      if (reconnectTimeoutRef.current) {
        clearTimeout(reconnectTimeoutRef.current);
      }
      if (wsRef.current) {
        wsRef.current.close();
      }
    };
  }, []);

  const clearEntries = useCallback(() => {
    setEntries([]);
  }, []);

  return {
    connected,
    entries,
    clearEntries,
  };
}