summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useTaskSubscription.ts
blob: 9316c3a19825888a07adb53413b5b07a17186451 (plain) (tree)












































































































































































































































































































































                                                                                                 
import { useState, useCallback, useRef, useEffect } from "react";
import { TASK_SUBSCRIBE_ENDPOINT } from "../lib/api";

export interface TaskUpdateEvent {
  taskId: string;
  version: number;
  status: string;
  updatedFields: string[];
  updatedBy: "user" | "daemon" | "system";
}

export interface TaskOutputEvent {
  taskId: string;
  /** Message type: "assistant", "tool_use", "tool_result", "result", "system", "error", "raw" */
  messageType: string;
  /** Main text content */
  content: string;
  /** Tool name if tool_use message */
  toolName?: string;
  /** Tool input JSON if tool_use message */
  toolInput?: Record<string, unknown>;
  /** Whether tool result was an error */
  isError?: boolean;
  /** Cost in USD if result message */
  costUsd?: number;
  /** Duration in ms if result message */
  durationMs?: number;
  isPartial: boolean;
}

interface UseTaskSubscriptionOptions {
  taskId: string | null;
  subscribeAll?: boolean;
  subscribeOutput?: boolean;
  /** Task ID to subscribe output for (defaults to taskId if not specified) */
  outputTaskId?: string;
  onUpdate?: (event: TaskUpdateEvent) => void;
  onOutput?: (event: TaskOutputEvent) => void;
  onError?: (error: string) => void;
}

export function useTaskSubscription(options: UseTaskSubscriptionOptions) {
  const {
    taskId,
    subscribeAll = false,
    subscribeOutput = false,
    outputTaskId,
    onUpdate,
    onOutput,
    onError,
  } = options;

  // The task ID to use for output subscription (defaults to taskId)
  const effectiveOutputTaskId = outputTaskId || taskId;

  const [connected, setConnected] = useState(false);
  const wsRef = useRef<WebSocket | null>(null);
  const reconnectTimeoutRef = useRef<number | null>(null);
  const subscribedTaskRef = useRef<string | null>(null);
  const subscribedAllRef = useRef(false);
  const subscribedOutputRef = useRef<string | null>(null);

  // Store callbacks in refs to avoid re-connecting when callbacks change
  const callbacksRef = useRef({ onUpdate, onOutput, onError });
  useEffect(() => {
    callbacksRef.current = { onUpdate, onOutput, onError };
  }, [onUpdate, onOutput, onError]);

  const connect = useCallback(() => {
    // Prevent multiple connections - check for OPEN or CONNECTING states
    const currentState = wsRef.current?.readyState;
    if (currentState === WebSocket.OPEN || currentState === WebSocket.CONNECTING) {
      return;
    }

    // Close any existing connection that's in CLOSING state
    if (wsRef.current && currentState === WebSocket.CLOSING) {
      wsRef.current = null;
    }

    try {
      const ws = new WebSocket(TASK_SUBSCRIBE_ENDPOINT);
      wsRef.current = ws;

      ws.onopen = () => {
        setConnected(true);
        // Re-subscribe if we had subscriptions
        if (subscribedAllRef.current) {
          ws.send(JSON.stringify({ type: "subscribeAll" }));
        }
        if (subscribedTaskRef.current) {
          ws.send(
            JSON.stringify({
              type: "subscribe",
              taskId: subscribedTaskRef.current,
            })
          );
        }
        if (subscribedOutputRef.current) {
          ws.send(
            JSON.stringify({
              type: "subscribeOutput",
              taskId: subscribedOutputRef.current,
            })
          );
        }
      };

      ws.onmessage = (event) => {
        try {
          const message = JSON.parse(event.data);

          switch (message.type) {
            case "taskUpdated":
              callbacksRef.current.onUpdate?.({
                taskId: message.taskId,
                version: message.version,
                status: message.status,
                updatedFields: message.updatedFields,
                updatedBy: message.updatedBy,
              });
              break;
            case "taskOutput":
              callbacksRef.current.onOutput?.({
                taskId: message.taskId,
                messageType: message.messageType,
                content: message.content,
                toolName: message.toolName,
                toolInput: message.toolInput,
                isError: message.isError,
                costUsd: message.costUsd,
                durationMs: message.durationMs,
                isPartial: message.isPartial,
              });
              break;
            case "error":
              callbacksRef.current.onError?.(message.message);
              break;
            // Acknowledgement messages - could add callbacks if needed
            case "subscribed":
            case "unsubscribed":
            case "subscribedAll":
            case "unsubscribedAll":
            case "outputSubscribed":
            case "outputUnsubscribed":
              break;
          }
        } catch (e) {
          console.error("Failed to parse task subscription message:", e);
        }
      };

      ws.onerror = () => {
        callbacksRef.current.onError?.("WebSocket connection error");
      };

      ws.onclose = () => {
        setConnected(false);
        wsRef.current = null;

        // Attempt reconnection after 3 seconds if we still have a subscription
        if (
          subscribedTaskRef.current ||
          subscribedAllRef.current ||
          subscribedOutputRef.current
        ) {
          reconnectTimeoutRef.current = window.setTimeout(() => {
            connect();
          }, 3000);
        }
      };
    } catch (e) {
      callbacksRef.current.onError?.(
        e instanceof Error ? e.message : "Failed to connect"
      );
    }
  }, []);

  const subscribeToTask = useCallback(
    (id: string) => {
      subscribedTaskRef.current = id;

      if (wsRef.current?.readyState === WebSocket.OPEN) {
        wsRef.current.send(
          JSON.stringify({
            type: "subscribe",
            taskId: id,
          })
        );
      } else {
        connect();
      }
    },
    [connect]
  );

  const unsubscribeFromTask = useCallback(() => {
    if (
      subscribedTaskRef.current &&
      wsRef.current?.readyState === WebSocket.OPEN
    ) {
      wsRef.current.send(
        JSON.stringify({
          type: "unsubscribe",
          taskId: subscribedTaskRef.current,
        })
      );
    }
    subscribedTaskRef.current = null;
  }, []);

  const subscribeToAll = useCallback(() => {
    subscribedAllRef.current = true;

    if (wsRef.current?.readyState === WebSocket.OPEN) {
      wsRef.current.send(JSON.stringify({ type: "subscribeAll" }));
    } else {
      connect();
    }
  }, [connect]);

  const unsubscribeFromAll = useCallback(() => {
    if (wsRef.current?.readyState === WebSocket.OPEN) {
      wsRef.current.send(JSON.stringify({ type: "unsubscribeAll" }));
    }
    subscribedAllRef.current = false;
  }, []);

  const subscribeToOutput = useCallback(
    (id: string) => {
      // First unsubscribe from any previous output subscription
      if (subscribedOutputRef.current && subscribedOutputRef.current !== id) {
        if (wsRef.current?.readyState === WebSocket.OPEN) {
          wsRef.current.send(
            JSON.stringify({
              type: "unsubscribeOutput",
              taskId: subscribedOutputRef.current,
            })
          );
        }
      }

      subscribedOutputRef.current = id;

      if (wsRef.current?.readyState === WebSocket.OPEN) {
        wsRef.current.send(
          JSON.stringify({
            type: "subscribeOutput",
            taskId: id,
          })
        );
      } else {
        connect();
      }
    },
    [connect]
  );

  const unsubscribeFromOutput = useCallback(() => {
    if (
      subscribedOutputRef.current &&
      wsRef.current?.readyState === WebSocket.OPEN
    ) {
      wsRef.current.send(
        JSON.stringify({
          type: "unsubscribeOutput",
          taskId: subscribedOutputRef.current,
        })
      );
    }
    subscribedOutputRef.current = null;
  }, []);

  // Auto-subscribe based on options
  useEffect(() => {
    if (subscribeAll) {
      subscribeToAll();
    } else if (taskId) {
      subscribeToTask(taskId);
    } else {
      unsubscribeFromTask();
      unsubscribeFromAll();
    }

    return () => {
      unsubscribeFromTask();
      unsubscribeFromAll();
    };
  }, [
    taskId,
    subscribeAll,
    subscribeToTask,
    unsubscribeFromTask,
    subscribeToAll,
    unsubscribeFromAll,
  ]);

  // Handle output subscription separately
  // Uses effectiveOutputTaskId which may be different from taskId when viewing subtask output
  useEffect(() => {
    if (subscribeOutput && effectiveOutputTaskId) {
      subscribeToOutput(effectiveOutputTaskId);
    } else {
      unsubscribeFromOutput();
    }

    return () => {
      unsubscribeFromOutput();
    };
  }, [effectiveOutputTaskId, subscribeOutput, subscribeToOutput, unsubscribeFromOutput]);

  // Cleanup on unmount
  useEffect(() => {
    return () => {
      if (reconnectTimeoutRef.current) {
        clearTimeout(reconnectTimeoutRef.current);
      }
      if (wsRef.current) {
        wsRef.current.close();
      }
    };
  }, []);

  return {
    connected,
    subscribeToTask,
    unsubscribeFromTask,
    subscribeToAll,
    unsubscribeFromAll,
    subscribeToOutput,
    unsubscribeFromOutput,
  };
}