summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useWebSocket.ts
blob: 961951f8aa0c654ea93e3d87c1224a26ae22abf7 (plain) (tree)







































                                                                 

                                             



















































                                                                             












                                                                                                 
                          

                                                               
                   




                                                                       











                                                                              
                                            

                                                                    







                                                                          













































                                                                                                     




                                                             






































                                                                         
                                 








                                              

                                         






















                                                              
import { useState, useCallback, useRef, useEffect } from "react";
import { LISTEN_ENDPOINT } from "../lib/api";
import type {
  ClientMessage,
  ServerMessage,
  TranscriptEntry,
} from "../types/messages";

export type ConnectionStatus =
  | "disconnected"
  | "connecting"
  | "connected"
  | "error";

export interface WebSocketState {
  status: ConnectionStatus;
  sessionId: string | null;
  error: string | null;
  transcripts: TranscriptEntry[];
}

interface UseWebSocketOptions {
  onReady?: (sessionId: string) => void;
  onTranscript?: (transcript: TranscriptEntry) => void;
  onError?: (code: string, message: string) => void;
  onStopped?: (reason: string) => void;
}

export function useWebSocket(options: UseWebSocketOptions = {}) {
  const { onReady, onTranscript, onError, onStopped } = options;

  const [state, setState] = useState<WebSocketState>({
    status: "disconnected",
    sessionId: null,
    error: null,
    transcripts: [],
  });

  const wsRef = useRef<WebSocket | null>(null);
  const transcriptIdRef = useRef(0);
  const stoppingRef = useRef(false);
  const pendingDisconnectRef = useRef(false);

  // Store callbacks in refs to avoid recreating handlers
  const callbacksRef = useRef({ onReady, onTranscript, onError, onStopped });
  useEffect(() => {
    callbacksRef.current = { onReady, onTranscript, onError, onStopped };
  }, [onReady, onTranscript, onError, onStopped]);

  const connect = useCallback((): Promise<boolean> => {
    return new Promise((resolve) => {
      if (wsRef.current?.readyState === WebSocket.OPEN) {
        resolve(true);
        return;
      }

      // Close any existing connection
      if (wsRef.current) {
        wsRef.current.close();
        wsRef.current = null;
      }

      setState((s) => ({ ...s, status: "connecting", error: null }));

      try {
        const ws = new WebSocket(LISTEN_ENDPOINT);
        wsRef.current = ws;

        ws.onopen = () => {
          setState((s) => ({ ...s, status: "connected", error: null }));
          resolve(true);
        };

        ws.onmessage = (event) => {
          try {
            const message: ServerMessage = JSON.parse(event.data);

            switch (message.type) {
              case "ready":
                setState((s) => ({ ...s, sessionId: message.sessionId }));
                callbacksRef.current.onReady?.(message.sessionId);
                break;

              case "transcript": {
                const entry: TranscriptEntry = {
                  id: `t-${++transcriptIdRef.current}`,
                  speaker: message.speaker,
                  start: message.start,
                  end: message.end,
                  text: message.text,
                  isFinal: message.isFinal,
                };

                setState((s) => {
                  // Find existing transcript with same speaker and overlapping timestamp
                  const existingIdx = s.transcripts.findIndex(
                    (t) =>
                      t.speaker === message.speaker &&
                      Math.abs(t.start - message.start) < 0.1
                  );

                  let newTranscripts: TranscriptEntry[];

                  if (existingIdx >= 0) {
                    // Replace existing transcript (final replaces non-final, or update in place)
                    newTranscripts = [...s.transcripts];
                    newTranscripts[existingIdx] = entry;
                  } else {
                    // No overlap - insert in time order
                    newTranscripts = [...s.transcripts, entry];
                  }

                  // Sort by start time to maintain chronological order
                  newTranscripts.sort((a, b) => a.start - b.start);

                  return { ...s, transcripts: newTranscripts };
                });

                callbacksRef.current.onTranscript?.(entry);
                break;
              }

              case "error":
                setState((s) => ({ ...s, error: message.message }));
                callbacksRef.current.onError?.(message.code, message.message);
                break;

              case "stopped":
                stoppingRef.current = false;
                setState((s) => ({ ...s, status: "disconnected" }));
                callbacksRef.current.onStopped?.(message.reason);
                // Execute pending disconnect if requested during stopping
                if (pendingDisconnectRef.current) {
                  pendingDisconnectRef.current = false;
                  if (wsRef.current) {
                    wsRef.current.close(1000, "User disconnected");
                    wsRef.current = null;
                  }
                }
                break;
            }
          } catch {
            console.error("Failed to parse WebSocket message:", event.data);
          }
        };

        ws.onerror = () => {
          setState((s) => ({
            ...s,
            status: "error",
            error: "Failed to connect to server",
          }));
          resolve(false);
        };

        ws.onclose = (event) => {
          // Check for specific close codes
          let errorMessage: string | null = null;
          if (event.code === 1006) {
            errorMessage = "Connection failed - server may be unavailable";
          } else if (event.code !== 1000 && event.code !== 1001) {
            errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
          }

          setState((s) => ({
            ...s,
            status: "disconnected",
            sessionId: null,
            error: errorMessage || s.error,
          }));
          wsRef.current = null;
        };
      } catch (err) {
        const message = err instanceof Error ? err.message : "Failed to create WebSocket connection";
        setState((s) => ({
          ...s,
          status: "error",
          error: message,
        }));
        resolve(false);
      }
    });
  }, []);

  const disconnect = useCallback(() => {
    if (stoppingRef.current) {
      // Defer disconnect until "stopped" message is received
      pendingDisconnectRef.current = true;
      return;
    }
    if (wsRef.current) {
      wsRef.current.close(1000, "User disconnected");
      wsRef.current = null;
    }
    setState((s) => ({ ...s, status: "disconnected", sessionId: null }));
  }, []);

  const sendMessage = useCallback((message: ClientMessage) => {
    if (wsRef.current?.readyState === WebSocket.OPEN) {
      wsRef.current.send(JSON.stringify(message));
    }
  }, []);

  const sendAudio = useCallback((samples: Float32Array) => {
    if (wsRef.current?.readyState === WebSocket.OPEN) {
      // Convert Float32Array to bytes (little-endian)
      const bytes = new Uint8Array(samples.length * 4);
      const view = new DataView(bytes.buffer);
      for (let i = 0; i < samples.length; i++) {
        view.setFloat32(i * 4, samples[i], true);
      }
      wsRef.current.send(bytes);
    }
  }, []);

  const startSession = useCallback(
    (sampleRate: number, channels: number = 1) => {
      sendMessage({
        type: "start",
        sampleRate,
        channels,
        encoding: "pcm32f",
      });
    },
    [sendMessage]
  );

  const stopSession = useCallback(
    (reason?: string) => {
      stoppingRef.current = true;
      sendMessage({
        type: "stop",
        reason,
      });
    },
    [sendMessage]
  );

  const clearTranscripts = useCallback(() => {
    stoppingRef.current = false;
    pendingDisconnectRef.current = false;
    setState((s) => ({ ...s, transcripts: [], error: null }));
  }, []);

  // Cleanup on unmount
  useEffect(() => {
    return () => {
      if (wsRef.current) {
        wsRef.current.close();
      }
    };
  }, []);

  return {
    ...state,
    connect,
    disconnect,
    sendAudio,
    startSession,
    stopSession,
    clearTranscripts,
    isConnected: state.status === "connected",
  };
}