import { useState, useCallback, useRef, useEffect } from "react";
import { SPEAK_ENDPOINT } from "../lib/api";
export type SpeakStatus =
| "disconnected"
| "connecting"
| "connected"
| "loading_model"
| "speaking"
| "error";
export interface SpeakWebSocketState {
status: SpeakStatus;
error: string | null;
}
export function useSpeakWebSocket() {
const [state, setState] = useState<SpeakWebSocketState>({
status: "disconnected",
error: null,
});
const wsRef = useRef<WebSocket | null>(null);
const audioContextRef = useRef<AudioContext | null>(null);
const audioQueueRef = useRef<Float32Array<ArrayBuffer>[]>([]);
const isPlayingRef = useRef(false);
const modelLoadingTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const nextPlayTimeRef = useRef(0);
// Clean up on unmount
useEffect(() => {
return () => {
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
if (audioContextRef.current) {
audioContextRef.current.close();
audioContextRef.current = null;
}
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
};
}, []);
const getAudioContext = useCallback((): AudioContext => {
if (!audioContextRef.current || audioContextRef.current.state === "closed") {
audioContextRef.current = new AudioContext({ sampleRate: 24000 });
}
return audioContextRef.current;
}, []);
const playAudioQueue = useCallback(() => {
if (isPlayingRef.current) return;
isPlayingRef.current = true;
const ctx = getAudioContext();
function scheduleNext() {
const chunk = audioQueueRef.current.shift();
if (!chunk) {
isPlayingRef.current = false;
return;
}
const buffer = ctx.createBuffer(1, chunk.length, 24000);
buffer.copyToChannel(chunk, 0);
const source = ctx.createBufferSource();
source.buffer = buffer;
source.connect(ctx.destination);
// Schedule playback at the right time to avoid gaps
const now = ctx.currentTime;
const startTime = Math.max(now, nextPlayTimeRef.current);
source.start(startTime);
nextPlayTimeRef.current = startTime + buffer.duration;
source.onended = () => {
if (audioQueueRef.current.length > 0) {
scheduleNext();
} else {
isPlayingRef.current = false;
}
};
}
scheduleNext();
}, [getAudioContext]);
const connect = useCallback((): Promise<boolean> => {
return new Promise((resolve) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
resolve(true);
return;
}
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
setState({ status: "connecting", error: null });
try {
const ws = new WebSocket(SPEAK_ENDPOINT);
ws.binaryType = "arraybuffer";
wsRef.current = ws;
ws.onopen = () => {
setState({ status: "connected", error: null });
resolve(true);
};
ws.onmessage = (event) => {
// Binary data = PCM audio chunk
if (event.data instanceof ArrayBuffer) {
// Clear model loading timer on first audio data
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
// Update status to speaking if not already
setState((s) => {
if (s.status === "loading_model" || s.status === "connected") {
return { ...s, status: "speaking" };
}
return s;
});
// Convert PCM16 LE to Float32
const pcm16 = new Int16Array(event.data);
const float32 = new Float32Array(pcm16.length);
for (let i = 0; i < pcm16.length; i++) {
float32[i] = pcm16[i] / 32768;
}
audioQueueRef.current.push(float32);
playAudioQueue();
return;
}
// Text data = JSON message
try {
const message = JSON.parse(event.data);
switch (message.type) {
case "audio_end":
// Clear model loading timer
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
// Wait for audio queue to drain, then go back to connected
// Use a short delay to let buffered audio finish
{
const checkDone = () => {
if (audioQueueRef.current.length === 0 && !isPlayingRef.current) {
setState((s) => {
if (s.status === "speaking" || s.status === "loading_model") {
return { ...s, status: "connected" };
}
return s;
});
} else {
setTimeout(checkDone, 100);
}
};
checkDone();
}
break;
case "error":
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
setState({
status: "error",
error: message.message || `Error: ${message.code}`,
});
break;
}
} catch {
console.error("Failed to parse speak WebSocket message:", event.data);
}
};
ws.onerror = () => {
setState({
status: "error",
error: "Failed to connect to speak server",
});
resolve(false);
};
ws.onclose = (event) => {
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
let errorMessage: string | null = null;
if (event.code === 1006) {
errorMessage = "Connection failed - server may be unavailable";
} else if (event.code !== 1000 && event.code !== 1001) {
errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
}
setState((s) => ({
status: "disconnected",
error: errorMessage || s.error,
}));
wsRef.current = null;
};
} catch (err) {
const message =
err instanceof Error ? err.message : "Failed to create WebSocket connection";
setState({ status: "error", error: message });
resolve(false);
}
});
}, [playAudioQueue]);
const speak = useCallback(
async (text: string) => {
if (!text.trim()) return;
// Connect if not connected
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
const connected = await connect();
if (!connected) return;
}
// Reset audio state
audioQueueRef.current = [];
isPlayingRef.current = false;
nextPlayTimeRef.current = 0;
// Resume audio context if suspended (browser autoplay policy)
const ctx = getAudioContext();
if (ctx.state === "suspended") {
await ctx.resume();
}
// Start loading timer - if no audio arrives in 2 seconds, show loading state
modelLoadingTimerRef.current = setTimeout(() => {
setState((s) => {
if (s.status === "connected" || s.status === "connecting") {
return { ...s, status: "loading_model" };
}
return s;
});
modelLoadingTimerRef.current = null;
}, 2000);
// Send speak request
wsRef.current?.send(
JSON.stringify({ type: "speak", text })
);
setState((s) => ({ ...s, error: null }));
},
[connect, getAudioContext]
);
const cancel = useCallback(() => {
// Clear audio queue
audioQueueRef.current = [];
isPlayingRef.current = false;
nextPlayTimeRef.current = 0;
// Clear model loading timer
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
// Send cancel message
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: "cancel" }));
}
setState((s) => ({
...s,
status: wsRef.current?.readyState === WebSocket.OPEN ? "connected" : "disconnected",
}));
}, []);
const disconnect = useCallback(() => {
// Clear audio queue
audioQueueRef.current = [];
isPlayingRef.current = false;
nextPlayTimeRef.current = 0;
if (modelLoadingTimerRef.current) {
clearTimeout(modelLoadingTimerRef.current);
modelLoadingTimerRef.current = null;
}
if (wsRef.current) {
// Send stop message before closing
if (wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: "stop" }));
}
wsRef.current.close(1000, "User disconnected");
wsRef.current = null;
}
setState({ status: "disconnected", error: null });
}, []);
return {
...state,
isConnected:
state.status === "connected" ||
state.status === "speaking" ||
state.status === "loading_model",
isSpeaking: state.status === "speaking",
isModelLoading: state.status === "loading_model",
speak,
cancel,
connect,
disconnect,
};
}