summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useMultiTaskSubscription.ts
blob: 41489c770533d1de93f86f8d2657517ef049505d (plain) (tree)
1
2
3
4
5
6
7
8
9
                                                                          
                                                                    






                                                               

                                                                                   


















                                                                                    
                                                            

                                     

                                                                               









                                 

















































































                                                                                          


                                                                                                                             














                                                                              



















































                                                                                  





























































































                                                                                   




                                                               









                                                          
                                                        


                                                   

                                                             

       
                                                                                      














                                                  
                                       







                 
import { useState, useCallback, useRef, useEffect, useMemo } from "react";
import { TASK_SUBSCRIBE_ENDPOINT, getTaskOutput } from "../lib/api";
import type { TaskOutputEvent } from "./useTaskSubscription";

export interface MultiTaskOutputEntry extends TaskOutputEvent {
  /** Label for the task (e.g. step name or "Orchestrator") */
  taskLabel: string;
  /** Timestamp when the entry was received */
  receivedAt: number;
  /** Whether this entry was backfilled from historical data (not live-streamed) */
  isBackfill?: boolean;
}

interface UseMultiTaskSubscriptionOptions {
  /** Map of taskId -> label */
  taskMap: Map<string, string>;
  /** Whether to actively subscribe */
  enabled?: boolean;
  /** Max entries to keep in buffer */
  maxEntries?: number;
}

export function useMultiTaskSubscription(options: UseMultiTaskSubscriptionOptions) {
  const { taskMap, enabled = true, maxEntries = 2000 } = options;

  const [connected, setConnected] = useState(false);
  const [entries, setEntries] = useState<MultiTaskOutputEntry[]>([]);
  const wsRef = useRef<WebSocket | null>(null);
  const reconnectTimeoutRef = useRef<number | null>(null);
  const subscribedTasksRef = useRef<Set<string>>(new Set());
  const backfilledTasksRef = useRef<Set<string>>(new Set());
  const taskMapRef = useRef(taskMap);
  const enabledRef = useRef(enabled);
  /** Track which task IDs have already been backfilled to avoid re-fetching */
  const backfilledTasksRef = useRef<Set<string>>(new Set());

  // Keep refs in sync
  useEffect(() => {
    taskMapRef.current = taskMap;
  }, [taskMap]);

  useEffect(() => {
    enabledRef.current = enabled;
  }, [enabled]);

  /** Max number of historical events to backfill per task */
  const MAX_BACKFILL_PER_TASK = 200;

  /**
   * Convert a TaskEvent (from the REST API) into a MultiTaskOutputEntry.
   * Only converts events with event_type === 'output'.
   */
  const convertTaskEventToEntry = useCallback(
    (event: TaskEvent): MultiTaskOutputEntry | null => {
      if (event.eventType !== "output") return null;
      const data = event.eventData;
      if (!data) return null;

      return {
        taskId: event.taskId,
        messageType: (data.messageType as string) || "system",
        content: (data.content as string) || "",
        toolName: data.toolName as string | undefined,
        toolInput: data.toolInput as Record<string, unknown> | undefined,
        isError: data.isError as boolean | undefined,
        costUsd: data.costUsd as number | undefined,
        durationMs: data.durationMs as number | undefined,
        isPartial: false,
        taskLabel:
          taskMapRef.current.get(event.taskId) || event.taskId,
        receivedAt: new Date(event.createdAt).getTime(),
        isBackfill: true,
      };
    },
    []
  );

  /**
   * Backfill historical log entries for a task from the REST API.
   * Only fetches once per task ID (tracked in backfilledTasksRef).
   */
  const backfillTask = useCallback(
    async (taskId: string) => {
      if (backfilledTasksRef.current.has(taskId)) return;
      backfilledTasksRef.current.add(taskId);

      try {
        const response = await listTaskEvents(taskId);
        const events = response.events;

        // The API returns events in DESC order; reverse to get chronological ASC
        const chronologicalEvents = [...events].reverse();

        // Filter to output events and convert, limiting to MAX_BACKFILL_PER_TASK
        const backfillEntries: MultiTaskOutputEntry[] = [];
        for (const event of chronologicalEvents) {
          const entry = convertTaskEventToEntry(event);
          if (entry) {
            backfillEntries.push(entry);
            if (backfillEntries.length >= MAX_BACKFILL_PER_TASK) break;
          }
        }

        if (backfillEntries.length === 0) return;

        // Prepend historical entries before any existing live entries for this task,
        // maintaining overall chronological order across all tasks
        setEntries((prev) => {
          // Merge backfill entries with existing entries, maintaining chronological order
          const merged = [...backfillEntries, ...prev];
          // Sort by receivedAt to ensure proper chronological ordering
          merged.sort((a, b) => a.receivedAt - b.receivedAt);
          // Trim to maxEntries
          if (merged.length > maxEntries) {
            return merged.slice(merged.length - maxEntries);
          }
          return merged;
        });
      } catch (e) {
        console.error(`Failed to backfill task events for ${taskId}:`, e);
        // Remove from backfilled set so it can be retried
        backfilledTasksRef.current.delete(taskId);
      }
    },
    [convertTaskEventToEntry, maxEntries]
  );

  // Derive task IDs from the map, stabilized to avoid unnecessary effect triggers
  const taskIdsKey = useMemo(() => Array.from(taskMap.keys()).sort().join(","), [taskMap]);
  const taskIds = useMemo(() => Array.from(taskMap.keys()), [taskIdsKey]); // eslint-disable-line react-hooks/exhaustive-deps

  const subscribeToTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
      subscribedTasksRef.current.add(taskId);
    }
  }, []);

  const unsubscribeFromTask = useCallback((ws: WebSocket, taskId: string) => {
    if (ws.readyState === WebSocket.OPEN) {
      ws.send(JSON.stringify({ type: "unsubscribeOutput", taskId }));
      subscribedTasksRef.current.delete(taskId);
    }
  }, []);

  const backfillTask = useCallback(
    async (taskId: string, label: string) => {
      if (backfilledTasksRef.current.has(taskId)) return;
      backfilledTasksRef.current.add(taskId);

      try {
        const response = await getTaskOutput(taskId);
        if (response.entries.length === 0) return;

        const historicalEntries: MultiTaskOutputEntry[] = response.entries.map(
          (entry) => ({
            taskId: entry.taskId,
            messageType: entry.messageType,
            content: entry.content,
            toolName: entry.toolName,
            toolInput: entry.toolInput,
            isError: entry.isError,
            costUsd: entry.costUsd,
            durationMs: entry.durationMs,
            isPartial: false,
            taskLabel: label,
            receivedAt: new Date(entry.createdAt || Date.now()).getTime(),
          })
        );

        setEntries((prev) => {
          // De-duplicate by checking if content+taskId+messageType already exists
          const existingKeys = new Set(
            prev.map(
              (e) =>
                `${e.taskId}:${e.messageType}:${e.content.slice(0, 100)}`
            )
          );
          const newHistorical = historicalEntries.filter(
            (e) =>
              !existingKeys.has(
                `${e.taskId}:${e.messageType}:${e.content.slice(0, 100)}`
              )
          );
          const combined = [...newHistorical, ...prev];
          if (combined.length > maxEntries) {
            return combined.slice(combined.length - maxEntries);
          }
          return combined;
        });
      } catch (e) {
        console.error(`Failed to backfill task ${taskId}:`, e);
      }
    },
    [maxEntries]
  );

  const connect = useCallback(() => {
    const currentState = wsRef.current?.readyState;
    if (currentState === WebSocket.OPEN || currentState === WebSocket.CONNECTING) {
      return;
    }

    if (wsRef.current && currentState === WebSocket.CLOSING) {
      wsRef.current = null;
    }

    try {
      const ws = new WebSocket(TASK_SUBSCRIBE_ENDPOINT);
      wsRef.current = ws;

      ws.onopen = () => {
        setConnected(true);
        // Re-subscribe to all tasks
        for (const taskId of subscribedTasksRef.current) {
          ws.send(JSON.stringify({ type: "subscribeOutput", taskId }));
        }
      };

      ws.onmessage = (event) => {
        try {
          const message = JSON.parse(event.data);

          if (message.type === "taskOutput") {
            const label = taskMapRef.current.get(message.taskId) || message.taskId;
            const entry: MultiTaskOutputEntry = {
              taskId: message.taskId,
              messageType: message.messageType,
              content: message.content,
              toolName: message.toolName,
              toolInput: message.toolInput,
              isError: message.isError,
              costUsd: message.costUsd,
              durationMs: message.durationMs,
              isPartial: message.isPartial,
              taskLabel: label,
              receivedAt: Date.now(),
            };

            setEntries((prev) => {
              const next = [...prev, entry];
              if (next.length > maxEntries) {
                return next.slice(next.length - maxEntries);
              }
              return next;
            });
          }
        } catch (e) {
          console.error("Failed to parse multi-task subscription message:", e);
        }
      };

      ws.onerror = () => {
        console.error("Multi-task WebSocket connection error");
      };

      ws.onclose = () => {
        setConnected(false);
        wsRef.current = null;

        // Reconnect if we still have subscriptions
        if (subscribedTasksRef.current.size > 0 && enabledRef.current) {
          reconnectTimeoutRef.current = window.setTimeout(() => {
            connect();
          }, 3000);
        }
      };
    } catch (e) {
      console.error("Failed to connect multi-task subscription:", e);
    }
  }, [maxEntries]);

  // Manage subscriptions when task IDs change
  useEffect(() => {
    if (!enabled || taskIds.length === 0) {
      // Close connection if no tasks
      if (wsRef.current) {
        subscribedTasksRef.current.clear();
        wsRef.current.close();
        wsRef.current = null;
      }
      return;
    }

    const newTaskIds = new Set(taskIds);
    const ws = wsRef.current;

    if (!ws || ws.readyState !== WebSocket.OPEN) {
      // Set desired subscriptions and connect
      subscribedTasksRef.current = newTaskIds;
      connect();
      // Backfill all initial tasks
      for (const taskId of newTaskIds) {
        const label = taskMapRef.current.get(taskId) || taskId;
        backfillTask(taskId, label);
      }
      return;
    }

    // Unsubscribe from removed tasks
    for (const existingId of subscribedTasksRef.current) {
      if (!newTaskIds.has(existingId)) {
        unsubscribeFromTask(ws, existingId);
      }
    }

    // Subscribe to new tasks and backfill their history
    for (const newId of newTaskIds) {
      if (!subscribedTasksRef.current.has(newId)) {
        subscribeToTask(ws, newId);
        const label = taskMapRef.current.get(newId) || newId;
        backfillTask(newId, label);
      }
    }
  }, [taskIds, enabled, connect, subscribeToTask, unsubscribeFromTask, backfillTask]);

  // Cleanup on unmount
  useEffect(() => {
    return () => {
      if (reconnectTimeoutRef.current) {
        clearTimeout(reconnectTimeoutRef.current);
      }
      if (wsRef.current) {
        wsRef.current.close();
      }
    };
  }, []);

  const clearEntries = useCallback(() => {
    setEntries([]);
    backfilledTasksRef.current.clear();
  }, []);

  return {
    connected,
    entries,
    clearEntries,
  };
}