summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useMultiTaskSubscription.ts
blob: 84b936697fe5ca31f1cb0e9eb654392549df29c2 (plain) (tree)
1
2
3
4
5
6
7
8
9
                                                                          
                                                                    






                                                               

                                                                                   






























                                                                                    


                                                                                                                             














                                                                              



















































                                                                                  





























































































                                                                                   




                                                               









                                                          
                                                        


                                                   

                                                             

       
                                                                                      














                                                  
                                       







                 
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import { TASK_SUBSCRIBE_ENDPOINT, getTaskOutput } from "../lib/api";
import type { TaskOutputEvent } from "./useTaskSubscription";

export interface MultiTaskOutputEntry extends TaskOutputEvent {
  /** Label for the task (e.g. step name or "Orchestrator") */
  taskLabel: string;
  /** Timestamp when the entry was received */
  receivedAt: number;
  /** Whether this entry was backfilled from historical data (not live-streamed) */
  isBackfill?: boolean;
}

interface UseMultiTaskSubscriptionOptions {
  /** Map of taskId -> label */
  taskMap: Map<string, string>;
  /** Whether to actively subscribe */
  enabled?: boolean;
  /** Max entries to keep in buffer */
  maxEntries?: number;
}

export function useMultiTaskSubscription(options: UseMultiTaskSubscriptionOptions) {
  const { taskMap, enabled = true, maxEntries = 2000 } = options;

  const [connected, setConnected] = useState(false);
  const [entries, setEntries] = useState<MultiTaskOutputEntry[]>([]);
  const wsRef = useRef<WebSocket | null>(null);
  const reconnectTimeoutRef = useRef<number | null>(null);
  const subscribedTasksRef = useRef<Set<string>>(new Set());
  const taskMapRef = useRef(taskMap);
  const enabledRef = useRef(enabled);

  // Keep refs in sync
  useEffect(() => {
    taskMapRef.current = taskMap;
  }, [taskMap]);

  useEffect(() => {
    enabledRef.current = enabled;
  }, [enabled]);

  // Derive task IDs from the map, stabilized to avoid unnecessary effect triggers
  const taskIdsKey = useMemo(() => Array.from(taskMap.keys()).sort().join(","), [taskMap]);
  const taskIds = useMemo(() => Array.from(taskMap.keys()), [taskIdsKey]); // eslint-disable-line react-hooks/exhaustive-deps

  const subscribeToTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
      subscribedTasksRef.current.add(taskId);
    }
  }, []);

  const unsubscribeFromTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "unsubscribeOutput", taskId }));
      subscribedTasksRef.current.delete(taskId);
    }
  }, []);

  const backfillTask = useCallback(
    async (taskId: string, label: string) => {
      if (backfilledTasksRef.current.has(taskId)) return;
      backfilledTasksRef.current.add(taskId);

      try {
        const response = await getTaskOutput(taskId);
        if (response.entries.length === 0) return;

        const historicalEntries: MultiTaskOutputEntry[] = response.entries.map(
          (entry) => ({
            taskId: entry.taskId,
            messageType: entry.messageType,
            content: entry.content,
            toolName: entry.toolName,
            toolInput: entry.toolInput,
            isError: entry.isError,
            costUsd: entry.costUsd,
            durationMs: entry.durationMs,
            isPartial: false,
            taskLabel: label,
            receivedAt: new Date(entry.createdAt || Date.now()).getTime(),
          })
        );

        setEntries((prev) => {
          // De-duplicate by checking if content+taskId+messageType already exists
          const existingKeys = new Set(
            prev.map(
              (e) =>
                `${e.taskId}:${e.messageType}:${e.content.slice(0, 100)}`
            )
          );
          const newHistorical = historicalEntries.filter(
            (e) =>
              !existingKeys.has(
                `${e.taskId}:${e.messageType}:${e.content.slice(0, 100)}`
              )
          );
          const combined = [...newHistorical, ...prev];
          if (combined.length > maxEntries) {
            return combined.slice(combined.length - maxEntries);
          }
          return combined;
        });
      } catch (e) {
        console.error(`Failed to backfill task ${taskId}:`, e);
      }
    },
    [maxEntries]
  );

  const connect = useCallback(() => {
    const currentState = wsRef.current?.readyState;
    if (currentState === WebSocket.OPEN || currentState === WebSocket.CONNECTING) {
      return;
    }

    if (wsRef.current && currentState === WebSocket.CLOSING) {
      wsRef.current = null;
    }

    try {
      const ws = new WebSocket(TASK_SUBSCRIBE_ENDPOINT);
      wsRef.current = ws;

      ws.onopen = () => {
        setConnected(true);
        // Re-subscribe to all tasks
        for (const taskId of subscribedTasksRef.current) {
          ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
        }
      };

      ws.onmessage = (event) => {
        try {
          const message = JSON.parse(event.data);

          if (message.type === "taskOutput") {
            const label = taskMapRef.current.get(message.taskId) || message.taskId;
            const entry: MultiTaskOutputEntry = {
              taskId: message.taskId,
              messageType: message.messageType,
              content: message.content,
              toolName: message.toolName,
              toolInput: message.toolInput,
              isError: message.isError,
              costUsd: message.costUsd,
              durationMs: message.durationMs,
              isPartial: message.isPartial,
              taskLabel: label,
              receivedAt: Date.now(),
            };

            setEntries((prev) => {
              const next = [...prev, entry];
              if (next.length > maxEntries) {
                return next.slice(next.length - maxEntries);
              }
              return next;
            });
          }
        } catch (e) {
          console.error("Failed to parse multi-task subscription message:", e);
        }
      };

      ws.onerror = () => {
        console.error("Multi-task WebSocket connection error");
      };

      ws.onclose = () => {
        setConnected(false);
        wsRef.current = null;

        // Reconnect if we still have subscriptions
        if (subscribedTasksRef.current.size > 0 && enabledRef.current) {
          reconnectTimeoutRef.current = window.setTimeout(() => {
            connect();
          }, 3000);
        }
      };
    } catch (e) {
      console.error("Failed to connect multi-task subscription:", e);
    }
  }, [maxEntries]);

  // Manage subscriptions when task IDs change
  useEffect(() => {
    if (!enabled || taskIds.length === 0) {
      // Close connection if no tasks
      if (wsRef.current) {
        subscribedTasksRef.current.clear();
        wsRef.current.close();
        wsRef.current = null;
      }
      return;
    }

    const newTaskIds = new Set(taskIds);
    const ws = wsRef.current;

    if (!ws || ws.readyState !== WebSocket.OPEN) {
      // Set desired subscriptions and connect
      subscribedTasksRef.current = newTaskIds;
      connect();
      // Backfill all initial tasks
      for (const taskId of newTaskIds) {
        const label = taskMapRef.current.get(taskId) || taskId;
        backfillTask(taskId, label);
      }
      return;
    }

    // Unsubscribe from removed tasks
    for (const existingId of subscribedTasksRef.current) {
      if (!newTaskIds.has(existingId)) {
        unsubscribeFromTask(ws, existingId);
      }
    }

    // Subscribe to new tasks and backfill their history
    for (const newId of newTaskIds) {
      if (!subscribedTasksRef.current.has(newId)) {
        subscribeToTask(ws, newId);
        const label = taskMapRef.current.get(newId) || newId;
        backfillTask(newId, label);
      }
    }
  }, [taskIds, enabled, connect, subscribeToTask, unsubscribeFromTask, backfillTask]);

  // Cleanup on unmount
  useEffect(() => {
    return () => {
      if (reconnectTimeoutRef.current) {
        clearTimeout(reconnectTimeoutRef.current);
      }
      if (wsRef.current) {
        wsRef.current.close();
      }
    };
  }, []);

  const clearEntries = useCallback(() => {
    setEntries([]);
    backfilledTasksRef.current.clear();
  }, []);

  return {
    connected,
    entries,
    clearEntries,
  };
}