summaryrefslogblamecommitdiff
path: root/makima/frontend/src/hooks/useSpeakWebSocket.ts
blob: d9fb826fa411272c184edc82fe29394dd6b78eca (plain) (tree)























                                                                 
                                                                















































































































































































































































































































                                                                                          
import { useState, useCallback, useRef, useEffect } from "react";
import { SPEAK_ENDPOINT } from "../lib/api";

export type SpeakStatus =
  | "disconnected"
  | "connecting"
  | "connected"
  | "loading_model"
  | "speaking"
  | "error";

export interface SpeakWebSocketState {
  status: SpeakStatus;
  error: string | null;
}

export function useSpeakWebSocket() {
  const [state, setState] = useState<SpeakWebSocketState>({
    status: "disconnected",
    error: null,
  });

  const wsRef = useRef<WebSocket | null>(null);
  const audioContextRef = useRef<AudioContext | null>(null);
  const audioQueueRef = useRef<Float32Array<ArrayBuffer>[]>([]);
  const isPlayingRef = useRef(false);
  const modelLoadingTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
  const nextPlayTimeRef = useRef(0);

  // Clean up on unmount
  useEffect(() => {
    return () => {
      if (wsRef.current) {
        wsRef.current.close();
        wsRef.current = null;
      }
      if (audioContextRef.current) {
        audioContextRef.current.close();
        audioContextRef.current = null;
      }
      if (modelLoadingTimerRef.current) {
        clearTimeout(modelLoadingTimerRef.current);
        modelLoadingTimerRef.current = null;
      }
    };
  }, []);

  const getAudioContext = useCallback((): AudioContext => {
    if (!audioContextRef.current || audioContextRef.current.state === "closed") {
      audioContextRef.current = new AudioContext({ sampleRate: 24000 });
    }
    return audioContextRef.current;
  }, []);

  const playAudioQueue = useCallback(() => {
    if (isPlayingRef.current) return;
    isPlayingRef.current = true;

    const ctx = getAudioContext();

    function scheduleNext() {
      const chunk = audioQueueRef.current.shift();
      if (!chunk) {
        isPlayingRef.current = false;
        return;
      }

      const buffer = ctx.createBuffer(1, chunk.length, 24000);
      buffer.copyToChannel(chunk, 0);

      const source = ctx.createBufferSource();
      source.buffer = buffer;
      source.connect(ctx.destination);

      // Schedule playback at the right time to avoid gaps
      const now = ctx.currentTime;
      const startTime = Math.max(now, nextPlayTimeRef.current);
      source.start(startTime);
      nextPlayTimeRef.current = startTime + buffer.duration;

      source.onended = () => {
        if (audioQueueRef.current.length > 0) {
          scheduleNext();
        } else {
          isPlayingRef.current = false;
        }
      };
    }

    scheduleNext();
  }, [getAudioContext]);

  const connect = useCallback((): Promise<boolean> => {
    return new Promise((resolve) => {
      if (wsRef.current?.readyState === WebSocket.OPEN) {
        resolve(true);
        return;
      }

      if (wsRef.current) {
        wsRef.current.close();
        wsRef.current = null;
      }

      setState({ status: "connecting", error: null });

      try {
        const ws = new WebSocket(SPEAK_ENDPOINT);
        ws.binaryType = "arraybuffer";
        wsRef.current = ws;

        ws.onopen = () => {
          setState({ status: "connected", error: null });
          resolve(true);
        };

        ws.onmessage = (event) => {
          // Binary data = PCM audio chunk
          if (event.data instanceof ArrayBuffer) {
            // Clear model loading timer on first audio data
            if (modelLoadingTimerRef.current) {
              clearTimeout(modelLoadingTimerRef.current);
              modelLoadingTimerRef.current = null;
            }

            // Update status to speaking if not already
            setState((s) => {
              if (s.status === "loading_model" || s.status === "connected") {
                return { ...s, status: "speaking" };
              }
              return s;
            });

            // Convert PCM16 LE to Float32
            const pcm16 = new Int16Array(event.data);
            const float32 = new Float32Array(pcm16.length);
            for (let i = 0; i < pcm16.length; i++) {
              float32[i] = pcm16[i] / 32768;
            }

            audioQueueRef.current.push(float32);
            playAudioQueue();
            return;
          }

          // Text data = JSON message
          try {
            const message = JSON.parse(event.data);

            switch (message.type) {
              case "audio_end":
                // Clear model loading timer
                if (modelLoadingTimerRef.current) {
                  clearTimeout(modelLoadingTimerRef.current);
                  modelLoadingTimerRef.current = null;
                }
                // Wait for audio queue to drain, then go back to connected
                // Use a short delay to let buffered audio finish
                {
                  const checkDone = () => {
                    if (audioQueueRef.current.length === 0 && !isPlayingRef.current) {
                      setState((s) => {
                        if (s.status === "speaking" || s.status === "loading_model") {
                          return { ...s, status: "connected" };
                        }
                        return s;
                      });
                    } else {
                      setTimeout(checkDone, 100);
                    }
                  };
                  checkDone();
                }
                break;

              case "error":
                if (modelLoadingTimerRef.current) {
                  clearTimeout(modelLoadingTimerRef.current);
                  modelLoadingTimerRef.current = null;
                }
                setState({
                  status: "error",
                  error: message.message || `Error: ${message.code}`,
                });
                break;
            }
          } catch {
            console.error("Failed to parse speak WebSocket message:", event.data);
          }
        };

        ws.onerror = () => {
          setState({
            status: "error",
            error: "Failed to connect to speak server",
          });
          resolve(false);
        };

        ws.onclose = (event) => {
          if (modelLoadingTimerRef.current) {
            clearTimeout(modelLoadingTimerRef.current);
            modelLoadingTimerRef.current = null;
          }

          let errorMessage: string | null = null;
          if (event.code === 1006) {
            errorMessage = "Connection failed - server may be unavailable";
          } else if (event.code !== 1000 && event.code !== 1001) {
            errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
          }

          setState((s) => ({
            status: "disconnected",
            error: errorMessage || s.error,
          }));
          wsRef.current = null;
        };
      } catch (err) {
        const message =
          err instanceof Error ? err.message : "Failed to create WebSocket connection";
        setState({ status: "error", error: message });
        resolve(false);
      }
    });
  }, [playAudioQueue]);

  const speak = useCallback(
    async (text: string) => {
      if (!text.trim()) return;

      // Connect if not connected
      if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
        const connected = await connect();
        if (!connected) return;
      }

      // Reset audio state
      audioQueueRef.current = [];
      isPlayingRef.current = false;
      nextPlayTimeRef.current = 0;

      // Resume audio context if suspended (browser autoplay policy)
      const ctx = getAudioContext();
      if (ctx.state === "suspended") {
        await ctx.resume();
      }

      // Start loading timer - if no audio arrives in 2 seconds, show loading state
      modelLoadingTimerRef.current = setTimeout(() => {
        setState((s) => {
          if (s.status === "connected" || s.status === "connecting") {
            return { ...s, status: "loading_model" };
          }
          return s;
        });
        modelLoadingTimerRef.current = null;
      }, 2000);

      // Send speak request
      wsRef.current?.send(
        JSON.stringify({ type: "speak", text })
      );

      setState((s) => ({ ...s, error: null }));
    },
    [connect, getAudioContext]
  );

  const cancel = useCallback(() => {
    // Clear audio queue
    audioQueueRef.current = [];
    isPlayingRef.current = false;
    nextPlayTimeRef.current = 0;

    // Clear model loading timer
    if (modelLoadingTimerRef.current) {
      clearTimeout(modelLoadingTimerRef.current);
      modelLoadingTimerRef.current = null;
    }

    // Send cancel message
    if (wsRef.current?.readyState === WebSocket.OPEN) {
      wsRef.current.send(JSON.stringify({ type: "cancel" }));
    }

    setState((s) => ({
      ...s,
      status: wsRef.current?.readyState === WebSocket.OPEN ? "connected" : "disconnected",
    }));
  }, []);

  const disconnect = useCallback(() => {
    // Clear audio queue
    audioQueueRef.current = [];
    isPlayingRef.current = false;
    nextPlayTimeRef.current = 0;

    if (modelLoadingTimerRef.current) {
      clearTimeout(modelLoadingTimerRef.current);
      modelLoadingTimerRef.current = null;
    }

    if (wsRef.current) {
      // Send stop message before closing
      if (wsRef.current.readyState === WebSocket.OPEN) {
        wsRef.current.send(JSON.stringify({ type: "stop" }));
      }
      wsRef.current.close(1000, "User disconnected");
      wsRef.current = null;
    }

    setState({ status: "disconnected", error: null });
  }, []);

  return {
    ...state,
    isConnected:
      state.status === "connected" ||
      state.status === "speaking" ||
      state.status === "loading_model",
    isSpeaking: state.status === "speaking",
    isModelLoading: state.status === "loading_model",
    speak,
    cancel,
    connect,
    disconnect,
  };
}