import { useState, useCallback, useRef, useEffect } from "react";
import { LISTEN_ENDPOINT } from "../lib/api";
import type {
ClientMessage,
ServerMessage,
TranscriptEntry,
} from "../types/messages";
export type ConnectionStatus =
| "disconnected"
| "connecting"
| "connected"
| "error";
export interface WebSocketState {
status: ConnectionStatus;
sessionId: string | null;
error: string | null;
transcripts: TranscriptEntry[];
}
interface UseWebSocketOptions {
onReady?: (sessionId: string) => void;
onTranscript?: (transcript: TranscriptEntry) => void;
onError?: (code: string, message: string) => void;
onStopped?: (reason: string) => void;
onTranscriptSaved?: (fileId: string, contractId: string) => void;
}
export function useWebSocket(options: UseWebSocketOptions = {}) {
const { onReady, onTranscript, onError, onStopped, onTranscriptSaved } = options;
const [state, setState] = useState<WebSocketState>({
status: "disconnected",
sessionId: null,
error: null,
transcripts: [],
});
const wsRef = useRef<WebSocket | null>(null);
const transcriptIdRef = useRef(0);
const stoppingRef = useRef(false);
const pendingDisconnectRef = useRef(false);
// Store callbacks in refs to avoid recreating handlers
const callbacksRef = useRef({ onReady, onTranscript, onError, onStopped, onTranscriptSaved });
useEffect(() => {
callbacksRef.current = { onReady, onTranscript, onError, onStopped, onTranscriptSaved };
}, [onReady, onTranscript, onError, onStopped, onTranscriptSaved]);
const connect = useCallback((): Promise<boolean> => {
return new Promise((resolve) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
resolve(true);
return;
}
// Close any existing connection
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
setState((s) => ({ ...s, status: "connecting", error: null }));
try {
const ws = new WebSocket(LISTEN_ENDPOINT);
wsRef.current = ws;
ws.onopen = () => {
setState((s) => ({ ...s, status: "connected", error: null }));
resolve(true);
};
ws.onmessage = (event) => {
try {
const message: ServerMessage = JSON.parse(event.data);
switch (message.type) {
case "ready":
setState((s) => ({ ...s, sessionId: message.sessionId }));
callbacksRef.current.onReady?.(message.sessionId);
break;
case "transcript": {
const entry: TranscriptEntry = {
id: `t-${++transcriptIdRef.current}`,
speaker: message.speaker,
start: message.start,
end: message.end,
text: message.text,
isFinal: message.isFinal,
};
setState((s) => {
// Find existing transcript with same speaker and overlapping timestamp
const existingIdx = s.transcripts.findIndex(
(t) =>
t.speaker === message.speaker &&
Math.abs(t.start - message.start) < 0.1
);
let newTranscripts: TranscriptEntry[];
if (existingIdx >= 0) {
// Replace existing transcript (final replaces non-final, or update in place)
newTranscripts = [...s.transcripts];
newTranscripts[existingIdx] = entry;
} else {
// No overlap - insert in time order
newTranscripts = [...s.transcripts, entry];
}
// Sort by start time to maintain chronological order
newTranscripts.sort((a, b) => a.start - b.start);
return { ...s, transcripts: newTranscripts };
});
callbacksRef.current.onTranscript?.(entry);
break;
}
case "error":
setState((s) => ({ ...s, error: message.message }));
callbacksRef.current.onError?.(message.code, message.message);
break;
case "stopped":
stoppingRef.current = false;
setState((s) => ({ ...s, status: "disconnected" }));
callbacksRef.current.onStopped?.(message.reason);
// Execute pending disconnect if requested during stopping
if (pendingDisconnectRef.current) {
pendingDisconnectRef.current = false;
if (wsRef.current) {
wsRef.current.close(1000, "User disconnected");
wsRef.current = null;
}
}
break;
case "transcriptSaved":
callbacksRef.current.onTranscriptSaved?.(message.fileId, message.contractId);
break;
}
} catch {
console.error("Failed to parse WebSocket message:", event.data);
}
};
ws.onerror = () => {
setState((s) => ({
...s,
status: "error",
error: "Failed to connect to server",
}));
resolve(false);
};
ws.onclose = (event) => {
// Check for specific close codes
let errorMessage: string | null = null;
if (event.code === 1006) {
errorMessage = "Connection failed - server may be unavailable";
} else if (event.code !== 1000 && event.code !== 1001) {
errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
}
setState((s) => ({
...s,
status: "disconnected",
sessionId: null,
error: errorMessage || s.error,
}));
wsRef.current = null;
};
} catch (err) {
const message = err instanceof Error ? err.message : "Failed to create WebSocket connection";
setState((s) => ({
...s,
status: "error",
error: message,
}));
resolve(false);
}
});
}, []);
const disconnect = useCallback(() => {
if (stoppingRef.current) {
// Defer disconnect until "stopped" message is received
pendingDisconnectRef.current = true;
return;
}
if (wsRef.current) {
wsRef.current.close(1000, "User disconnected");
wsRef.current = null;
}
setState((s) => ({ ...s, status: "disconnected", sessionId: null }));
}, []);
const sendMessage = useCallback((message: ClientMessage) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify(message));
}
}, []);
const sendAudio = useCallback((samples: Float32Array) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
// Convert Float32Array to bytes (little-endian)
const bytes = new Uint8Array(samples.length * 4);
const view = new DataView(bytes.buffer);
for (let i = 0; i < samples.length; i++) {
view.setFloat32(i * 4, samples[i], true);
}
wsRef.current.send(bytes);
}
}, []);
const startSession = useCallback(
(sampleRate: number, channels: number = 1, contractId?: string | null, authToken?: string | null) => {
sendMessage({
type: "start",
sampleRate,
channels,
encoding: "pcm32f",
...(contractId && authToken ? { contractId, authToken } : {}),
});
},
[sendMessage]
);
const stopSession = useCallback(
(reason?: string) => {
stoppingRef.current = true;
sendMessage({
type: "stop",
reason,
});
},
[sendMessage]
);
const clearTranscripts = useCallback(() => {
stoppingRef.current = false;
pendingDisconnectRef.current = false;
setState((s) => ({ ...s, transcripts: [], error: null }));
}, []);
// Cleanup on unmount
useEffect(() => {
return () => {
if (wsRef.current) {
wsRef.current.close();
}
};
}, []);
return {
...state,
connect,
disconnect,
sendAudio,
startSession,
stopSession,
clearTranscripts,
isConnected: state.status === "connected",
};
}