summaryrefslogtreecommitdiff
path: root/makima
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-28 02:54:17 +0000
committerGitHub <noreply@github.com>2026-01-28 02:54:17 +0000
commiteabd1304cce0e053cd32ec910d2f0ea429e8af14 (patch)
treefca3b08810a1dc0c0c610a8189a466cc23d5c547 /makima
parentc618174e60e4632d36d7352d83399508c72b2f42 (diff)
downloadsoryu-eabd1304cce0e053cd32ec910d2f0ea429e8af14.tar.gz
soryu-eabd1304cce0e053cd32ec910d2f0ea429e8af14.zip
Add Qwen3-TTS streaming endpoint for voice synthesis (#40)
* Task completion checkpoint * Task completion checkpoint * Task completion checkpoint * Add Qwen3-TTS research document for live TTS replacement Research findings for replacing Chatterbox TTS with Qwen3-TTS-12Hz-0.6B-Base: - Current TTS: Chatterbox-Turbo-ONNX with batch-only generation, no streaming - Qwen3-TTS: 97ms end-to-end latency, streaming support, 3-second voice cloning - Voice cloning: Requires 3s reference audio + transcript (Makima voice planned) - Integration: Python service with WebSocket bridge (no ONNX export available) - Languages: 10 supported including English and Japanese Document includes: - Current architecture analysis (makima/src/tts.rs) - Qwen3-TTS capabilities and requirements - Feasibility assessment for live/streaming TTS - Audio clip requirements for voice cloning - Preliminary technical approach with architecture diagrams Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:11:15 UTC * Add Qwen3-TTS research documentation Comprehensive research on replacing Chatterbox TTS with Qwen3-TTS-12Hz-0.6B-Base: - Current TTS implementation analysis (Chatterbox-Turbo-ONNX in makima/src/tts.rs) - Qwen3-TTS capabilities: 97ms streaming latency, voice cloning with 3s reference - Cross-lingual support: Japanese voice (Makima/Tomori Kusunoki) speaking English - Python microservice architecture recommendation (FastAPI + WebSocket) - Implementation phases and technical approach - Hardware requirements and dependencies Key findings: - Live/streaming TTS is highly feasible with 97ms latency - Voice cloning fully supported with 0.95 speaker similarity - Recommended: Python microservice with WebSocket streaming Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add comprehensive Qwen3-TTS integration specification This specification document defines the complete integration of Qwen3-TTS-12Hz-0.6B-Base as a replacement for the existing Chatterbox-Turbo TTS implementation. The document covers: ## Functional Requirements - WebSocket endpoint /api/v1/speak for streaming TTS - Voice cloning with default Makima voice (Japanese VA speaking English) - Support for custom voice references - Detailed client-to-server and server-to-client message protocols - Integration with Listen page for bidirectional speech ## Non-Functional Requirements - Latency targets: < 200ms first audio byte - Audio quality: 24kHz, mono, PCM16/PCM32f - Hardware requirements: CUDA GPU with 4-8GB VRAM - Scalability: 10 concurrent sessions per GPU ## Architecture Specification - Python TTS microservice with FastAPI/WebSocket - Rust proxy endpoint in makima server - Voice prompt caching mechanism (LRU cache) - Error handling and recovery strategies ## API Contract - Complete WebSocket message format definitions (TypeScript) - Error codes and responses (TTS_UNAVAILABLE, SYNTHESIS_ERROR, etc.) - Session state machine and lifecycle management ## Voice Asset Requirements - Makima voice clip specifications (5-10s WAV, transcript required) - Storage location: models/voices/makima/ - Metadata format for voice management ## Testing Strategy - Unit tests for Python TTS service and Rust proxy - Integration tests for WebSocket flow - Latency benchmarks with performance targets - Test data fixtures for various text lengths Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Qwen3-TTS implementation plan Comprehensive implementation plan for replacing Chatterbox-TTS with Qwen3-TTS streaming TTS service, including: - Task breakdown with estimated hours for each phase - Phase 1: Python TTS microservice (FastAPI, WebSocket) - Phase 2: Rust proxy integration (speak.rs, tts_client.rs) - Detailed file changes and new module structure - Testing plan with unit, integration, and latency benchmarks - Risk assessment with mitigation strategies - Success criteria for each phase Based on specification in docs/specs/qwen3-tts-spec.md Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add author and research references to TTS implementation plan Add links to research documentation and author attribution. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:25:06 UTC * Add Python TTS service project structure (Phase 1.1-1.3) Create the initial makima-tts Python service directory structure with: - pyproject.toml with FastAPI, Qwen-TTS, and torch dependencies - config.py with pydantic-settings TTSConfig class - models.py with Pydantic message models (Start, Speak, Stop, Ready, etc.) This implements tasks P1.1, P1.2, and P1.3 from the Qwen3-TTS implementation plan. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add TTS engine and voice manager for Qwen3-TTS (Phase 1.4-1.5) Implement core TTS functionality: - tts_engine.py: Qwen3-TTS wrapper with streaming audio chunk generation - voice_manager.py: Voice prompt caching with LRU eviction and TTL support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:30:06 UTC * Add TTS proxy client and message types (Phase 2.1, 2.2, 2.4) - Add tts_client.rs with TtsConfig, TtsCircuitBreaker, TtsError, TtsProxyClient, and TtsConnection structs for WebSocket proxying - Add TTS message types to messages.rs (TtsAudioEncoding, TtsPriority, TtsStartMessage, TtsSpeakMessage, TtsStopMessage, TtsClientMessage, TtsReadyMessage, TtsAudioChunkMessage, TtsCompleteMessage, TtsErrorMessage, TtsStoppedMessage, TtsServerMessage) - Export tts_client module from server mod.rs - tokio-tungstenite already present in Cargo.toml Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add TTS WebSocket handler and route (Phase 2.3, 2.5, 2.6) - Create speak.rs WebSocket handler that proxies to Python TTS service - Add TtsState fields (tts_client, tts_config) to AppState - Add with_tts() builder and is_tts_healthy() methods to AppState - Register /api/v1/speak route in the router - Add speak module export in handlers/mod.rs The handler forwards WebSocket messages bidirectionally between the client and the Python TTS microservice with proper error handling. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Makima voice profile assets for TTS voice cloning Creates the voice assets directory structure with: - manifest.json containing voice configuration (voice_id, speaker, language, reference audio path, and Japanese transcript placeholder) - README.md with instructions for obtaining voice reference audio Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Rust-native Qwen3-TTS integration research document Research findings for integrating Qwen3-TTS-12Hz-0.6B-Base directly into the makima Rust codebase without Python. Key conclusions: - ONNX export is not viable (unsupported architecture) - Candle (HF Rust ML framework) is the recommended approach - Model weights available in safetensors format (2.52GB total) - Three components needed: LM backbone, code predictor, speech tokenizer - Crane project has Qwen3-TTS as highest priority (potential upstream) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 11:21:43 UTC * [WIP] Heartbeat checkpoint - 2026-01-27 11:24:19 UTC * [WIP] Heartbeat checkpoint - 2026-01-27 11:26:43 UTC * feat: implement Rust-native Qwen3-TTS using candle framework Replace monolithic tts.rs with modular tts/ directory structure: - tts/mod.rs: TtsEngine trait, TtsEngineFactory, shared types (AudioChunk, TtsError), and utility functions (save_wav, resample, argmax) - tts/chatterbox.rs: existing ONNX-based ChatterboxTTS adapted to implement TtsEngine trait with Mutex-wrapped sessions for Send+Sync - tts/qwen3/mod.rs: Qwen3Tts entry point with HuggingFace model loading - tts/qwen3/config.rs: Qwen3TtsConfig parsing from HF config.json - tts/qwen3/model.rs: 28-layer Qwen3 transformer with RoPE, GQA (16 heads, 8 KV heads), SiLU MLP, RMS norm, and KV cache - tts/qwen3/code_predictor.rs: 5-layer MTP module predicting 16 codebooks - tts/qwen3/speech_tokenizer.rs: ConvNet encoder/decoder with 16-layer RVQ - tts/qwen3/generate.rs: autoregressive generation loop with streaming support Add candle-core, candle-nn, candle-transformers, safetensors to Cargo.toml. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: integrate TTS engine into speak WebSocket handler - Update speak.rs handler to use TTS engine directly from SharedState instead of returning a stub "not implemented" error - Add TtsEngine (OnceCell lazy-loaded) to AppState in state.rs with get_tts_engine() method for lazy initialization on first connection - Implement full WebSocket protocol: client sends JSON speak/cancel/stop messages, server streams binary PCM audio chunks and audio_end signals - Create voices/makima/manifest.json for Makima voice profile configuration - All files compile successfully with zero errors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: add /speak TTS page with WebSocket audio playback Add a new /speak frontend page for text-to-speech via WebSocket. The page accepts text input and streams synthesized PCM audio through the Web Audio API. Includes model loading indicator, cancel support, and connection status. Also adds a loading bar to the listen page ControlPanel during WebSocket connection. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'makima')
-rw-r--r--makima/Cargo.toml6
-rw-r--r--makima/frontend/src/components/listen/ControlPanel.tsx36
-rw-r--r--makima/frontend/src/hooks/useSpeakWebSocket.ts329
-rw-r--r--makima/frontend/src/index.css6
-rw-r--r--makima/frontend/src/lib/api.ts1
-rw-r--r--makima/frontend/src/main.tsx9
-rw-r--r--makima/frontend/src/routes/listen.tsx1
-rw-r--r--makima/frontend/src/routes/speak.tsx159
-rw-r--r--makima/src/main.rs14
-rw-r--r--makima/src/server/handlers/mod.rs1
-rw-r--r--makima/src/server/handlers/speak.rs274
-rw-r--r--makima/src/server/messages.rs161
-rw-r--r--makima/src/server/mod.rs3
-rw-r--r--makima/src/server/state.rs22
-rw-r--r--makima/src/tts/chatterbox.rs (renamed from makima/src/tts.rs)391
-rw-r--r--makima/src/tts/mod.rs281
-rw-r--r--makima/src/tts/qwen3/code_predictor.rs261
-rw-r--r--makima/src/tts/qwen3/config.rs271
-rw-r--r--makima/src/tts/qwen3/generate.rs426
-rw-r--r--makima/src/tts/qwen3/mod.rs287
-rw-r--r--makima/src/tts/qwen3/model.rs581
-rw-r--r--makima/src/tts/qwen3/speech_tokenizer.rs612
22 files changed, 3868 insertions, 264 deletions
diff --git a/makima/Cargo.toml b/makima/Cargo.toml
index 950c123..b6b12dd 100644
--- a/makima/Cargo.toml
+++ b/makima/Cargo.toml
@@ -17,6 +17,12 @@ tokenizers = "0.21"
hf-hub = "0.4"
ndarray = "0.16"
+# Candle ML framework (Qwen3-TTS native inference)
+candle-core = "0.8"
+candle-nn = "0.8"
+candle-transformers = "0.8"
+safetensors = "0.4"
+
# Web server
axum = { version = "0.8", features = ["ws", "multipart"] }
tokio = { version = "1.0", features = ["full", "signal", "process"] }
diff --git a/makima/frontend/src/components/listen/ControlPanel.tsx b/makima/frontend/src/components/listen/ControlPanel.tsx
index f0e5702..f482ec4 100644
--- a/makima/frontend/src/components/listen/ControlPanel.tsx
+++ b/makima/frontend/src/components/listen/ControlPanel.tsx
@@ -1,6 +1,7 @@
import { useState } from "react";
import { Logo } from "../Logo";
import type { MicrophoneStatus } from "../../hooks/useMicrophone";
+import type { ConnectionStatus } from "../../hooks/useWebSocket";
import { ContractPickerModal } from "./ContractPickerModal";
export interface ContractOption {
@@ -22,6 +23,8 @@ interface ControlPanelProps {
selectedContractId: string | null;
onContractChange: (contractId: string | null) => void;
contractsLoading?: boolean;
+ // Connection status for loading state
+ connectionStatus?: ConnectionStatus;
}
function getStatusText(isListening: boolean, micStatus: MicrophoneStatus): string {
@@ -54,6 +57,7 @@ export function ControlPanel({
selectedContractId,
onContractChange,
contractsLoading,
+ connectionStatus,
}: ControlPanelProps) {
const [isModalOpen, setIsModalOpen] = useState(false);
const statusText = getStatusText(isListening, micStatus);
@@ -121,18 +125,36 @@ export function ControlPanel({
{/* Connection status */}
<div
- className={`inline-flex items-center gap-1.5 px-2 py-1 border ${
+ className={`inline-flex flex-col gap-1 px-2 py-1 border ${
isConnected
? "border-[#3f6fb3] text-[#75aafc]"
+ : connectionStatus === "connecting"
+ ? "border-[#3f6fb3] text-[#9bc3ff]"
: "border-[rgba(117,170,252,0.25)] text-[#9bc3ff]"
}`}
>
- <span
- className={`w-1.5 h-1.5 rounded-full ${
- isConnected ? "bg-[#75aafc]" : "bg-[#3f6fb3]"
- }`}
- />
- {isConnected ? "CONNECTED" : "DISCONNECTED"}
+ <div className="inline-flex items-center gap-1.5">
+ <span
+ className={`w-1.5 h-1.5 rounded-full ${
+ isConnected ? "bg-[#75aafc]" : "bg-[#3f6fb3]"
+ }`}
+ />
+ {isConnected
+ ? "CONNECTED"
+ : connectionStatus === "connecting"
+ ? "LOADING MODELS..."
+ : "DISCONNECTED"}
+ </div>
+ {connectionStatus === "connecting" && (
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-1/3 bg-[#75aafc]"
+ style={{
+ animation: "loading-slide 1.5s ease-in-out infinite",
+ }}
+ />
+ </div>
+ )}
</div>
</div>
diff --git a/makima/frontend/src/hooks/useSpeakWebSocket.ts b/makima/frontend/src/hooks/useSpeakWebSocket.ts
new file mode 100644
index 0000000..3ef8851
--- /dev/null
+++ b/makima/frontend/src/hooks/useSpeakWebSocket.ts
@@ -0,0 +1,329 @@
+import { useState, useCallback, useRef, useEffect } from "react";
+import { SPEAK_ENDPOINT } from "../lib/api";
+
+export type SpeakStatus =
+ | "disconnected"
+ | "connecting"
+ | "connected"
+ | "loading_model"
+ | "speaking"
+ | "error";
+
+export interface SpeakWebSocketState {
+ status: SpeakStatus;
+ error: string | null;
+}
+
+export function useSpeakWebSocket() {
+ const [state, setState] = useState<SpeakWebSocketState>({
+ status: "disconnected",
+ error: null,
+ });
+
+ const wsRef = useRef<WebSocket | null>(null);
+ const audioContextRef = useRef<AudioContext | null>(null);
+ const audioQueueRef = useRef<Float32Array[]>([]);
+ const isPlayingRef = useRef(false);
+ const modelLoadingTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
+ const nextPlayTimeRef = useRef(0);
+
+ // Clean up on unmount
+ useEffect(() => {
+ return () => {
+ if (wsRef.current) {
+ wsRef.current.close();
+ wsRef.current = null;
+ }
+ if (audioContextRef.current) {
+ audioContextRef.current.close();
+ audioContextRef.current = null;
+ }
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ };
+ }, []);
+
+ const getAudioContext = useCallback((): AudioContext => {
+ if (!audioContextRef.current || audioContextRef.current.state === "closed") {
+ audioContextRef.current = new AudioContext({ sampleRate: 24000 });
+ }
+ return audioContextRef.current;
+ }, []);
+
+ const playAudioQueue = useCallback(() => {
+ if (isPlayingRef.current) return;
+ isPlayingRef.current = true;
+
+ const ctx = getAudioContext();
+
+ function scheduleNext() {
+ const chunk = audioQueueRef.current.shift();
+ if (!chunk) {
+ isPlayingRef.current = false;
+ return;
+ }
+
+ const buffer = ctx.createBuffer(1, chunk.length, 24000);
+ buffer.copyToChannel(chunk, 0);
+
+ const source = ctx.createBufferSource();
+ source.buffer = buffer;
+ source.connect(ctx.destination);
+
+ // Schedule playback at the right time to avoid gaps
+ const now = ctx.currentTime;
+ const startTime = Math.max(now, nextPlayTimeRef.current);
+ source.start(startTime);
+ nextPlayTimeRef.current = startTime + buffer.duration;
+
+ source.onended = () => {
+ if (audioQueueRef.current.length > 0) {
+ scheduleNext();
+ } else {
+ isPlayingRef.current = false;
+ }
+ };
+ }
+
+ scheduleNext();
+ }, [getAudioContext]);
+
+ const connect = useCallback((): Promise<boolean> => {
+ return new Promise((resolve) => {
+ if (wsRef.current?.readyState === WebSocket.OPEN) {
+ resolve(true);
+ return;
+ }
+
+ if (wsRef.current) {
+ wsRef.current.close();
+ wsRef.current = null;
+ }
+
+ setState({ status: "connecting", error: null });
+
+ try {
+ const ws = new WebSocket(SPEAK_ENDPOINT);
+ ws.binaryType = "arraybuffer";
+ wsRef.current = ws;
+
+ ws.onopen = () => {
+ setState({ status: "connected", error: null });
+ resolve(true);
+ };
+
+ ws.onmessage = (event) => {
+ // Binary data = PCM audio chunk
+ if (event.data instanceof ArrayBuffer) {
+ // Clear model loading timer on first audio data
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ // Update status to speaking if not already
+ setState((s) => {
+ if (s.status === "loading_model" || s.status === "connected") {
+ return { ...s, status: "speaking" };
+ }
+ return s;
+ });
+
+ // Convert PCM16 LE to Float32
+ const pcm16 = new Int16Array(event.data);
+ const float32 = new Float32Array(pcm16.length);
+ for (let i = 0; i < pcm16.length; i++) {
+ float32[i] = pcm16[i] / 32768;
+ }
+
+ audioQueueRef.current.push(float32);
+ playAudioQueue();
+ return;
+ }
+
+ // Text data = JSON message
+ try {
+ const message = JSON.parse(event.data);
+
+ switch (message.type) {
+ case "audio_end":
+ // Clear model loading timer
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ // Wait for audio queue to drain, then go back to connected
+ // Use a short delay to let buffered audio finish
+ {
+ const checkDone = () => {
+ if (audioQueueRef.current.length === 0 && !isPlayingRef.current) {
+ setState((s) => {
+ if (s.status === "speaking" || s.status === "loading_model") {
+ return { ...s, status: "connected" };
+ }
+ return s;
+ });
+ } else {
+ setTimeout(checkDone, 100);
+ }
+ };
+ checkDone();
+ }
+ break;
+
+ case "error":
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ setState({
+ status: "error",
+ error: message.message || `Error: ${message.code}`,
+ });
+ break;
+ }
+ } catch {
+ console.error("Failed to parse speak WebSocket message:", event.data);
+ }
+ };
+
+ ws.onerror = () => {
+ setState({
+ status: "error",
+ error: "Failed to connect to speak server",
+ });
+ resolve(false);
+ };
+
+ ws.onclose = (event) => {
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ let errorMessage: string | null = null;
+ if (event.code === 1006) {
+ errorMessage = "Connection failed - server may be unavailable";
+ } else if (event.code !== 1000 && event.code !== 1001) {
+ errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
+ }
+
+ setState((s) => ({
+ status: "disconnected",
+ error: errorMessage || s.error,
+ }));
+ wsRef.current = null;
+ };
+ } catch (err) {
+ const message =
+ err instanceof Error ? err.message : "Failed to create WebSocket connection";
+ setState({ status: "error", error: message });
+ resolve(false);
+ }
+ });
+ }, [playAudioQueue]);
+
+ const speak = useCallback(
+ async (text: string) => {
+ if (!text.trim()) return;
+
+ // Connect if not connected
+ if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
+ const connected = await connect();
+ if (!connected) return;
+ }
+
+ // Reset audio state
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ // Resume audio context if suspended (browser autoplay policy)
+ const ctx = getAudioContext();
+ if (ctx.state === "suspended") {
+ await ctx.resume();
+ }
+
+ // Start loading timer - if no audio arrives in 2 seconds, show loading state
+ modelLoadingTimerRef.current = setTimeout(() => {
+ setState((s) => {
+ if (s.status === "connected" || s.status === "connecting") {
+ return { ...s, status: "loading_model" };
+ }
+ return s;
+ });
+ modelLoadingTimerRef.current = null;
+ }, 2000);
+
+ // Send speak request
+ wsRef.current?.send(
+ JSON.stringify({ type: "speak", text })
+ );
+
+ setState((s) => ({ ...s, error: null }));
+ },
+ [connect, getAudioContext]
+ );
+
+ const cancel = useCallback(() => {
+ // Clear audio queue
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ // Clear model loading timer
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ // Send cancel message
+ if (wsRef.current?.readyState === WebSocket.OPEN) {
+ wsRef.current.send(JSON.stringify({ type: "cancel" }));
+ }
+
+ setState((s) => ({
+ ...s,
+ status: wsRef.current?.readyState === WebSocket.OPEN ? "connected" : "disconnected",
+ }));
+ }, []);
+
+ const disconnect = useCallback(() => {
+ // Clear audio queue
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ if (wsRef.current) {
+ // Send stop message before closing
+ if (wsRef.current.readyState === WebSocket.OPEN) {
+ wsRef.current.send(JSON.stringify({ type: "stop" }));
+ }
+ wsRef.current.close(1000, "User disconnected");
+ wsRef.current = null;
+ }
+
+ setState({ status: "disconnected", error: null });
+ }, []);
+
+ return {
+ ...state,
+ isConnected:
+ state.status === "connected" ||
+ state.status === "speaking" ||
+ state.status === "loading_model",
+ isSpeaking: state.status === "speaking",
+ isModelLoading: state.status === "loading_model",
+ speak,
+ cancel,
+ connect,
+ disconnect,
+ };
+}
diff --git a/makima/frontend/src/index.css b/makima/frontend/src/index.css
index 5c08006..f29873b 100644
--- a/makima/frontend/src/index.css
+++ b/makima/frontend/src/index.css
@@ -64,6 +64,12 @@ body {
background: rgba(117, 170, 252, 0.35);
}
+/* Loading bar animation for indeterminate progress */
+@keyframes loading-slide {
+ 0% { transform: translateX(-100%); }
+ 100% { transform: translateX(300%); }
+}
+
/* Grid overlay */
.grid-overlay {
position: fixed;
diff --git a/makima/frontend/src/lib/api.ts b/makima/frontend/src/lib/api.ts
index 4390b20..ca04ce7 100644
--- a/makima/frontend/src/lib/api.ts
+++ b/makima/frontend/src/lib/api.ts
@@ -99,6 +99,7 @@ async function authFetch(url: string, options: RequestInit = {}): Promise<Respon
});
}
export const LISTEN_ENDPOINT = `${WS_BASE}/api/v1/listen`;
+export const SPEAK_ENDPOINT = `${WS_BASE}/api/v1/speak`;
export const FILE_SUBSCRIBE_ENDPOINT = `${WS_BASE}/api/v1/files/subscribe`;
export const TASK_SUBSCRIBE_ENDPOINT = `${WS_BASE}/api/v1/mesh/tasks/subscribe`;
diff --git a/makima/frontend/src/main.tsx b/makima/frontend/src/main.tsx
index 383b732..ef1ba5c 100644
--- a/makima/frontend/src/main.tsx
+++ b/makima/frontend/src/main.tsx
@@ -19,6 +19,7 @@ import LoginPage from "./routes/login";
import SettingsPage from "./routes/settings";
import ContractFilePage from "./routes/contract-file";
import TemplatesPage from "./routes/templates";
+import SpeakPage from "./routes/speak";
createRoot(document.getElementById("root")!).render(
<StrictMode>
@@ -135,6 +136,14 @@ createRoot(document.getElementById("root")!).render(
</ProtectedRoute>
}
/>
+ <Route
+ path="/speak"
+ element={
+ <ProtectedRoute>
+ <SpeakPage />
+ </ProtectedRoute>
+ }
+ />
</Routes>
</BrowserRouter>
</SupervisorQuestionsProvider>
diff --git a/makima/frontend/src/routes/listen.tsx b/makima/frontend/src/routes/listen.tsx
index 55cf7e6..8af538e 100644
--- a/makima/frontend/src/routes/listen.tsx
+++ b/makima/frontend/src/routes/listen.tsx
@@ -207,6 +207,7 @@ export default function ListenPage() {
selectedContractId={selectedContractId}
onContractChange={setSelectedContractId}
contractsLoading={contractsLoading}
+ connectionStatus={ws.status}
/>
</div>
</main>
diff --git a/makima/frontend/src/routes/speak.tsx b/makima/frontend/src/routes/speak.tsx
new file mode 100644
index 0000000..c4692ff
--- /dev/null
+++ b/makima/frontend/src/routes/speak.tsx
@@ -0,0 +1,159 @@
+import { useState, useCallback } from "react";
+import { Masthead } from "../components/Masthead";
+import { useSpeakWebSocket } from "../hooks/useSpeakWebSocket";
+
+export default function SpeakPage() {
+ const [text, setText] = useState("");
+ const tts = useSpeakWebSocket();
+
+ const handleSpeak = useCallback(() => {
+ if (!text.trim()) return;
+ tts.speak(text);
+ }, [text, tts]);
+
+ const handleCancel = useCallback(() => {
+ tts.cancel();
+ }, [tts]);
+
+ const handleKeyDown = useCallback(
+ (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
+ // Ctrl/Cmd + Enter to speak
+ if ((e.ctrlKey || e.metaKey) && e.key === "Enter") {
+ e.preventDefault();
+ handleSpeak();
+ }
+ },
+ [handleSpeak]
+ );
+
+ const statusLabel = (() => {
+ switch (tts.status) {
+ case "disconnected":
+ return "DISCONNECTED";
+ case "connecting":
+ return "CONNECTING...";
+ case "connected":
+ return "CONNECTED";
+ case "loading_model":
+ return "LOADING TTS MODEL...";
+ case "speaking":
+ return "SPEAKING";
+ case "error":
+ return "ERROR";
+ default:
+ return "IDLE";
+ }
+ })();
+
+ const statusColor = (() => {
+ switch (tts.status) {
+ case "connected":
+ case "speaking":
+ return "border-[#3f6fb3] text-[#75aafc]";
+ case "error":
+ return "border-red-400/50 text-red-400";
+ default:
+ return "border-[rgba(117,170,252,0.25)] text-[#9bc3ff]";
+ }
+ })();
+
+ const dotColor = (() => {
+ switch (tts.status) {
+ case "connected":
+ case "speaking":
+ return "bg-[#75aafc]";
+ case "error":
+ return "bg-red-400";
+ default:
+ return "bg-[#3f6fb3]";
+ }
+ })();
+
+ return (
+ <div className="relative z-10 h-screen flex flex-col overflow-hidden">
+ <Masthead showTicker={false} showNav />
+
+ <main className="flex-1 flex flex-col items-center justify-center p-4 md:p-8 gap-6 min-h-0 overflow-auto">
+ {/* Text input area */}
+ <div className="w-full max-w-2xl">
+ <textarea
+ value={text}
+ onChange={(e) => setText(e.target.value)}
+ onKeyDown={handleKeyDown}
+ placeholder="Enter text to speak..."
+ disabled={tts.isSpeaking || tts.isModelLoading}
+ className="w-full h-48 p-4 font-mono text-sm text-[#dbe7ff] bg-[#0d1b2d] border border-[#0f3c78] focus:border-[#3f6fb3] focus:outline-none placeholder-[#3f6fb3] resize-none transition-colors disabled:opacity-50"
+ />
+ <div className="mt-1 text-right font-mono text-xs text-[#3f6fb3]">
+ Ctrl+Enter to speak
+ </div>
+ </div>
+
+ {/* Controls row */}
+ <div className="w-full max-w-2xl flex items-center gap-4">
+ {/* Speak / Cancel button */}
+ {tts.isSpeaking || tts.isModelLoading ? (
+ <button
+ onClick={handleCancel}
+ className="px-6 py-2 font-mono text-sm text-red-400 bg-[#0d1b2d] border border-red-400/50 hover:border-red-400 transition-colors uppercase tracking-wide"
+ >
+ Cancel
+ </button>
+ ) : (
+ <button
+ onClick={handleSpeak}
+ disabled={!text.trim()}
+ className="px-6 py-2 font-mono text-sm text-[#dbe7ff] bg-[#0d1b2d] border border-[#0f3c78] hover:border-[#3f6fb3] transition-colors uppercase tracking-wide disabled:opacity-50 disabled:cursor-not-allowed"
+ >
+ Speak
+ </button>
+ )}
+
+ {/* Status indicator */}
+ <div
+ className={`inline-flex items-center gap-1.5 px-2 py-1 border font-mono text-xs tracking-wide uppercase ${statusColor}`}
+ >
+ <span className={`w-1.5 h-1.5 rounded-full ${dotColor}`} />
+ {statusLabel}
+ </div>
+ </div>
+
+ {/* Loading bar (indeterminate) */}
+ {tts.isModelLoading && (
+ <div className="w-full max-w-2xl">
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-1/3 bg-[#75aafc]"
+ style={{
+ animation: "loading-slide 1.5s ease-in-out infinite",
+ }}
+ />
+ </div>
+ <div className="mt-2 font-mono text-xs text-[#9bc3ff] text-center tracking-wide uppercase">
+ Loading TTS model... This may take a moment on first use.
+ </div>
+ </div>
+ )}
+
+ {/* Speaking animation bar */}
+ {tts.isSpeaking && (
+ <div className="w-full max-w-2xl">
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-full bg-[#75aafc] animate-pulse"
+ />
+ </div>
+ </div>
+ )}
+
+ {/* Error display */}
+ {tts.error && (
+ <div className="w-full max-w-2xl font-mono text-xs text-red-400 text-center px-4 py-2 border border-red-400/50 bg-red-400/10">
+ {tts.error}
+ </div>
+ )}
+ </main>
+
+ </div>
+ );
+}
diff --git a/makima/src/main.rs b/makima/src/main.rs
index 2348b23..1d87106 100644
--- a/makima/src/main.rs
+++ b/makima/src/main.rs
@@ -7,21 +7,9 @@ pub mod tts;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Loading ChatterboxTTS...");
- let mut tts = ChatterboxTTS::from_pretrained(None)?;
+ let tts = ChatterboxTTS::from_pretrained(None)?;
println!("Model loaded successfully!");
- // // Voice cloning using existing audio file
- // println!("Generating TTS with voice cloning...");
- // let audio = tts.generate_tts_with_voice(
- // "Hello, this is a test of the voice cloning system.",
- // Path::new("audio.wav")
- // )?;
- //
- // println!("Generated {} samples", audio.len());
- // save_wav(&audio, Path::new("output.wav"))?;
- // println!("Saved to output.wav");
-
-
// Load reference audio from mp3
println!("Loading reference audio...");
let reference = audio::to_16k_mono_from_path(Path::new("audio.mp3"))?;
diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs
index b496922..8207399 100644
--- a/makima/src/server/handlers/mod.rs
+++ b/makima/src/server/handlers/mod.rs
@@ -17,6 +17,7 @@ pub mod mesh_red_team;
pub mod mesh_supervisor;
pub mod mesh_ws;
pub mod repository_history;
+pub mod speak;
pub mod templates;
pub mod transcript_analysis;
pub mod users;
diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs
new file mode 100644
index 0000000..75e7780
--- /dev/null
+++ b/makima/src/server/handlers/speak.rs
@@ -0,0 +1,274 @@
+//! WebSocket handler for TTS streaming (direct in-process inference).
+//!
+//! This module implements the `/api/v1/speak` endpoint which performs
+//! text-to-speech synthesis directly using the candle-based TTS engine.
+//! No external Python service or proxy — the model runs in-process.
+//!
+//! ## Architecture
+//!
+//! The speak handler will:
+//! 1. Accept a WebSocket connection from the client
+//! 2. Lazily load the TTS model (candle) on first request
+//! 3. Parse JSON control messages (start, speak, stop, cancel)
+//! 4. Run inference directly and stream audio chunks back
+//!
+//! See `makima/src/tts/` for the TTS engine implementation.
+//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
+
+use axum::{
+ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
+ response::Response,
+};
+use futures::{SinkExt, StreamExt};
+use serde::Deserialize;
+use uuid::Uuid;
+
+use crate::server::state::SharedState;
+
+/// Client-to-server control messages.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+enum ClientMessage {
+ /// Request speech synthesis for the given text.
+ Speak {
+ text: String,
+ /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection.
+ #[serde(default)]
+ #[allow(dead_code)]
+ voice: Option<String>,
+ },
+ /// Cancel any in-progress synthesis.
+ Cancel,
+ /// Graceful close.
+ Stop,
+}
+
+/// WebSocket upgrade handler for TTS streaming.
+///
+/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
+/// The TTS model runs directly in-process using candle — no external service.
+#[utoipa::path(
+ get,
+ path = "/api/v1/speak",
+ responses(
+ (status = 101, description = "WebSocket connection established"),
+ (status = 503, description = "TTS engine not available"),
+ ),
+ tag = "Speak"
+)]
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+/// Handle TTS WebSocket session with direct in-process inference.
+///
+/// Protocol:
+/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
+/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
+/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
+/// - Server sends JSON `{ "type": "error", ... }` on failures
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+ tracing::info!(session_id = %session_id, "New TTS WebSocket connection");
+
+ let (mut sender, mut receiver) = socket.split();
+
+ // Process incoming messages
+ while let Some(msg) = receiver.next().await {
+ let msg = match msg {
+ Ok(m) => m,
+ Err(e) => {
+ tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
+ break;
+ }
+ };
+
+ match msg {
+ Message::Text(text) => {
+ let client_msg: ClientMessage = match serde_json::from_str(&text) {
+ Ok(m) => m,
+ Err(e) => {
+ let _ = send_error(
+ &mut sender,
+ "INVALID_MESSAGE",
+ &format!("Failed to parse message: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ match client_msg {
+ ClientMessage::Speak { text, .. } => {
+ tracing::info!(
+ session_id = %session_id,
+ text_len = text.len(),
+ "TTS speak request"
+ );
+
+ // Get or lazily load the TTS engine
+ let engine = match state.get_tts_engine().await {
+ Ok(e) => e,
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "Failed to load TTS engine"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_LOAD_FAILED",
+ &format!("Failed to load TTS engine: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ if !engine.is_ready() {
+ let _ = send_error(
+ &mut sender,
+ "TTS_NOT_READY",
+ "TTS engine is not ready yet",
+ )
+ .await;
+ continue;
+ }
+
+ // Run TTS inference (no voice reference for now — uses default)
+ match engine.generate(&text, None, None).await {
+ Ok(chunks) => {
+ for chunk in &chunks {
+ // Send binary PCM audio data
+ let pcm_bytes = chunk.to_pcm16_bytes();
+ if sender
+ .send(Message::Binary(pcm_bytes.into()))
+ .await
+ .is_err()
+ {
+ tracing::warn!(
+ session_id = %session_id,
+ "Failed to send audio chunk — client disconnected"
+ );
+ return;
+ }
+ }
+
+ // Signal end of audio
+ let end_msg = serde_json::json!({
+ "type": "audio_end",
+ "sample_rate": engine.sample_rate(),
+ "format": "pcm_s16le",
+ "channels": 1,
+ });
+ let _ = sender
+ .send(Message::Text(end_msg.to_string().into()))
+ .await;
+ }
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "TTS inference failed"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_INFERENCE_FAILED",
+ &format!("TTS inference failed: {e}"),
+ )
+ .await;
+ }
+ }
+ }
+ ClientMessage::Cancel => {
+ tracing::info!(session_id = %session_id, "TTS cancel requested");
+ // TODO: support cancellation of in-progress inference
+ }
+ ClientMessage::Stop => {
+ tracing::info!(session_id = %session_id, "TTS stop requested, closing");
+ break;
+ }
+ }
+ }
+ Message::Close(_) => {
+ tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
+ break;
+ }
+ _ => {
+ // Ignore ping/pong/binary from client
+ }
+ }
+ }
+
+ tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
+}
+
+/// Send an error message to the client.
+async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
+where
+ S: SinkExt<Message> + Unpin,
+ <S as futures::Sink<Message>>::Error: std::error::Error,
+{
+ let error_msg = serde_json::json!({
+ "type": "error",
+ "code": code,
+ "message": message,
+ "recoverable": false
+ });
+
+ sender
+ .send(Message::Text(error_msg.to_string().into()))
+ .await
+ .ok();
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_error_message_format() {
+ let error = serde_json::json!({
+ "type": "error",
+ "code": "TEST_ERROR",
+ "message": "Test message",
+ "recoverable": false
+ });
+
+ assert_eq!(error["type"], "error");
+ assert_eq!(error["code"], "TEST_ERROR");
+ assert_eq!(error["message"], "Test message");
+ assert_eq!(error["recoverable"], false);
+ }
+
+ #[test]
+ fn test_client_message_parse_speak() {
+ let json = r#"{"type": "speak", "text": "Hello world"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ match msg {
+ ClientMessage::Speak { text, voice } => {
+ assert_eq!(text, "Hello world");
+ assert!(voice.is_none());
+ }
+ _ => panic!("Expected Speak message"),
+ }
+ }
+
+ #[test]
+ fn test_client_message_parse_cancel() {
+ let json = r#"{"type": "cancel"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Cancel));
+ }
+
+ #[test]
+ fn test_client_message_parse_stop() {
+ let json = r#"{"type": "stop"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Stop));
+ }
+}
diff --git a/makima/src/server/messages.rs b/makima/src/server/messages.rs
index 9c50334..cecb622 100644
--- a/makima/src/server/messages.rs
+++ b/makima/src/server/messages.rs
@@ -103,3 +103,164 @@ impl ApiError {
}
}
}
+
+// =============================================================================
+// TTS (Text-to-Speech) Message Types
+// =============================================================================
+
+/// TTS audio encoding format for WebSocket streaming.
+#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum TtsAudioEncoding {
+ /// 16-bit signed integer PCM samples
+ #[default]
+ Pcm16,
+ /// 32-bit floating point PCM samples
+ Pcm32f,
+}
+
+/// TTS synthesis priority level.
+#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum TtsPriority {
+ /// Low priority - may be queued
+ Low,
+ /// Normal priority (default)
+ #[default]
+ Normal,
+ /// High priority - processed immediately
+ High,
+}
+
+/// TTS session start message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStartMessage {
+ /// Audio sample rate in Hz (default: 24000)
+ #[serde(default = "default_tts_sample_rate")]
+ pub sample_rate: u32,
+ /// Audio encoding format
+ #[serde(default)]
+ pub encoding: TtsAudioEncoding,
+ /// Voice identifier (default: "makima")
+ #[serde(default = "default_tts_voice")]
+ pub voice: String,
+ /// Language for synthesis (default: "English")
+ #[serde(default = "default_tts_language")]
+ pub language: String,
+}
+
+fn default_tts_sample_rate() -> u32 {
+ 24000
+}
+
+fn default_tts_voice() -> String {
+ "makima".to_string()
+}
+
+fn default_tts_language() -> String {
+ "English".to_string()
+}
+
+/// TTS speak request message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsSpeakMessage {
+ /// Text to synthesize (max 1000 characters)
+ pub text: String,
+ /// Synthesis priority
+ #[serde(default)]
+ pub priority: TtsPriority,
+}
+
+/// TTS stop request message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStopMessage {
+ /// Optional reason for stopping
+ pub reason: Option<String>,
+}
+
+/// Wrapper for all TTS WebSocket messages from client to server.
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TtsClientMessage {
+ /// Start a new TTS session
+ Start(TtsStartMessage),
+ /// Request speech synthesis
+ Speak(TtsSpeakMessage),
+ /// Stop the current session
+ Stop(TtsStopMessage),
+}
+
+/// TTS session ready message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsReadyMessage {
+ /// Unique session identifier
+ pub session_id: String,
+ /// Confirmed sample rate
+ pub sample_rate: u32,
+ /// Confirmed encoding format
+ pub encoding: TtsAudioEncoding,
+ /// Confirmed voice
+ pub voice: String,
+}
+
+/// TTS audio chunk message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsAudioChunkMessage {
+ /// Base64-encoded audio data
+ pub data: String,
+ /// Whether this is the final chunk
+ pub is_final: bool,
+ /// Timestamp in seconds from start of audio
+ pub timestamp: f64,
+}
+
+/// TTS synthesis complete message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsCompleteMessage {
+ /// Total synthesis duration in milliseconds
+ pub duration_ms: u64,
+ /// Total number of chunks sent
+ pub total_chunks: u32,
+ /// Length of input text
+ pub text_length: u32,
+}
+
+/// TTS error message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsErrorMessage {
+ /// Error code for programmatic handling
+ pub code: String,
+ /// Human-readable error message
+ pub message: String,
+}
+
+/// TTS session stopped message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStoppedMessage {
+ /// Reason for stopping
+ pub reason: String,
+}
+
+/// Wrapper for all TTS WebSocket messages from server to client.
+#[derive(Debug, Clone, Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TtsServerMessage {
+ /// Session is ready for synthesis requests
+ Ready(TtsReadyMessage),
+ /// Audio chunk (streamed during synthesis)
+ AudioChunk(TtsAudioChunkMessage),
+ /// Synthesis completed
+ Complete(TtsCompleteMessage),
+ /// Error occurred
+ Error(TtsErrorMessage),
+ /// Session has been stopped
+ Stopped(TtsStoppedMessage),
+}
diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs
index b969650..7c13f08 100644
--- a/makima/src/server/mod.rs
+++ b/makima/src/server/mod.rs
@@ -18,7 +18,7 @@ use tower_http::trace::TraceLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
-use crate::server::handlers::{api_keys, chat, contract_chat, contract_daemon, contracts, file_ws, files, history, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_red_team, mesh_supervisor, mesh_ws, repository_history, templates, transcript_analysis, users, versions};
+use crate::server::handlers::{api_keys, chat, contract_chat, contract_daemon, contracts, file_ws, files, history, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_red_team, mesh_supervisor, mesh_ws, repository_history, speak, templates, transcript_analysis, users, versions};
use crate::server::openapi::ApiDoc;
use crate::server::state::SharedState;
@@ -44,6 +44,7 @@ pub fn make_router(state: SharedState) -> Router {
// API v1 routes
let api_v1 = Router::new()
.route("/listen", get(listen::websocket_handler))
+ .route("/speak", get(speak::websocket_handler))
// Listen/transcript analysis endpoints
.route("/listen/analyze", post(transcript_analysis::analyze_transcript))
.route("/listen/create-contract", post(transcript_analysis::create_contract_from_analysis))
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
index 1bc7d7e..bf8f6f2 100644
--- a/makima/src/server/state.rs
+++ b/makima/src/server/state.rs
@@ -8,6 +8,7 @@ use uuid::Uuid;
use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
use crate::server::auth::{AuthConfig, JwtVerifier};
+use crate::tts::TtsEngine;
/// Notification payload for file updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
@@ -599,6 +600,8 @@ pub struct AppState {
pub jwt_verifier: Option<JwtVerifier>,
/// Pending worktree info requests awaiting daemon response (keyed by task_id)
pub pending_worktree_info: DashMap<Uuid, oneshot::Sender<WorktreeInfoResponse>>,
+ /// Lazily-loaded TTS engine (initialized on first Speak connection)
+ pub tts_engine: OnceCell<Box<dyn TtsEngine>>,
}
impl AppState {
@@ -673,9 +676,28 @@ impl AppState {
tool_keys: DashMap::new(),
jwt_verifier,
pending_worktree_info: DashMap::new(),
+ tts_engine: OnceCell::new(),
}
}
+ /// Get or initialize the TTS engine (lazy loading).
+ ///
+ /// The TTS engine is loaded on first Speak connection using the Qwen3 backend.
+ /// Returns a reference to the engine, or an error if loading fails.
+ pub async fn get_tts_engine(&self) -> Result<&dyn TtsEngine, Box<dyn std::error::Error + Send + Sync>> {
+ self.tts_engine.get_or_try_init(|| async {
+ tracing::info!("Lazy-loading TTS engine (Qwen3) on first Speak connection...");
+ let engine = crate::tts::TtsEngineFactory::create(
+ crate::tts::TtsBackend::Qwen3,
+ None, // Use default model directory
+ ).map_err(|e| -> Box<dyn std::error::Error + Send + Sync> {
+ Box::new(e)
+ })?;
+ tracing::info!("TTS engine loaded successfully");
+ Ok(engine)
+ }).await.map(|b| b.as_ref())
+ }
+
/// Get or initialize ML models (lazy loading).
///
/// Models are loaded on first call and cached for subsequent calls.
diff --git a/makima/src/tts.rs b/makima/src/tts/chatterbox.rs
index 5198938..e26bc06 100644
--- a/makima/src/tts.rs
+++ b/makima/src/tts/chatterbox.rs
@@ -1,17 +1,26 @@
-use std::path::{Path, PathBuf};
-use std::fs;
+//! Chatterbox TTS engine — ONNX-based (legacy).
+//!
+//! This is the existing Chatterbox TTS implementation moved from `tts.rs`,
+//! now implementing the `TtsEngine` trait for unified access.
-use hf_hub::api::sync::Api;
use std::borrow::Cow;
+use std::fs;
+use std::path::{Path, PathBuf};
+use std::sync::Mutex;
-use ndarray::{ArrayD, Array2, Array3, Array4, IxDyn};
+use hf_hub::api::sync::Api;
+use ndarray::{Array2, Array3, Array4, ArrayD, IxDyn};
use ort::session::Session;
-use ort::value::{Value, DynValue};
+use ort::value::{DynValue, Value};
use tokenizers::Tokenizer;
use crate::audio;
-pub const SAMPLE_RATE: u32 = 24_000;
+use super::{
+ apply_repetition_penalty, argmax, resample_to_24k, AudioChunk, TtsEngine, TtsError,
+ SAMPLE_RATE,
+};
+
const START_SPEECH_TOKEN: i64 = 6561;
const STOP_SPEECH_TOKEN: i64 = 6562;
const SILENCE_TOKEN: i64 = 4299;
@@ -22,57 +31,6 @@ const HEAD_DIM: usize = 64;
const MODEL_ID: &str = "ResembleAI/chatterbox-turbo-ONNX";
const DEFAULT_MODEL_DIR: &str = "models/chatterbox-turbo";
-#[derive(Debug)]
-pub enum TtsError {
- ModelLoad(String),
- Inference(String),
- Tokenizer(String),
- Audio(audio::AudioError),
- Io(std::io::Error),
- VoiceRequired,
-}
-
-impl std::fmt::Display for TtsError {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"),
- TtsError::Inference(msg) => write!(f, "inference error: {msg}"),
- TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"),
- TtsError::Audio(err) => write!(f, "audio error: {err}"),
- TtsError::Io(err) => write!(f, "io error: {err}"),
- TtsError::VoiceRequired => write!(f, "voice reference audio is required for chatterbox-turbo"),
- }
- }
-}
-
-impl std::error::Error for TtsError {}
-
-impl From<audio::AudioError> for TtsError {
- fn from(value: audio::AudioError) -> Self {
- TtsError::Audio(value)
- }
-}
-
-impl From<std::io::Error> for TtsError {
- fn from(value: std::io::Error) -> Self {
- TtsError::Io(value)
- }
-}
-
-impl From<ort::Error> for TtsError {
- fn from(value: ort::Error) -> Self {
- TtsError::ModelLoad(value.to_string())
- }
-}
-
-pub struct ChatterboxTTS {
- speech_encoder: Session,
- embed_tokens: Session,
- language_model: Session,
- conditional_decoder: Session,
- tokenizer: Tokenizer,
-}
-
struct VoiceCondition {
audio_features: ArrayD<f32>,
prompt_tokens: ArrayD<i64>,
@@ -100,6 +58,18 @@ fn extract_i64_tensor(value: &Value) -> Result<ArrayD<i64>, TtsError> {
.map_err(|e| TtsError::Inference(e.to_string()))
}
+pub struct ChatterboxTTS {
+ speech_encoder: Mutex<Session>,
+ embed_tokens: Mutex<Session>,
+ language_model: Mutex<Session>,
+ conditional_decoder: Mutex<Session>,
+ tokenizer: Tokenizer,
+}
+
+// SAFETY: Sessions are behind Mutex, Tokenizer is Send+Sync
+unsafe impl Send for ChatterboxTTS {}
+unsafe impl Sync for ChatterboxTTS {}
+
impl ChatterboxTTS {
pub fn from_pretrained(model_dir: Option<&str>) -> Result<Self, TtsError> {
let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
@@ -133,21 +103,20 @@ impl ChatterboxTTS {
.map_err(|e| TtsError::Tokenizer(e.to_string()))?;
Ok(Self {
- speech_encoder,
- embed_tokens,
- language_model,
- conditional_decoder,
+ speech_encoder: Mutex::new(speech_encoder),
+ embed_tokens: Mutex::new(embed_tokens),
+ language_model: Mutex::new(language_model),
+ conditional_decoder: Mutex::new(conditional_decoder),
tokenizer,
})
}
- pub fn generate_tts(&mut self, _text: &str) -> Result<Vec<f32>, TtsError> {
- // Chatterbox TTS requires voice reference audio
+ pub fn generate_tts(&self) -> Result<Vec<f32>, TtsError> {
Err(TtsError::VoiceRequired)
}
pub fn generate_tts_with_voice(
- &mut self,
+ &self,
text: &str,
sample_audio_path: &Path,
) -> Result<Vec<f32>, TtsError> {
@@ -157,7 +126,7 @@ impl ChatterboxTTS {
}
pub fn generate_tts_with_samples(
- &mut self,
+ &self,
text: &str,
samples: &[f32],
sample_rate: u32,
@@ -168,10 +137,8 @@ impl ChatterboxTTS {
samples.to_vec()
};
- // 1. Encode reference audio
let voice_condition = self.encode_voice(&resampled)?;
- // 2. Tokenize text
let encoding = self
.tokenizer
.encode(text, true)
@@ -179,24 +146,18 @@ impl ChatterboxTTS {
let text_input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
- // 3. Generate speech tokens
- let generated_tokens = self.generate_speech_tokens(
- &text_input_ids,
- &voice_condition.audio_features,
- )?;
+ let generated_tokens =
+ self.generate_speech_tokens(&text_input_ids, &voice_condition.audio_features)?;
- // 4. Prepare final speech tokens: prompt_tokens + generated + silence
let prompt_tokens: Vec<i64> = voice_condition.prompt_tokens.iter().copied().collect();
let silence_tokens = vec![SILENCE_TOKEN; 3];
- let mut final_tokens = Vec::with_capacity(
- prompt_tokens.len() + generated_tokens.len() + silence_tokens.len()
- );
+ let mut final_tokens =
+ Vec::with_capacity(prompt_tokens.len() + generated_tokens.len() + silence_tokens.len());
final_tokens.extend_from_slice(&prompt_tokens);
final_tokens.extend_from_slice(&generated_tokens);
final_tokens.extend_from_slice(&silence_tokens);
- // 5. Decode to audio
let audio_samples = self.decode_speech_tokens(
&final_tokens,
&voice_condition.speaker_embeddings,
@@ -206,15 +167,18 @@ impl ChatterboxTTS {
Ok(audio_samples)
}
- fn encode_voice(&mut self, samples: &[f32]) -> Result<VoiceCondition, TtsError> {
+ fn encode_voice(&self, samples: &[f32]) -> Result<VoiceCondition, TtsError> {
let audio_arr = Array2::from_shape_vec((1, samples.len()), samples.to_vec())
.map_err(|e| TtsError::Inference(e.to_string()))?;
let audio_tensor = Value::from_array(audio_arr)?;
- let outputs = self.speech_encoder.run(ort::inputs!["audio_values" => audio_tensor])?;
+ let mut encoder = self
+ .speech_encoder
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = encoder.run(ort::inputs!["audio_values" => audio_tensor])?;
- // Order: audio_features, audio_tokens (prompt_token), speaker_embeddings, speaker_features
let audio_features = extract_f32_tensor(&outputs[0])?;
let prompt_tokens = extract_i64_tensor(&outputs[1])?;
let speaker_embeddings = extract_f32_tensor(&outputs[2])?;
@@ -229,57 +193,56 @@ impl ChatterboxTTS {
}
fn generate_speech_tokens(
- &mut self,
+ &self,
text_input_ids: &[i64],
audio_features: &ArrayD<f32>,
) -> Result<Vec<i64>, TtsError> {
let max_new_tokens: usize = 1024;
let repetition_penalty: f32 = 1.2;
- // Start with START_SPEECH_TOKEN
let mut generate_tokens: Vec<i64> = vec![START_SPEECH_TOKEN];
- // Initialize empty KV cache (seq_len = 0)
- let mut past_key_values = self.init_kv_cache(0)?;
-
+ let mut past_key_values = Self::init_kv_cache(0);
let mut first_iteration = true;
let mut total_seq_len: usize = 0;
for _ in 0..max_new_tokens {
- // Get embeddings for current input_ids
let current_input_ids = if first_iteration {
- // First iteration: use text input_ids
text_input_ids.to_vec()
} else {
- // Subsequent iterations: use last generated token
vec![*generate_tokens.last().unwrap()]
};
- let input_ids_arr = Array2::from_shape_vec(
- (1, current_input_ids.len()),
- current_input_ids
- ).map_err(|e| TtsError::Inference(e.to_string()))?;
+ let input_ids_arr =
+ Array2::from_shape_vec((1, current_input_ids.len()), current_input_ids)
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
let input_ids_tensor = Value::from_array(input_ids_arr)?;
let inputs_embeds = {
- let embed_outputs = self.embed_tokens.run(ort::inputs![input_ids_tensor])?;
+ let mut embed = self
+ .embed_tokens
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let embed_outputs = embed.run(ort::inputs![input_ids_tensor])?;
extract_f32_tensor(&embed_outputs[0])?
};
- // On first iteration, concatenate audio features with text embeddings
let inputs_embeds = if first_iteration {
- let audio_feat_3d = audio_features.view()
+ let audio_feat_3d = audio_features
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
- let text_emb_3d = inputs_embeds.view()
+ let text_emb_3d = inputs_embeds
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
ndarray::concatenate(ndarray::Axis(1), &[audio_feat_3d, text_emb_3d])
.map_err(|e| TtsError::Inference(e.to_string()))?
} else {
- inputs_embeds.view()
+ inputs_embeds
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?
.to_owned()
@@ -287,7 +250,6 @@ impl ChatterboxTTS {
let seq_len = inputs_embeds.shape()[1];
- // Set up attention mask and position ids
let (attention_mask, position_ids) = if first_iteration {
total_seq_len = seq_len;
let attention_mask: Array2<i64> = Array2::ones((1, seq_len));
@@ -296,14 +258,12 @@ impl ChatterboxTTS {
} else {
total_seq_len += 1;
let attention_mask: Array2<i64> = Array2::ones((1, total_seq_len));
- let position_ids = Array2::from_shape_vec(
- (1, 1),
- vec![(total_seq_len - 1) as i64]
- ).map_err(|e| TtsError::Inference(e.to_string()))?;
+ let position_ids =
+ Array2::from_shape_vec((1, 1), vec![(total_seq_len - 1) as i64])
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
(attention_mask, position_ids)
};
- // Run language model
let (logits, new_kv) = self.run_language_model(
inputs_embeds,
position_ids,
@@ -313,8 +273,9 @@ impl ChatterboxTTS {
past_key_values = new_kv;
- // Get last logits
- let logits_3d = logits.view().into_dimensionality::<ndarray::Ix3>()
+ let logits_3d = logits
+ .view()
+ .into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
let last_idx = logits_3d.shape()[1] - 1;
@@ -324,12 +285,9 @@ impl ChatterboxTTS {
.copied()
.collect();
- // Apply repetition penalty
apply_repetition_penalty(&mut current_logits, &generate_tokens, repetition_penalty);
- // Get next token
let next_token = argmax(&current_logits);
-
generate_tokens.push(next_token);
if next_token == STOP_SPEECH_TOKEN {
@@ -339,15 +297,14 @@ impl ChatterboxTTS {
first_iteration = false;
}
- // Return tokens without START and STOP tokens: [1:-1]
if generate_tokens.len() > 2 {
- Ok(generate_tokens[1..generate_tokens.len()-1].to_vec())
+ Ok(generate_tokens[1..generate_tokens.len() - 1].to_vec())
} else {
Ok(Vec::new())
}
}
- fn init_kv_cache(&self, seq_len: usize) -> Result<Vec<Array4<f32>>, TtsError> {
+ fn init_kv_cache(seq_len: usize) -> Vec<Array4<f32>> {
let mut cache = Vec::with_capacity(NUM_LAYERS * 2);
for _ in 0..NUM_LAYERS {
let key = Array4::<f32>::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM));
@@ -355,11 +312,11 @@ impl ChatterboxTTS {
cache.push(key);
cache.push(value);
}
- Ok(cache)
+ cache
}
fn run_language_model(
- &mut self,
+ &self,
inputs_embeds: Array3<f32>,
position_ids: Array2<i64>,
attention_mask: Array2<i64>,
@@ -367,23 +324,37 @@ impl ChatterboxTTS {
) -> Result<(ArrayD<f32>, Vec<Array4<f32>>), TtsError> {
let mut inputs: Vec<(Cow<str>, DynValue)> = Vec::new();
- inputs.push((Cow::from("inputs_embeds"), Value::from_array(inputs_embeds)?.into_dyn()));
- inputs.push((Cow::from("position_ids"), Value::from_array(position_ids)?.into_dyn()));
- inputs.push((Cow::from("attention_mask"), Value::from_array(attention_mask)?.into_dyn()));
+ inputs.push((
+ Cow::from("inputs_embeds"),
+ Value::from_array(inputs_embeds)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("position_ids"),
+ Value::from_array(position_ids)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("attention_mask"),
+ Value::from_array(attention_mask)?.into_dyn(),
+ ));
- // Add KV cache inputs
for layer_idx in 0..NUM_LAYERS {
let key_name = format!("past_key_values.{}.key", layer_idx);
let value_name = format!("past_key_values.{}.value", layer_idx);
- let key_tensor = Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn();
- let value_tensor = Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn();
+ let key_tensor =
+ Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn();
+ let value_tensor =
+ Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn();
inputs.push((Cow::from(key_name), key_tensor));
inputs.push((Cow::from(value_name), value_tensor));
}
- let outputs = self.language_model.run(inputs)?;
+ let mut lm = self
+ .language_model
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = lm.run(inputs)?;
let logits = extract_f32_tensor(&outputs[0])?;
@@ -395,9 +366,11 @@ impl ChatterboxTTS {
let key_arr = extract_f32_tensor(&outputs[key_idx])?;
let value_arr = extract_f32_tensor(&outputs[value_idx])?;
- let key_4d = key_arr.into_dimensionality::<ndarray::Ix4>()
+ let key_4d = key_arr
+ .into_dimensionality::<ndarray::Ix4>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
- let value_4d = value_arr.into_dimensionality::<ndarray::Ix4>()
+ let value_4d = value_arr
+ .into_dimensionality::<ndarray::Ix4>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
new_kv.push(key_4d.to_owned());
@@ -408,7 +381,7 @@ impl ChatterboxTTS {
}
fn decode_speech_tokens(
- &mut self,
+ &self,
speech_tokens: &[i64],
speaker_embeddings: &ArrayD<f32>,
speaker_features: &ArrayD<f32>,
@@ -417,15 +390,29 @@ impl ChatterboxTTS {
return Ok(Vec::new());
}
- let tokens_arr = Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec())
- .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let tokens_arr =
+ Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec())
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
let mut inputs: Vec<(Cow<str>, DynValue)> = Vec::new();
- inputs.push((Cow::from("speech_tokens"), Value::from_array(tokens_arr)?.into_dyn()));
- inputs.push((Cow::from("speaker_embeddings"), Value::from_array(speaker_embeddings.clone())?.into_dyn()));
- inputs.push((Cow::from("speaker_features"), Value::from_array(speaker_features.clone())?.into_dyn()));
-
- let outputs = self.conditional_decoder.run(inputs)?;
+ inputs.push((
+ Cow::from("speech_tokens"),
+ Value::from_array(tokens_arr)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("speaker_embeddings"),
+ Value::from_array(speaker_embeddings.clone())?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("speaker_features"),
+ Value::from_array(speaker_features.clone())?.into_dyn(),
+ ));
+
+ let mut decoder = self
+ .conditional_decoder
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = decoder.run(inputs)?;
let waveform = extract_f32_tensor(&outputs[0])?;
@@ -433,6 +420,34 @@ impl ChatterboxTTS {
}
}
+#[async_trait::async_trait]
+impl TtsEngine for ChatterboxTTS {
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ let samples = match reference_audio {
+ Some(audio) => {
+ let sr = reference_sample_rate.unwrap_or(SAMPLE_RATE);
+ self.generate_tts_with_samples(text, audio, sr)?
+ }
+ None => return Err(TtsError::VoiceRequired),
+ };
+
+ Ok(vec![AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }])
+ }
+
+ fn is_ready(&self) -> bool {
+ true
+ }
+}
+
fn download_models(target_dir: &Path) -> Result<(), TtsError> {
fs::create_dir_all(target_dir)?;
@@ -453,7 +468,9 @@ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
for file in &model_files {
println!("Downloading {}...", file);
- let downloaded_path = repo.get(file).map_err(|e| TtsError::ModelLoad(e.to_string()))?;
+ let downloaded_path = repo
+ .get(file)
+ .map_err(|e| TtsError::ModelLoad(e.to_string()))?;
let filename = Path::new(file).file_name().unwrap();
let target_path = target_dir.join(filename);
@@ -466,115 +483,3 @@ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
println!("Models downloaded to {:?}", target_dir);
Ok(())
}
-
-fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec<f32> {
- if input_rate == SAMPLE_RATE {
- return samples.to_vec();
- }
- if samples.is_empty() {
- return Vec::new();
- }
-
- let ratio = input_rate as f64 / SAMPLE_RATE as f64;
- let output_len = ((samples.len() as f64) / ratio).ceil() as usize;
-
- let mut output = Vec::with_capacity(output_len);
- for i in 0..output_len {
- let src_idx = (i as f64 * ratio) as usize;
- let sample = samples.get(src_idx).copied().unwrap_or(0.0);
- output.push(sample);
- }
-
- output
-}
-
-fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) {
- for &token in generated {
- if (token as usize) < logits.len() {
- let score = logits[token as usize];
- // Note: opposite of standard - if score < 0, multiply; if > 0, divide
- logits[token as usize] = if score < 0.0 {
- score * penalty
- } else {
- score / penalty
- };
- }
- }
-}
-
-fn argmax(logits: &[f32]) -> i64 {
- logits
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx as i64)
- .unwrap_or(0)
-}
-
-pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> {
- let mut file = fs::File::create(path)?;
- write_wav(&mut file, samples, SAMPLE_RATE)?;
- Ok(())
-}
-
-fn write_wav<W: std::io::Write>(writer: &mut W, samples: &[f32], sample_rate: u32) -> Result<(), std::io::Error> {
- let num_samples = samples.len() as u32;
- let num_channels: u16 = 1;
- let bits_per_sample: u16 = 16;
- let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
- let block_align = num_channels * bits_per_sample / 8;
- let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
- let file_size = 36 + data_size;
-
- writer.write_all(b"RIFF")?;
- writer.write_all(&file_size.to_le_bytes())?;
- writer.write_all(b"WAVE")?;
-
- writer.write_all(b"fmt ")?;
- writer.write_all(&16u32.to_le_bytes())?;
- writer.write_all(&1u16.to_le_bytes())?;
- writer.write_all(&num_channels.to_le_bytes())?;
- writer.write_all(&sample_rate.to_le_bytes())?;
- writer.write_all(&byte_rate.to_le_bytes())?;
- writer.write_all(&block_align.to_le_bytes())?;
- writer.write_all(&bits_per_sample.to_le_bytes())?;
-
- writer.write_all(b"data")?;
- writer.write_all(&data_size.to_le_bytes())?;
-
- for &sample in samples {
- let clamped = sample.clamp(-1.0, 1.0);
- let int_sample = (clamped * 32767.0) as i16;
- writer.write_all(&int_sample.to_le_bytes())?;
- }
-
- Ok(())
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_argmax() {
- let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
- assert_eq!(argmax(&logits), 3);
- }
-
- #[test]
- fn test_resample_same_rate() {
- let samples = vec![0.1, 0.2, 0.3];
- let resampled = resample_to_24k(&samples, SAMPLE_RATE);
- assert_eq!(resampled, samples);
- }
-
- #[test]
- fn test_repetition_penalty() {
- let mut logits = vec![1.0, 2.0, 3.0, 4.0];
- let generated = vec![1, 3];
- apply_repetition_penalty(&mut logits, &generated, 1.2);
- // score > 0 -> divide
- assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6);
- assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6);
- }
-}
diff --git a/makima/src/tts/mod.rs b/makima/src/tts/mod.rs
new file mode 100644
index 0000000..2cd0412
--- /dev/null
+++ b/makima/src/tts/mod.rs
@@ -0,0 +1,281 @@
+//! TTS engine abstraction and implementations.
+//!
+//! Provides a trait-based TTS engine interface with two backends:
+//! - **Chatterbox**: ONNX-based TTS (legacy)
+//! - **Qwen3**: Pure Rust candle-based Qwen3-TTS-12Hz-0.6B
+
+use std::path::Path;
+
+pub mod chatterbox;
+pub mod qwen3;
+
+// Re-export primary types
+pub use chatterbox::ChatterboxTTS;
+pub use qwen3::Qwen3Tts;
+
+/// Audio output sample rate (both engines output 24kHz).
+pub const SAMPLE_RATE: u32 = 24_000;
+
+/// A chunk of generated audio for streaming output.
+#[derive(Debug, Clone)]
+pub struct AudioChunk {
+ /// PCM f32 samples in [-1.0, 1.0].
+ pub samples: Vec<f32>,
+ /// Sample rate (always 24000 for both engines).
+ pub sample_rate: u32,
+ /// Whether this is the final chunk in the stream.
+ pub is_final: bool,
+}
+
+impl AudioChunk {
+ /// Convert to 16-bit PCM bytes (little-endian) for WebSocket streaming.
+ pub fn to_pcm16_bytes(&self) -> Vec<u8> {
+ let mut buf = Vec::with_capacity(self.samples.len() * 2);
+ for &s in &self.samples {
+ let clamped = s.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ buf.extend_from_slice(&int_sample.to_le_bytes());
+ }
+ buf
+ }
+}
+
+/// Errors that can occur during TTS operations.
+#[derive(Debug)]
+pub enum TtsError {
+ ModelLoad(String),
+ Inference(String),
+ Tokenizer(String),
+ Audio(crate::audio::AudioError),
+ Io(std::io::Error),
+ VoiceRequired,
+ Config(String),
+ Candle(String),
+}
+
+impl std::fmt::Display for TtsError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"),
+ TtsError::Inference(msg) => write!(f, "inference error: {msg}"),
+ TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"),
+ TtsError::Audio(err) => write!(f, "audio error: {err}"),
+ TtsError::Io(err) => write!(f, "io error: {err}"),
+ TtsError::VoiceRequired => {
+ write!(f, "voice reference audio is required")
+ }
+ TtsError::Config(msg) => write!(f, "config error: {msg}"),
+ TtsError::Candle(msg) => write!(f, "candle error: {msg}"),
+ }
+ }
+}
+
+impl std::error::Error for TtsError {}
+
+impl From<crate::audio::AudioError> for TtsError {
+ fn from(value: crate::audio::AudioError) -> Self {
+ TtsError::Audio(value)
+ }
+}
+
+impl From<std::io::Error> for TtsError {
+ fn from(value: std::io::Error) -> Self {
+ TtsError::Io(value)
+ }
+}
+
+impl From<ort::Error> for TtsError {
+ fn from(value: ort::Error) -> Self {
+ TtsError::ModelLoad(value.to_string())
+ }
+}
+
+impl From<candle_core::Error> for TtsError {
+ fn from(value: candle_core::Error) -> Self {
+ TtsError::Candle(value.to_string())
+ }
+}
+
+/// Which TTS backend to use.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum TtsBackend {
+ /// ONNX-based Chatterbox TTS (legacy).
+ Chatterbox,
+ /// Candle-based Qwen3-TTS (preferred).
+ Qwen3,
+}
+
+/// TTS engine trait — implemented by both Chatterbox and Qwen3.
+#[async_trait::async_trait]
+pub trait TtsEngine: Send + Sync {
+ /// Generate complete audio from text with a voice reference.
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError>;
+
+ /// Check if the engine is loaded and ready.
+ fn is_ready(&self) -> bool;
+
+ /// Get the engine's output sample rate.
+ fn sample_rate(&self) -> u32 {
+ SAMPLE_RATE
+ }
+}
+
+/// Factory for creating TTS engines.
+pub struct TtsEngineFactory;
+
+impl TtsEngineFactory {
+ /// Create a TTS engine of the specified backend type.
+ pub fn create(backend: TtsBackend, model_dir: Option<&str>) -> Result<Box<dyn TtsEngine>, TtsError> {
+ match backend {
+ TtsBackend::Chatterbox => {
+ let engine = ChatterboxTTS::from_pretrained(model_dir)?;
+ Ok(Box::new(engine))
+ }
+ TtsBackend::Qwen3 => {
+ let device = candle_core::Device::Cpu; // Default to CPU; GPU selection happens at higher level
+ let engine = Qwen3Tts::from_pretrained(model_dir, &device)?;
+ Ok(Box::new(engine))
+ }
+ }
+ }
+}
+
+/// Save audio samples to a WAV file.
+pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> {
+ let mut file = std::fs::File::create(path)?;
+ write_wav(&mut file, samples, SAMPLE_RATE)?;
+ Ok(())
+}
+
+fn write_wav<W: std::io::Write>(
+ writer: &mut W,
+ samples: &[f32],
+ sample_rate: u32,
+) -> Result<(), std::io::Error> {
+ let num_samples = samples.len() as u32;
+ let num_channels: u16 = 1;
+ let bits_per_sample: u16 = 16;
+ let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
+ let block_align = num_channels * bits_per_sample / 8;
+ let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
+ let file_size = 36 + data_size;
+
+ writer.write_all(b"RIFF")?;
+ writer.write_all(&file_size.to_le_bytes())?;
+ writer.write_all(b"WAVE")?;
+
+ writer.write_all(b"fmt ")?;
+ writer.write_all(&16u32.to_le_bytes())?;
+ writer.write_all(&1u16.to_le_bytes())?;
+ writer.write_all(&num_channels.to_le_bytes())?;
+ writer.write_all(&sample_rate.to_le_bytes())?;
+ writer.write_all(&byte_rate.to_le_bytes())?;
+ writer.write_all(&block_align.to_le_bytes())?;
+ writer.write_all(&bits_per_sample.to_le_bytes())?;
+
+ writer.write_all(b"data")?;
+ writer.write_all(&data_size.to_le_bytes())?;
+
+ for &sample in samples {
+ let clamped = sample.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ writer.write_all(&int_sample.to_le_bytes())?;
+ }
+
+ Ok(())
+}
+
+/// Resample audio to 24kHz using simple linear interpolation.
+pub fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec<f32> {
+ if input_rate == SAMPLE_RATE {
+ return samples.to_vec();
+ }
+ if samples.is_empty() {
+ return Vec::new();
+ }
+
+ let ratio = input_rate as f64 / SAMPLE_RATE as f64;
+ let output_len = ((samples.len() as f64) / ratio).ceil() as usize;
+
+ let mut output = Vec::with_capacity(output_len);
+ for i in 0..output_len {
+ let src_idx = (i as f64 * ratio) as usize;
+ let sample = samples.get(src_idx).copied().unwrap_or(0.0);
+ output.push(sample);
+ }
+
+ output
+}
+
+/// Apply repetition penalty to logits based on previously generated tokens.
+pub fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) {
+ for &token in generated {
+ if (token as usize) < logits.len() {
+ let score = logits[token as usize];
+ logits[token as usize] = if score < 0.0 {
+ score * penalty
+ } else {
+ score / penalty
+ };
+ }
+ }
+}
+
+/// Return the index of the maximum value in logits.
+pub fn argmax(logits: &[f32]) -> i64 {
+ logits
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+ .map(|(idx, _)| idx as i64)
+ .unwrap_or(0)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_argmax() {
+ let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
+ assert_eq!(argmax(&logits), 3);
+ }
+
+ #[test]
+ fn test_resample_same_rate() {
+ let samples = vec![0.1, 0.2, 0.3];
+ let resampled = resample_to_24k(&samples, SAMPLE_RATE);
+ assert_eq!(resampled, samples);
+ }
+
+ #[test]
+ fn test_repetition_penalty() {
+ let mut logits = vec![1.0, 2.0, 3.0, 4.0];
+ let generated = vec![1, 3];
+ apply_repetition_penalty(&mut logits, &generated, 1.2);
+ assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6);
+ assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_audio_chunk_to_pcm16() {
+ let chunk = AudioChunk {
+ samples: vec![0.0, 1.0, -1.0],
+ sample_rate: 24_000,
+ is_final: true,
+ };
+ let bytes = chunk.to_pcm16_bytes();
+ assert_eq!(bytes.len(), 6);
+ // 0.0 -> 0i16
+ assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0);
+ // 1.0 -> 32767i16
+ assert_eq!(i16::from_le_bytes([bytes[2], bytes[3]]), 32767);
+ // -1.0 -> -32767i16
+ assert_eq!(i16::from_le_bytes([bytes[4], bytes[5]]), -32767);
+ }
+}
diff --git a/makima/src/tts/qwen3/code_predictor.rs b/makima/src/tts/qwen3/code_predictor.rs
new file mode 100644
index 0000000..0ef8a1d
--- /dev/null
+++ b/makima/src/tts/qwen3/code_predictor.rs
@@ -0,0 +1,261 @@
+//! Multi-Token Prediction (MTP) code predictor.
+//!
+//! After the main LM predicts the zeroth codebook token, this module
+//! predicts the remaining 15 codebook layers in parallel from the
+//! LM's hidden states.
+//!
+//! Architecture:
+//! - 5 transformer layers (same structure as main LM layers)
+//! - 16 output heads, one per codebook (vocab 2048 each)
+//! - Input: last hidden state from main LM + zeroth codebook embedding
+//! - Output: 16 codebook token predictions
+
+use candle_core::{Device, Module, Result, Tensor, D};
+use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+
+use super::config::{CodePredictorConfig, Qwen3LmConfig};
+use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};
+
+/// A single code predictor transformer layer.
+///
+/// Uses the same pre-norm residual structure as the main LM layers.
+pub struct CodePredictorLayer {
+ self_attn: Qwen3Attention,
+ mlp: Qwen3Mlp,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl CodePredictorLayer {
+ pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
+ // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
+ let lm_config = Qwen3LmConfig {
+ hidden_size: config.hidden_size,
+ num_hidden_layers: config.num_layers,
+ num_attention_heads: config.num_attention_heads,
+ num_key_value_heads: config.num_attention_heads, // No GQA in predictor
+ intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
+ head_dim: config.hidden_size / config.num_attention_heads,
+ rms_norm_eps: config.rms_norm_eps,
+ ..Qwen3LmConfig::default()
+ };
+
+ let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
+ let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
+ let input_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("input_layernorm"),
+ )?;
+ let post_attention_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let residual = hidden_states;
+ let hidden_states = self.input_layernorm.forward(hidden_states)?;
+ let hidden_states =
+ self.self_attn
+ .forward(&hidden_states, rope, kv_cache, attention_mask)?;
+ let hidden_states = (residual + hidden_states)?;
+
+ let residual = &hidden_states;
+ let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
+ let hidden_states = self.mlp.forward(&hidden_states)?;
+ let output = (residual + hidden_states)?;
+
+ Ok(output)
+ }
+}
+
+/// Multi-token prediction code predictor.
+///
+/// Takes the hidden states from the main LM and predicts all 16 codebook
+/// tokens. The zeroth codebook is predicted by the main LM head; this
+/// module predicts the remaining 15 residual codebooks.
+pub struct CodePredictor {
+ /// Embedding layer for codebook tokens (shared across groups).
+ code_embeddings: Vec<Embedding>,
+ /// Projection from LM hidden + code embedding to predictor hidden.
+ input_proj: Linear,
+ /// 5 transformer layers.
+ layers: Vec<CodePredictorLayer>,
+ /// Final normalization.
+ norm: RmsNorm,
+ /// Per-codebook output heads (16 heads, each projecting to codebook_vocab_size).
+ output_heads: Vec<Linear>,
+ /// RoPE for the predictor's attention layers.
+ rope: RotaryEmbedding,
+ config: CodePredictorConfig,
+}
+
+impl CodePredictor {
+ pub fn new(
+ config: &CodePredictorConfig,
+ lm_config: &Qwen3LmConfig,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let predictor_vb = vb.pp("code_predictor");
+
+ // Code embeddings for each codebook group
+ let mut code_embeddings = Vec::with_capacity(config.num_code_groups);
+ for i in 0..config.num_code_groups {
+ let emb = embedding(
+ config.codebook_vocab_size,
+ config.hidden_size,
+ predictor_vb.pp(format!("code_embeddings.{i}")),
+ )?;
+ code_embeddings.push(emb);
+ }
+
+ // Input projection: LM hidden (1024) + code embedding (1024) -> predictor hidden (1024)
+ let input_proj = linear_no_bias(
+ config.hidden_size * 2,
+ config.hidden_size,
+ predictor_vb.pp("input_proj"),
+ )?;
+
+ // Transformer layers
+ let mut layers = Vec::with_capacity(config.num_layers);
+ for i in 0..config.num_layers {
+ let layer =
+ CodePredictorLayer::new(config, predictor_vb.pp(format!("layers.{i}")))?;
+ layers.push(layer);
+ }
+
+ let norm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ predictor_vb.pp("norm"),
+ )?;
+
+ // Output heads for each codebook
+ let mut output_heads = Vec::with_capacity(config.num_code_groups);
+ for i in 0..config.num_code_groups {
+ let head = linear_no_bias(
+ config.hidden_size,
+ config.codebook_vocab_size,
+ predictor_vb.pp(format!("output_heads.{i}")),
+ )?;
+ output_heads.push(head);
+ }
+
+ // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
+ let predictor_head_dim = config.hidden_size / config.num_attention_heads;
+ let rope_config = Qwen3LmConfig {
+ head_dim: predictor_head_dim,
+ rope_theta: lm_config.rope_theta,
+ max_position_embeddings: lm_config.max_position_embeddings,
+ ..Qwen3LmConfig::default()
+ };
+ let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;
+
+ Ok(Self {
+ code_embeddings,
+ input_proj,
+ layers,
+ norm,
+ output_heads,
+ rope,
+ config: config.clone(),
+ })
+ }
+
+ /// Predict all 16 codebook tokens from the LM hidden state.
+ ///
+ /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
+ /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
+ ///
+ /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
+ pub fn predict(
+ &self,
+ lm_hidden: &Tensor,
+ zeroth_code: u32,
+ device: &Device,
+ ) -> Result<Vec<u32>> {
+ let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
+ all_codes.push(zeroth_code);
+
+ // The code predictor iterates through codebook groups.
+ // For each group i (1..16), it:
+ // 1. Embeds the previous codebook token
+ // 2. Concatenates with LM hidden state
+ // 3. Projects through the predictor layers
+ // 4. Predicts the next codebook token via output_head[i]
+ let mut prev_code = zeroth_code;
+
+ for group_idx in 1..self.config.num_code_groups {
+ // Embed the previous codebook token
+ let code_tensor = Tensor::from_vec(
+ vec![prev_code],
+ (1, 1),
+ device,
+ )?;
+ let code_emb = self.code_embeddings[group_idx - 1].forward(&code_tensor)?;
+
+ // Concatenate LM hidden state with code embedding
+ let combined = Tensor::cat(&[lm_hidden, &code_emb], D::Minus1)?;
+
+ // Project to predictor hidden size
+ let mut hidden = self.input_proj.forward(&combined)?;
+
+ // Run through predictor transformer layers (no KV cache needed — single step)
+ let mut kv_caches: Vec<KvCache> =
+ (0..self.config.num_layers).map(|_| KvCache::new()).collect();
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
+ }
+
+ hidden = self.norm.forward(&hidden)?;
+
+ // Predict codebook token
+ let logits = self.output_heads[group_idx].forward(&hidden)?;
+
+ // Greedy decode: argmax
+ let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
+ let next_code = logits_flat
+ .argmax(0)?
+ .to_scalar::<u32>()?;
+
+ all_codes.push(next_code);
+ prev_code = next_code;
+ }
+
+ Ok(all_codes)
+ }
+
+ /// Number of codebook groups.
+ pub fn num_code_groups(&self) -> usize {
+ self.config.num_code_groups
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_code_predictor_config() {
+ let config = CodePredictorConfig::default();
+ assert_eq!(config.num_layers, 5);
+ assert_eq!(config.num_code_groups, 16);
+ assert_eq!(config.codebook_vocab_size, 2048);
+ assert_eq!(config.hidden_size, 1024);
+ }
+}
diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs
new file mode 100644
index 0000000..6fb55d7
--- /dev/null
+++ b/makima/src/tts/qwen3/config.rs
@@ -0,0 +1,271 @@
+//! Qwen3-TTS model configuration.
+//!
+//! Parses config.json from the HuggingFace model repository to configure
+//! the language model, code predictor, and speech tokenizer.
+
+use serde::Deserialize;
+
+use crate::tts::TtsError;
+
+/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
+#[derive(Debug, Clone, Deserialize)]
+pub struct Qwen3TtsConfig {
+ /// Language model (talker) configuration.
+ #[serde(default = "Qwen3LmConfig::default")]
+ pub lm: Qwen3LmConfig,
+
+ /// Code predictor (multi-token prediction) configuration.
+ #[serde(default = "CodePredictorConfig::default")]
+ pub code_predictor: CodePredictorConfig,
+
+ /// Speech tokenizer configuration.
+ #[serde(default = "SpeechTokenizerConfig::default")]
+ pub speech_tokenizer: SpeechTokenizerConfig,
+}
+
+impl Default for Qwen3TtsConfig {
+ fn default() -> Self {
+ Self {
+ lm: Qwen3LmConfig::default(),
+ code_predictor: CodePredictorConfig::default(),
+ speech_tokenizer: SpeechTokenizerConfig::default(),
+ }
+ }
+}
+
+impl Qwen3TtsConfig {
+ /// Load from a config.json file path.
+ pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
+ let content = std::fs::read_to_string(path)
+ .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
+ Self::from_json_str(&content)
+ }
+
+ /// Load from a JSON string.
+ pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
+ // Try to parse the full HuggingFace config.json format first
+ if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
+ return Ok(Self::from_hf_config(&hf_config));
+ }
+ // Fall back to direct deserialization
+ serde_json::from_str(json)
+ .map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
+ }
+
+ /// Convert from HuggingFace's config.json format.
+ fn from_hf_config(hf: &HfConfig) -> Self {
+ Self {
+ lm: Qwen3LmConfig {
+ hidden_size: hf.hidden_size.unwrap_or(1024),
+ num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
+ num_attention_heads: hf.num_attention_heads.unwrap_or(16),
+ num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
+ intermediate_size: hf.intermediate_size.unwrap_or(3072),
+ head_dim: hf.head_dim.unwrap_or(128),
+ vocab_size: hf.vocab_size.unwrap_or(151_936),
+ max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
+ rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
+ rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
+ use_sliding_window: hf.use_sliding_window.unwrap_or(false),
+ sliding_window: hf.sliding_window,
+ hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
+ },
+ code_predictor: CodePredictorConfig {
+ hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
+ num_layers: hf.code_predictor_num_layers.unwrap_or(5),
+ num_attention_heads: hf
+ .code_predictor_num_attention_heads
+ .unwrap_or(16),
+ num_code_groups: hf.num_code_groups.unwrap_or(16),
+ codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
+ rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
+ },
+ speech_tokenizer: SpeechTokenizerConfig::default(),
+ }
+ }
+}
+
+/// Language model configuration (28-layer Qwen3 transformer).
+#[derive(Debug, Clone, Deserialize)]
+pub struct Qwen3LmConfig {
+ /// Hidden dimension of transformer layers.
+ pub hidden_size: usize,
+ /// Number of transformer layers.
+ pub num_hidden_layers: usize,
+ /// Number of attention heads.
+ pub num_attention_heads: usize,
+ /// Number of key-value heads (GQA).
+ pub num_key_value_heads: usize,
+ /// Feed-forward intermediate size.
+ pub intermediate_size: usize,
+ /// Dimension per attention head.
+ pub head_dim: usize,
+ /// Text vocabulary size.
+ pub vocab_size: usize,
+ /// Maximum sequence length for RoPE.
+ pub max_position_embeddings: usize,
+ /// RMS normalization epsilon.
+ pub rms_norm_eps: f64,
+ /// RoPE theta parameter.
+ pub rope_theta: f64,
+ /// Whether to use sliding window attention.
+ pub use_sliding_window: bool,
+ /// Sliding window size (if enabled).
+ pub sliding_window: Option<usize>,
+ /// Activation function name.
+ pub hidden_act: String,
+}
+
+impl Default for Qwen3LmConfig {
+ fn default() -> Self {
+ Self {
+ hidden_size: 1024,
+ num_hidden_layers: 28,
+ num_attention_heads: 16,
+ num_key_value_heads: 8,
+ intermediate_size: 3072,
+ head_dim: 128,
+ vocab_size: 151_936,
+ max_position_embeddings: 32_768,
+ rms_norm_eps: 1e-6,
+ rope_theta: 1_000_000.0,
+ use_sliding_window: false,
+ sliding_window: None,
+ hidden_act: "silu".to_string(),
+ }
+ }
+}
+
+impl Qwen3LmConfig {
+ /// Number of key-value head groups for GQA.
+ pub fn num_kv_groups(&self) -> usize {
+ self.num_attention_heads / self.num_key_value_heads
+ }
+}
+
+/// Code predictor (multi-token prediction) configuration.
+#[derive(Debug, Clone, Deserialize)]
+pub struct CodePredictorConfig {
+ /// Hidden size (matches LM hidden size).
+ pub hidden_size: usize,
+ /// Number of predictor transformer layers.
+ pub num_layers: usize,
+ /// Number of attention heads.
+ pub num_attention_heads: usize,
+ /// Number of codebook groups (residual codebooks).
+ pub num_code_groups: usize,
+ /// Vocabulary size per codebook.
+ pub codebook_vocab_size: usize,
+ /// RMS norm epsilon.
+ pub rms_norm_eps: f64,
+}
+
+impl Default for CodePredictorConfig {
+ fn default() -> Self {
+ Self {
+ hidden_size: 1024,
+ num_layers: 5,
+ num_attention_heads: 16,
+ num_code_groups: 16,
+ codebook_vocab_size: 2048,
+ rms_norm_eps: 1e-6,
+ }
+ }
+}
+
+/// Speech tokenizer (ConvNet codec) configuration.
+#[derive(Debug, Clone, Deserialize)]
+pub struct SpeechTokenizerConfig {
+ /// Number of RVQ codebooks.
+ pub num_codebooks: usize,
+ /// Codebook embedding dimension.
+ pub codebook_dim: usize,
+ /// Codebook vocabulary size per layer.
+ pub codebook_size: usize,
+ /// Encoder/decoder hidden channels.
+ pub hidden_channels: usize,
+ /// Output sample rate.
+ pub sample_rate: u32,
+ /// Token frame rate (Hz).
+ pub frame_rate: f32,
+ /// HuggingFace model ID for the speech tokenizer.
+ pub model_id: String,
+}
+
+impl Default for SpeechTokenizerConfig {
+ fn default() -> Self {
+ Self {
+ num_codebooks: 16,
+ codebook_dim: 256,
+ codebook_size: 2048,
+ hidden_channels: 512,
+ sample_rate: 24_000,
+ frame_rate: 12.5,
+ model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
+ }
+ }
+}
+
+/// HuggingFace config.json format (partial, fields we need).
+#[derive(Debug, Deserialize)]
+struct HfConfig {
+ hidden_size: Option<usize>,
+ num_hidden_layers: Option<usize>,
+ num_attention_heads: Option<usize>,
+ num_key_value_heads: Option<usize>,
+ intermediate_size: Option<usize>,
+ head_dim: Option<usize>,
+ vocab_size: Option<usize>,
+ max_position_embeddings: Option<usize>,
+ rms_norm_eps: Option<f64>,
+ rope_theta: Option<f64>,
+ use_sliding_window: Option<bool>,
+ sliding_window: Option<usize>,
+ hidden_act: Option<String>,
+ // Code predictor specific fields
+ code_predictor_hidden_size: Option<usize>,
+ code_predictor_num_layers: Option<usize>,
+ code_predictor_num_attention_heads: Option<usize>,
+ num_code_groups: Option<usize>,
+ codebook_vocab_size: Option<usize>,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_default_config() {
+ let config = Qwen3TtsConfig::default();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.lm.num_attention_heads, 16);
+ assert_eq!(config.lm.num_key_value_heads, 8);
+ assert_eq!(config.lm.head_dim, 128);
+ assert_eq!(config.lm.num_kv_groups(), 2);
+ assert_eq!(config.code_predictor.num_layers, 5);
+ assert_eq!(config.code_predictor.num_code_groups, 16);
+ assert_eq!(config.speech_tokenizer.num_codebooks, 16);
+ }
+
+ #[test]
+ fn test_config_from_json() {
+ let json = r#"{
+ "hidden_size": 1024,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 16,
+ "num_key_value_heads": 8,
+ "intermediate_size": 3072,
+ "vocab_size": 151936,
+ "max_position_embeddings": 32768,
+ "rms_norm_eps": 1e-6,
+ "rope_theta": 1000000.0,
+ "hidden_act": "silu"
+ }"#;
+
+ let config = Qwen3TtsConfig::from_json_str(json).unwrap();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.lm.vocab_size, 151_936);
+ }
+}
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
new file mode 100644
index 0000000..02161e6
--- /dev/null
+++ b/makima/src/tts/qwen3/generate.rs
@@ -0,0 +1,426 @@
+//! Autoregressive generation loop for Qwen3-TTS.
+//!
+//! Orchestrates the full inference pipeline:
+//! 1. Encode reference audio → speaker embedding via speech tokenizer
+//! 2. Tokenize text → token IDs
+//! 3. Autoregressive LM generation → zeroth codebook tokens
+//! 4. Code predictor → remaining 15 codebook tokens per frame
+//! 5. Speech tokenizer decoder → waveform audio
+
+use candle_core::{DType, Device, IndexOp, Result, Tensor};
+use tokenizers::Tokenizer;
+
+use super::code_predictor::CodePredictor;
+use super::model::{KvCache, Qwen3Model};
+use super::speech_tokenizer::SpeechTokenizer;
+use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};
+
+/// Special tokens for the Qwen3-TTS vocabulary.
+pub const BOS_TOKEN_ID: u32 = 151_643;
+pub const EOS_TOKEN_ID: u32 = 151_645;
+pub const PAD_TOKEN_ID: u32 = 151_643;
+
+/// Speech-specific control tokens.
+/// These are placeholders — actual values come from the tokenizer config.
+pub const START_OF_SPEECH: u32 = 151_668;
+pub const END_OF_SPEECH: u32 = 151_669;
+
+/// Generation configuration.
+#[derive(Debug, Clone)]
+pub struct GenerationConfig {
+ /// Maximum number of speech tokens to generate.
+ pub max_new_tokens: usize,
+ /// Temperature for sampling (1.0 = greedy if top_k=1).
+ pub temperature: f32,
+ /// Top-k sampling (0 = disabled, use greedy argmax).
+ pub top_k: usize,
+ /// Repetition penalty.
+ pub repetition_penalty: f32,
+ /// Whether to generate audio chunks incrementally (streaming).
+ pub streaming: bool,
+}
+
+impl Default for GenerationConfig {
+ fn default() -> Self {
+ Self {
+ max_new_tokens: 2048,
+ temperature: 1.0,
+ top_k: 0, // Greedy by default
+ repetition_penalty: 1.2,
+ streaming: false,
+ }
+ }
+}
+
+/// Manages the full generation pipeline.
+pub struct GenerationContext<'a> {
+ model: &'a Qwen3Model,
+ code_predictor: &'a CodePredictor,
+ speech_tokenizer: &'a SpeechTokenizer,
+ tokenizer: &'a Tokenizer,
+ device: &'a Device,
+ config: GenerationConfig,
+}
+
+impl<'a> GenerationContext<'a> {
+ pub fn new(
+ model: &'a Qwen3Model,
+ code_predictor: &'a CodePredictor,
+ speech_tokenizer: &'a SpeechTokenizer,
+ tokenizer: &'a Tokenizer,
+ device: &'a Device,
+ config: GenerationConfig,
+ ) -> Self {
+ Self {
+ model,
+ code_predictor,
+ speech_tokenizer,
+ tokenizer,
+ device,
+ config,
+ }
+ }
+
+ /// Generate audio from text, optionally with a voice reference.
+ ///
+ /// Returns a list of audio chunks. If `streaming` is false, returns
+ /// a single chunk with the complete audio.
+ pub fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ // 1. Encode reference audio if provided
+ let reference_codes = match reference_audio {
+ Some(audio) => Some(
+ self.speech_tokenizer
+ .encode(audio)
+ .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
+ ),
+ None => None,
+ };
+
+ // 2. Tokenize text
+ let encoding = self
+ .tokenizer
+ .encode(text, true)
+ .map_err(|e| TtsError::Tokenizer(e.to_string()))?;
+
+ let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();
+
+ // 3. Prepare input sequence
+ // Format: [BOS] [text_tokens] [START_OF_SPEECH]
+ let mut input_ids = Vec::new();
+ input_ids.push(BOS_TOKEN_ID);
+ input_ids.extend_from_slice(&text_token_ids);
+ input_ids.push(START_OF_SPEECH);
+
+ // 4. Run autoregressive generation
+ let generated_frames = self
+ .autoregressive_generate(&input_ids, reference_codes.as_deref())
+ .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;
+
+ if generated_frames.is_empty() {
+ return Ok(vec![AudioChunk {
+ samples: vec![],
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }]);
+ }
+
+ // 5. Decode all frames to audio
+ if self.config.streaming {
+ self.decode_streaming(&generated_frames)
+ } else {
+ self.decode_batch(&generated_frames)
+ }
+ }
+
+ /// Autoregressive generation loop.
+ ///
+ /// Generates zeroth codebook tokens one at a time, then uses the code
+ /// predictor to fill in the remaining 15 codebooks per frame.
+ ///
+ /// Returns: Vec of frames, each frame is [num_codebooks] tokens.
+ fn autoregressive_generate(
+ &self,
+ input_ids: &[u32],
+ _reference_codes: Option<&[Vec<u32>]>,
+ ) -> Result<Vec<Vec<u32>>> {
+ let _num_codebooks = self.code_predictor.num_code_groups();
+ let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
+ .map(|_| KvCache::new())
+ .collect();
+
+ let mut generated_frames: Vec<Vec<u32>> = Vec::new();
+ let mut past_zeroth_tokens: Vec<u32> = Vec::new();
+
+ // === First iteration: process the full input sequence ===
+ let input_tensor = Tensor::from_vec(
+ input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
+ (1, input_ids.len()),
+ self.device,
+ )?
+ .to_dtype(DType::I64)?;
+
+ let seq_len = input_ids.len();
+ let attention_mask =
+ Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;
+
+ let logits =
+ self.model
+ .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;
+
+ // Get the logits for the last position
+ let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
+ let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;
+
+ if first_token == END_OF_SPEECH as u32 {
+ return Ok(generated_frames);
+ }
+
+ // Use code predictor for all codebooks
+ let lm_hidden = self
+ .model
+ .last_hidden_state()
+ .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
+ let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;
+
+ let frame_codes = self
+ .code_predictor
+ .predict(&last_hidden, first_token, self.device)?;
+ generated_frames.push(frame_codes);
+ past_zeroth_tokens.push(first_token);
+
+ // === Subsequent iterations: one token at a time ===
+ for _step in 1..self.config.max_new_tokens {
+ let past_len = kv_caches[0].seq_len();
+
+ // Input: just the last generated zeroth codebook token
+ let last_token = *past_zeroth_tokens.last().unwrap();
+ let token_tensor = Tensor::from_vec(
+ vec![last_token as i64],
+ (1, 1),
+ self.device,
+ )?
+ .to_dtype(DType::I64)?;
+
+ // Single-token attention mask
+ let attention_mask =
+ Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;
+
+ let logits =
+ self.model
+ .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;
+
+ let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
+ let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;
+
+ if next_token == END_OF_SPEECH as u32 {
+ break;
+ }
+
+ // Predict all codebooks for this frame
+ let lm_hidden = self
+ .model
+ .last_hidden_state()
+ .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
+
+ let frame_codes = self
+ .code_predictor
+ .predict(&lm_hidden, next_token, self.device)?;
+ generated_frames.push(frame_codes);
+ past_zeroth_tokens.push(next_token);
+ }
+
+ Ok(generated_frames)
+ }
+
+ /// Sample a token from logits.
+ fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
+ let mut logits_vec: Vec<f32> = logits.to_vec1()?;
+
+ // Apply repetition penalty
+ if self.config.repetition_penalty != 1.0 {
+ for &token in past_tokens {
+ let idx = token as usize;
+ if idx < logits_vec.len() {
+ let score = logits_vec[idx];
+ logits_vec[idx] = if score < 0.0 {
+ score * self.config.repetition_penalty
+ } else {
+ score / self.config.repetition_penalty
+ };
+ }
+ }
+ }
+
+ if self.config.top_k == 0 || self.config.temperature == 0.0 {
+ // Greedy: argmax
+ let (max_idx, _) = logits_vec
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| {
+ a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
+ })
+ .unwrap_or((0, &0.0));
+ Ok(max_idx as u32)
+ } else {
+ // Top-k sampling with temperature
+ let temperature = self.config.temperature;
+
+ // Apply temperature
+ for v in logits_vec.iter_mut() {
+ *v /= temperature;
+ }
+
+ // Sort indices by logit value (descending)
+ let mut indexed: Vec<(usize, f32)> =
+ logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
+ indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
+
+ // Keep only top-k
+ let k = self.config.top_k.min(indexed.len());
+ let top_k = &indexed[..k];
+
+ // Softmax over top-k
+ let max_val = top_k[0].1;
+ let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
+ let probs: Vec<(usize, f32)> = top_k
+ .iter()
+ .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
+ .collect();
+
+ // Sample from distribution (simple linear scan)
+ let r: f32 = random_float();
+ let mut cumulative = 0.0;
+ for (idx, prob) in &probs {
+ cumulative += prob;
+ if cumulative >= r {
+ return Ok(*idx as u32);
+ }
+ }
+
+ // Fallback to highest probability
+ Ok(probs[0].0 as u32)
+ }
+ }
+
+ /// Decode all frames in batch (non-streaming).
+ fn decode_batch(
+ &self,
+ frames: &[Vec<u32>],
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ let num_codebooks = self.speech_tokenizer.num_codebooks();
+
+ // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
+ let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
+ for frame in frames {
+ for (cb_idx, &code) in frame.iter().enumerate() {
+ if cb_idx < num_codebooks {
+ codes_by_codebook[cb_idx].push(code);
+ }
+ }
+ }
+
+ let samples = self
+ .speech_tokenizer
+ .decode(&codes_by_codebook)
+ .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;
+
+ Ok(vec![AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }])
+ }
+
+ /// Decode frames incrementally (streaming).
+ fn decode_streaming(
+ &self,
+ frames: &[Vec<u32>],
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ let mut chunks = Vec::new();
+
+ // Decode in groups of frames for efficiency
+ let chunk_size = 10; // ~800ms per chunk at 12.5Hz
+ let num_codebooks = self.speech_tokenizer.num_codebooks();
+
+ for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
+ let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
+
+ // Transpose chunk frames
+ let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
+ for frame in frame_chunk {
+ for (cb_idx, &code) in frame.iter().enumerate() {
+ if cb_idx < num_codebooks {
+ codes_by_codebook[cb_idx].push(code);
+ }
+ }
+ }
+
+ let samples = self
+ .speech_tokenizer
+ .decode(&codes_by_codebook)
+ .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;
+
+ chunks.push(AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: is_last,
+ });
+ }
+
+ Ok(chunks)
+ }
+}
+
+/// Simple pseudo-random float in [0, 1) using thread-local state.
+/// Uses a basic xorshift for reproducibility without external deps.
+fn random_float() -> f32 {
+ use std::cell::Cell;
+ thread_local! {
+ static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
+ }
+
+ STATE.with(|s| {
+ let mut x = s.get();
+ x ^= x << 13;
+ x ^= x >> 7;
+ x ^= x << 17;
+ s.set(x);
+ (x as f32) / (u64::MAX as f32)
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_generation_config_default() {
+ let config = GenerationConfig::default();
+ assert_eq!(config.max_new_tokens, 2048);
+ assert_eq!(config.top_k, 0);
+ assert_eq!(config.temperature, 1.0);
+ assert_eq!(config.repetition_penalty, 1.2);
+ assert!(!config.streaming);
+ }
+
+ #[test]
+ fn test_random_float_range() {
+ for _ in 0..100 {
+ let r = random_float();
+ assert!(r >= 0.0);
+ assert!(r < 1.0);
+ }
+ }
+
+ #[test]
+ fn test_special_tokens() {
+ assert_eq!(BOS_TOKEN_ID, 151_643);
+ assert_eq!(EOS_TOKEN_ID, 151_645);
+ assert_eq!(START_OF_SPEECH, 151_668);
+ assert_eq!(END_OF_SPEECH, 151_669);
+ }
+}
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs
new file mode 100644
index 0000000..c55c118
--- /dev/null
+++ b/makima/src/tts/qwen3/mod.rs
@@ -0,0 +1,287 @@
+//! Qwen3-TTS — Pure Rust implementation using candle.
+//!
+//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
+//! with voice cloning support. No Python, no ONNX — pure Rust inference
+//! via the candle ML framework.
+//!
+//! # Architecture
+//!
+//! The model has three components:
+//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
+//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
+//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
+//!
+//! # Usage
+//!
+//! ```rust,no_run
+//! use makima::tts::qwen3::Qwen3Tts;
+//! use candle_core::Device;
+//!
+//! let device = Device::Cpu;
+//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
+//! // Use via TtsEngine trait or direct API
+//! ```
+
+pub mod code_predictor;
+pub mod config;
+pub mod generate;
+pub mod model;
+pub mod speech_tokenizer;
+
+use std::path::{Path, PathBuf};
+use std::sync::atomic::{AtomicBool, Ordering};
+
+use candle_core::{DType, Device};
+use candle_nn::VarBuilder;
+use hf_hub::api::sync::Api;
+use tokenizers::Tokenizer;
+
+use self::code_predictor::CodePredictor;
+use self::config::Qwen3TtsConfig;
+use self::generate::{GenerationConfig, GenerationContext};
+use self::model::Qwen3Model;
+use self::speech_tokenizer::SpeechTokenizer;
+use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};
+
+/// HuggingFace model IDs.
+const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
+const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
+const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";
+
+/// Qwen3-TTS engine — pure Rust candle-based inference.
+pub struct Qwen3Tts {
+ /// The 28-layer language model.
+ model: Qwen3Model,
+ /// Multi-token prediction code predictor.
+ code_predictor: CodePredictor,
+ /// Speech tokenizer (encoder + decoder + RVQ).
+ speech_tokenizer: SpeechTokenizer,
+ /// Text tokenizer.
+ tokenizer: Tokenizer,
+ /// Model configuration.
+ config: Qwen3TtsConfig,
+ /// Compute device (CPU/CUDA/Metal).
+ device: Device,
+ /// Whether the model is fully loaded and ready.
+ ready: AtomicBool,
+}
+
+// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
+// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
+unsafe impl Send for Qwen3Tts {}
+unsafe impl Sync for Qwen3Tts {}
+
+impl Qwen3Tts {
+ /// Load from a local directory or download from HuggingFace.
+ pub fn from_pretrained(
+ model_dir: Option<&str>,
+ device: &Device,
+ ) -> Result<Self, TtsError> {
+ let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
+
+ if !model_path.exists() {
+ Self::download_models(&model_path)?;
+ }
+
+ Self::load_from_path(&model_path, device)
+ }
+
+ /// Load all model components from a local directory.
+ pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
+ let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU
+
+ // Load configuration
+ let config_path = model_dir.join("config.json");
+ let config = if config_path.exists() {
+ Qwen3TtsConfig::from_json_path(&config_path)?
+ } else {
+ Qwen3TtsConfig::default()
+ };
+
+ // Load text tokenizer
+ let tokenizer_path = model_dir.join("tokenizer.json");
+ let tokenizer = Tokenizer::from_file(&tokenizer_path)
+ .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?;
+
+ // Load LM weights from safetensors
+ let lm_weights_path = model_dir.join("model.safetensors");
+ let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to read LM weights from {}: {e}",
+ lm_weights_path.display()
+ ))
+ })?;
+ let lm_vb = VarBuilder::from_buffered_safetensors(
+ lm_data,
+ dtype,
+ device,
+ ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;
+
+ // Build language model
+ let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build LM model: {e}"))
+ })?;
+
+ // Build code predictor (weights are in the same safetensors file)
+ let code_predictor =
+ CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
+ })?;
+
+ // Load speech tokenizer from separate safetensors
+ let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
+ let st_data = std::fs::read(&st_weights_path).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to read speech tokenizer weights from {}: {e}",
+ st_weights_path.display()
+ ))
+ })?;
+ let st_vb = VarBuilder::from_buffered_safetensors(
+ st_data,
+ dtype,
+ device,
+ ).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to create speech tokenizer VarBuilder: {e}"
+ ))
+ })?;
+
+ let speech_tokenizer =
+ SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
+ })?;
+
+ Ok(Self {
+ model,
+ code_predictor,
+ speech_tokenizer,
+ tokenizer,
+ config,
+ device: device.clone(),
+ ready: AtomicBool::new(true),
+ })
+ }
+
+ /// Generate audio from text with optional voice reference.
+ pub fn generate_speech(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ gen_config: Option<GenerationConfig>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ let config = gen_config.unwrap_or_default();
+
+ let ctx = GenerationContext::new(
+ &self.model,
+ &self.code_predictor,
+ &self.speech_tokenizer,
+ &self.tokenizer,
+ &self.device,
+ config,
+ );
+
+ ctx.generate(text, reference_audio)
+ }
+
+ /// Download model files from HuggingFace Hub.
+ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
+ std::fs::create_dir_all(target_dir)?;
+
+ let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;
+
+ // Download LM model files
+ println!("Downloading Qwen3-TTS language model...");
+ let lm_repo = api.model(LM_MODEL_ID.to_string());
+
+ let lm_files = [
+ "model.safetensors",
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ ];
+
+ for file in &lm_files {
+ println!(" Downloading {file}...");
+ let downloaded = lm_repo
+ .get(file)
+ .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;
+
+ let target = target_dir.join(file);
+ if !target.exists() {
+ std::fs::copy(&downloaded, &target)?;
+ }
+ }
+
+ // Download speech tokenizer
+ println!("Downloading Qwen3-TTS speech tokenizer...");
+ let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());
+
+ let st_file = "model.safetensors";
+ let downloaded = st_repo
+ .get(st_file)
+ .map_err(|e| {
+ TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
+ })?;
+
+ let target = target_dir.join("speech_tokenizer.safetensors");
+ if !target.exists() {
+ std::fs::copy(&downloaded, &target)?;
+ }
+
+ println!("All models downloaded to {}", target_dir.display());
+ Ok(())
+ }
+
+ /// Get the model configuration.
+ pub fn config(&self) -> &Qwen3TtsConfig {
+ &self.config
+ }
+
+ /// Get the compute device.
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+}
+
+#[async_trait::async_trait]
+impl TtsEngine for Qwen3Tts {
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ _reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ // Note: reference audio should already be resampled to 24kHz
+ // by the caller. If a different sample rate is provided,
+ // the caller should resample using `resample_to_24k()`.
+ self.generate_speech(text, reference_audio, None)
+ }
+
+ fn is_ready(&self) -> bool {
+ self.ready.load(Ordering::Relaxed)
+ }
+
+ fn sample_rate(&self) -> u32 {
+ SAMPLE_RATE
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_default_config() {
+ let config = Qwen3TtsConfig::default();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.code_predictor.num_code_groups, 16);
+ assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
+ }
+
+ #[test]
+ fn test_model_ids() {
+ assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
+ assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
+ }
+}
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs
new file mode 100644
index 0000000..551893b
--- /dev/null
+++ b/makima/src/tts/qwen3/model.rs
@@ -0,0 +1,581 @@
+//! Qwen3 Language Model transformer backbone.
+//!
+//! Implements the 28-layer transformer with:
+//! - Rotary Position Embeddings (RoPE)
+//! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads
+//! - SiLU-gated MLP
+//! - RMS normalization
+//! - KV cache for autoregressive generation
+//!
+//! Based on the candle-transformers Qwen2 model architecture,
+//! extended for Qwen3-TTS.
+
+use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+
+use super::config::Qwen3LmConfig;
+
+// ---------------------------------------------------------------------------
+// Rotary Position Embeddings
+// ---------------------------------------------------------------------------
+
+/// Precomputed RoPE sin/cos tables.
+#[derive(Debug, Clone)]
+pub struct RotaryEmbedding {
+ cos: Tensor,
+ sin: Tensor,
+}
+
+impl RotaryEmbedding {
+ pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result<Self> {
+ let head_dim = config.head_dim;
+ let max_seq = config.max_position_embeddings;
+ let theta = config.rope_theta;
+
+ let inv_freq: Vec<f32> = (0..head_dim)
+ .step_by(2)
+ .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
+ .collect();
+
+ let inv_freq_tensor =
+ Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?;
+
+ let positions: Vec<f32> = (0..max_seq).map(|p| p as f32).collect();
+ let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?;
+
+ // [max_seq, head_dim/2]
+ let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?;
+ // [max_seq, head_dim] by repeating
+ let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+
+ let cos = emb.cos()?.to_dtype(dtype)?;
+ let sin = emb.sin()?.to_dtype(dtype)?;
+
+ Ok(Self { cos, sin })
+ }
+
+ /// Apply RoPE to query and key tensors.
+ /// Input shape: [batch, heads, seq_len, head_dim]
+ pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
+ let seq_len = q.dim(2)?;
+ let cos = self.cos.narrow(0, offset, seq_len)?;
+ let sin = self.sin.narrow(0, offset, seq_len)?;
+
+ let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim]
+ let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
+
+ let q_rotated = Self::rotate_half(q, &cos, &sin)?;
+ let k_rotated = Self::rotate_half(k, &cos, &sin)?;
+
+ Ok((q_rotated, k_rotated))
+ }
+
+ fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
+ let half_dim = x.dim(D::Minus1)? / 2;
+ let x1 = x.narrow(D::Minus1, 0, half_dim)?;
+ let x2 = x.narrow(D::Minus1, half_dim, half_dim)?;
+
+ // [-x2, x1] concatenated
+ let neg_x2 = x2.neg()?;
+ let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?;
+
+ // x * cos + rotated * sin
+ let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?;
+ Ok(result)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// KV Cache
+// ---------------------------------------------------------------------------
+
+/// Per-layer key-value cache for autoregressive generation.
+#[derive(Debug, Clone)]
+pub struct KvCache {
+ key: Option<Tensor>,
+ value: Option<Tensor>,
+}
+
+impl KvCache {
+ pub fn new() -> Self {
+ Self {
+ key: None,
+ value: None,
+ }
+ }
+
+ /// Append new key/value tensors and return the full cached sequence.
+ /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim]
+ pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
+ let (full_key, full_value) = match (&self.key, &self.value) {
+ (Some(prev_k), Some(prev_v)) => {
+ let k = Tensor::cat(&[prev_k, key], 2)?;
+ let v = Tensor::cat(&[prev_v, value], 2)?;
+ (k, v)
+ }
+ _ => (key.clone(), value.clone()),
+ };
+
+ self.key = Some(full_key.clone());
+ self.value = Some(full_value.clone());
+
+ Ok((full_key, full_value))
+ }
+
+ /// Current cached sequence length.
+ pub fn seq_len(&self) -> usize {
+ self.key
+ .as_ref()
+ .map(|k| k.dim(2).unwrap_or(0))
+ .unwrap_or(0)
+ }
+
+ /// Reset the cache.
+ pub fn reset(&mut self) {
+ self.key = None;
+ self.value = None;
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Attention
+// ---------------------------------------------------------------------------
+
+/// Multi-head attention with GQA and RoPE.
+pub struct Qwen3Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ q_norm: RmsNorm,
+ k_norm: RmsNorm,
+ num_heads: usize,
+ num_kv_heads: usize,
+ head_dim: usize,
+ num_kv_groups: usize,
+}
+
+impl Qwen3Attention {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let hidden = config.hidden_size;
+ let num_heads = config.num_attention_heads;
+ let num_kv_heads = config.num_key_value_heads;
+ let head_dim = config.head_dim;
+
+ let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?;
+
+ let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?;
+ let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?;
+
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ q_norm,
+ k_norm,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ num_kv_groups: config.num_kv_groups(),
+ })
+ }
+
+ /// Forward pass with KV cache and RoPE.
+ /// Input: [batch, seq_len, hidden_size]
+ /// Returns: [batch, seq_len, hidden_size]
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (batch, seq_len, _) = hidden_states.dims3()?;
+ let offset = kv_cache.seq_len();
+
+ // Project Q, K, V
+ let q = self.q_proj.forward(hidden_states)?;
+ let k = self.k_proj.forward(hidden_states)?;
+ let v = self.v_proj.forward(hidden_states)?;
+
+ // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim]
+ let q = q
+ .reshape((batch, seq_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ // Apply QK normalization (Qwen3 specific)
+ let q = self.apply_head_norm(&q, &self.q_norm)?;
+ let k = self.apply_head_norm(&k, &self.k_norm)?;
+
+ // Apply RoPE
+ let (q, k) = rope.apply(&q, &k, offset)?;
+
+ // Update KV cache
+ let (k, v) = kv_cache.append(&k, &v)?;
+
+ // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim]
+ let k = self.repeat_kv(&k)?;
+ let v = self.repeat_kv(&v)?;
+
+ // Scaled dot-product attention
+ let scale = (self.head_dim as f64).sqrt();
+ let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?;
+
+ let attn_weights = match attention_mask {
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ None => attn_weights,
+ };
+
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+
+ // Attention output
+ let attn_output = attn_weights.matmul(&v)?;
+
+ // [batch, heads, seq, dim] -> [batch, seq, heads*dim]
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((batch, seq_len, self.num_heads * self.head_dim))?;
+
+ self.o_proj.forward(&attn_output)
+ }
+
+ /// Apply RMS norm per-head.
+ fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result<Tensor> {
+ let (b, h, s, d) = x.dims4()?;
+ // Reshape to [b*h*s, d] for norm, then back
+ let flat = x.reshape((b * h * s, d))?;
+ let normed = norm.forward(&flat)?;
+ normed.reshape((b, h, s, d))
+ }
+
+ /// Repeat KV heads for GQA.
+ fn repeat_kv(&self, x: &Tensor) -> Result<Tensor> {
+ if self.num_kv_groups == 1 {
+ return Ok(x.clone());
+ }
+ let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(2)?
+ .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))?
+ .reshape((batch, self.num_heads, seq_len, head_dim))?;
+ Ok(x)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// MLP
+// ---------------------------------------------------------------------------
+
+/// SiLU-gated feed-forward network.
+pub struct Qwen3Mlp {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+}
+
+impl Qwen3Mlp {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let hidden = config.hidden_size;
+ let intermediate = config.intermediate_size;
+
+ let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?;
+ let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?;
+ let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?;
+
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let gate = self.gate_proj.forward(x)?;
+ let gate = candle_nn::Activation::Silu.forward(&gate)?;
+ let up = self.up_proj.forward(x)?;
+ let hidden = (gate * up)?;
+ self.down_proj.forward(&hidden)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Transformer Layer
+// ---------------------------------------------------------------------------
+
+/// A single Qwen3 transformer decoder layer.
+pub struct Qwen3DecoderLayer {
+ self_attn: Qwen3Attention,
+ mlp: Qwen3Mlp,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl Qwen3DecoderLayer {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?;
+ let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?;
+ let input_layernorm =
+ rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ // Pre-norm attention
+ let residual = hidden_states;
+ let hidden_states = self.input_layernorm.forward(hidden_states)?;
+ let hidden_states =
+ self.self_attn
+ .forward(&hidden_states, rope, kv_cache, attention_mask)?;
+ let hidden_states = (residual + hidden_states)?;
+
+ // Pre-norm MLP
+ let residual = &hidden_states;
+ let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
+ let hidden_states = self.mlp.forward(&hidden_states)?;
+ let output = (residual + hidden_states)?;
+
+ Ok(output)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Full Model
+// ---------------------------------------------------------------------------
+
+/// The complete Qwen3 language model for TTS.
+///
+/// Architecture:
+/// - Token embedding layer
+/// - 28 transformer decoder layers
+/// - Final RMS normalization
+/// - LM head (projects to vocab)
+pub struct Qwen3Model {
+ embed_tokens: Embedding,
+ layers: Vec<Qwen3DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ rope: RotaryEmbedding,
+ config: Qwen3LmConfig,
+ /// Last hidden states (before lm_head), used by code predictor.
+ last_hidden: std::cell::RefCell<Option<Tensor>>,
+}
+
+impl Qwen3Model {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let model_vb = vb.pp("model");
+
+ let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("embed_tokens"))?;
+
+ let mut layers = Vec::with_capacity(config.num_hidden_layers);
+ for i in 0..config.num_hidden_layers {
+ let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
+ layers.push(layer);
+ }
+
+ let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?;
+
+ // LM head — may or may not share weights with embed_tokens
+ let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
+
+ let dtype = vb.dtype();
+ let device = vb.device().clone();
+ let rope = RotaryEmbedding::new(config, dtype, &device)?;
+
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ rope,
+ config: config.clone(),
+ last_hidden: std::cell::RefCell::new(None),
+ })
+ }
+
+ /// Forward pass through the full model.
+ ///
+ /// `input_ids`: [batch, seq_len] — token IDs
+ /// `kv_caches`: per-layer KV caches
+ /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len]
+ ///
+ /// Returns logits: [batch, seq_len, vocab_size]
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ kv_caches: &mut [KvCache],
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut hidden_states = self.embed_tokens.forward(input_ids)?;
+
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden_states =
+ layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
+ }
+
+ hidden_states = self.norm.forward(&hidden_states)?;
+
+ // Store last hidden state for code predictor
+ *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
+
+ let logits = self.lm_head.forward(&hidden_states)?;
+ Ok(logits)
+ }
+
+ /// Forward pass with pre-computed embeddings (for first iteration where
+ /// text embeddings are concatenated with audio features).
+ ///
+ /// `inputs_embeds`: [batch, seq_len, hidden_size]
+ pub fn forward_embeds(
+ &self,
+ inputs_embeds: &Tensor,
+ kv_caches: &mut [KvCache],
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut hidden_states = inputs_embeds.clone();
+
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden_states =
+ layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
+ }
+
+ hidden_states = self.norm.forward(&hidden_states)?;
+
+ *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
+
+ let logits = self.lm_head.forward(&hidden_states)?;
+ Ok(logits)
+ }
+
+ /// Get the last hidden states (for the code predictor).
+ pub fn last_hidden_state(&self) -> Option<Tensor> {
+ self.last_hidden.borrow().clone()
+ }
+
+ /// Number of transformer layers.
+ pub fn num_layers(&self) -> usize {
+ self.config.num_hidden_layers
+ }
+
+ /// Hidden size.
+ pub fn hidden_size(&self) -> usize {
+ self.config.hidden_size
+ }
+
+ /// Get token embedding layer (for input preparation).
+ pub fn embed_tokens(&self) -> &Embedding {
+ &self.embed_tokens
+ }
+
+ /// Create a causal attention mask.
+ pub fn make_causal_mask(
+ seq_len: usize,
+ past_len: usize,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let total_len = past_len + seq_len;
+
+ if seq_len == 1 {
+ // Single token: no masking needed (can attend to everything)
+ return Tensor::zeros((1, 1, 1, total_len), dtype, device);
+ }
+
+ // Full causal mask: lower triangular
+ let mask: Vec<f32> = (0..seq_len)
+ .flat_map(|i| {
+ (0..total_len).map(move |j| {
+ if j <= past_len + i {
+ 0.0
+ } else {
+ f32::NEG_INFINITY
+ }
+ })
+ })
+ .collect();
+
+ Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_kv_cache() {
+ let device = Device::Cpu;
+ let mut cache = KvCache::new();
+ assert_eq!(cache.seq_len(), 0);
+
+ let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
+ let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
+ let (fk, _fv) = cache.append(&k, &v).unwrap();
+ assert_eq!(cache.seq_len(), 5);
+ assert_eq!(fk.dim(2).unwrap(), 5);
+
+ let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
+ let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
+ let (fk2, _fv2) = cache.append(&k2, &v2).unwrap();
+ assert_eq!(cache.seq_len(), 6);
+ assert_eq!(fk2.dim(2).unwrap(), 6);
+
+ cache.reset();
+ assert_eq!(cache.seq_len(), 0);
+ }
+
+ #[test]
+ fn test_causal_mask_single_token() {
+ let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap();
+ assert_eq!(mask.dims(), &[1, 1, 1, 11]);
+ // All zeros — single token can attend to everything
+ let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
+ assert_eq!(sum, 0.0);
+ }
+
+ #[test]
+ fn test_causal_mask_multi_token() {
+ let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap();
+ assert_eq!(mask.dims(), &[1, 1, 3, 3]);
+ // Upper triangle should be -inf
+ let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
+ // Row 0: [0, -inf, -inf]
+ assert_eq!(data[0], 0.0);
+ assert!(data[1].is_infinite() && data[1] < 0.0);
+ assert!(data[2].is_infinite() && data[2] < 0.0);
+ // Row 1: [0, 0, -inf]
+ assert_eq!(data[3], 0.0);
+ assert_eq!(data[4], 0.0);
+ assert!(data[5].is_infinite() && data[5] < 0.0);
+ // Row 2: [0, 0, 0]
+ assert_eq!(data[6], 0.0);
+ assert_eq!(data[7], 0.0);
+ assert_eq!(data[8], 0.0);
+ }
+}
diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs
new file mode 100644
index 0000000..752050a
--- /dev/null
+++ b/makima/src/tts/qwen3/speech_tokenizer.rs
@@ -0,0 +1,612 @@
+//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks.
+//!
+//! Two sub-components:
+//!
+//! **Encoder** (voice cloning): converts reference audio waveform to discrete
+//! multi-codebook tokens via a causal 1D ConvNet + RVQ.
+//!
+//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook
+//! indices via embedding lookup + causal 1D ConvNet.
+//!
+//! The speech tokenizer is a separate model (~682MB) loaded from
+//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`.
+
+use candle_core::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{
+ conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder,
+};
+
+use super::config::SpeechTokenizerConfig;
+
+// ---------------------------------------------------------------------------
+// Weight-Normalized Conv1d
+// ---------------------------------------------------------------------------
+
+/// A 1D convolution with optional weight normalization and activation.
+pub struct ConvBlock {
+ conv: Conv1d,
+ activation: ConvActivation,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum ConvActivation {
+ None,
+ Elu,
+ Tanh,
+}
+
+impl ConvBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ stride: usize,
+ padding: usize,
+ dilation: usize,
+ activation: ConvActivation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let config = Conv1dConfig {
+ stride,
+ padding,
+ dilation,
+ groups: 1,
+ };
+ let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?;
+
+ Ok(Self { conv, activation })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let out = self.conv.forward(x)?;
+ match self.activation {
+ ConvActivation::None => Ok(out),
+ ConvActivation::Elu => elu(&out, 1.0),
+ ConvActivation::Tanh => out.tanh(),
+ }
+ }
+}
+
+/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0
+fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> {
+ let zeros = x.zeros_like()?;
+ let positive = x.maximum(&zeros)?;
+ let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?;
+ let exp_x = x.exp()?;
+ let one = Tensor::ones_like(&exp_x)?;
+ let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?;
+ positive + negative
+}
+
+// ---------------------------------------------------------------------------
+// Residual Unit
+// ---------------------------------------------------------------------------
+
+/// Residual convolutional unit with dilated convolutions.
+pub struct ResidualUnit {
+ conv1: ConvBlock,
+ conv2: ConvBlock,
+}
+
+impl ResidualUnit {
+ pub fn new(
+ channels: usize,
+ dilation: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ // Dilated causal conv (kernel=7, dilation varies)
+ let padding = (7 - 1) * dilation / 2; // causal-ish padding
+ let conv1 = ConvBlock::new(
+ channels,
+ channels,
+ 7,
+ 1,
+ padding,
+ dilation,
+ ConvActivation::Elu,
+ vb.pp("block.0"),
+ )?;
+
+ // Pointwise conv (kernel=1)
+ let conv2 = ConvBlock::new(
+ channels,
+ channels,
+ 1,
+ 1,
+ 0,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("block.1"),
+ )?;
+
+ Ok(Self { conv1, conv2 })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let residual = x;
+ let out = self.conv1.forward(x)?;
+ let out = self.conv2.forward(&out)?;
+ // Match sequence lengths if needed (causal conv may change length)
+ let out_len = out.dim(D::Minus1)?;
+ let res_len = residual.dim(D::Minus1)?;
+ if out_len != res_len {
+ let start = res_len.saturating_sub(out_len);
+ let residual = residual.narrow(D::Minus1, start, out_len)?;
+ residual + out
+ } else {
+ residual + out
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Encoder Block
+// ---------------------------------------------------------------------------
+
+/// Encoder downsampling block: residual units + strided conv.
+pub struct EncoderBlock {
+ residual_units: Vec<ResidualUnit>,
+ downsample: ConvBlock,
+}
+
+impl EncoderBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ stride: usize,
+ num_residuals: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut residual_units = Vec::with_capacity(num_residuals);
+ for i in 0..num_residuals {
+ let dilation = 3usize.pow(i as u32); // 1, 3, 9
+ let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?;
+ residual_units.push(unit);
+ }
+
+ // Strided downsampling convolution
+ let kernel_size = stride * 2;
+ let padding = stride / 2;
+ let downsample = ConvBlock::new(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("downsample"),
+ )?;
+
+ Ok(Self {
+ residual_units,
+ downsample,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let mut out = x.clone();
+ for unit in &self.residual_units {
+ out = unit.forward(&out)?;
+ }
+ self.downsample.forward(&out)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Decoder Block
+// ---------------------------------------------------------------------------
+
+/// Decoder upsampling block: transposed conv + residual units.
+pub struct DecoderBlock {
+ upsample: ConvBlock,
+ residual_units: Vec<ResidualUnit>,
+}
+
+impl DecoderBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ stride: usize,
+ num_residuals: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ // Strided upsampling (transpose conv simulated by regular conv + padding)
+ let kernel_size = stride * 2;
+ let padding = stride / 2;
+ let upsample = ConvBlock::new(
+ in_channels,
+ out_channels,
+ kernel_size,
+ 1, // stride=1 for output; upsample via repeat/interpolation
+ padding,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("upsample"),
+ )?;
+
+ let mut residual_units = Vec::with_capacity(num_residuals);
+ for i in 0..num_residuals {
+ let dilation = 3usize.pow(i as u32);
+ let unit =
+ ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?;
+ residual_units.push(unit);
+ }
+
+ Ok(Self {
+ upsample,
+ residual_units,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let mut out = self.upsample.forward(x)?;
+ for unit in &self.residual_units {
+ out = unit.forward(&out)?;
+ }
+ Ok(out)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// RVQ Codebook
+// ---------------------------------------------------------------------------
+
+/// Residual Vector Quantization codebook.
+///
+/// Contains `num_codebooks` embedding tables, each mapping
+/// `codebook_size` indices to `codebook_dim`-dimensional vectors.
+pub struct RvqCodebook {
+ codebooks: Vec<Embedding>,
+ num_codebooks: usize,
+ codebook_dim: usize,
+}
+
+impl RvqCodebook {
+ pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> {
+ let mut codebooks = Vec::with_capacity(config.num_codebooks);
+ for i in 0..config.num_codebooks {
+ let cb = embedding(
+ config.codebook_size,
+ config.codebook_dim,
+ vb.pp(format!("codebooks.{i}")),
+ )?;
+ codebooks.push(cb);
+ }
+
+ Ok(Self {
+ codebooks,
+ num_codebooks: config.num_codebooks,
+ codebook_dim: config.codebook_dim,
+ })
+ }
+
+ /// Look up codebook embeddings for all codebook layers.
+ ///
+ /// `codes`: [num_codebooks, seq_len] — codebook indices per layer
+ /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings
+ pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
+ assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks);
+
+ let seq_len = codes[0].len();
+ let mut sum: Option<Tensor> = None;
+
+ for (i, code_layer) in codes.iter().enumerate() {
+ assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch");
+
+ let indices = Tensor::from_vec(
+ code_layer.clone(),
+ (1, seq_len),
+ device,
+ )?;
+
+ // [1, seq_len, codebook_dim]
+ let emb = self.codebooks[i].forward(&indices)?;
+
+ sum = Some(match sum {
+ Some(prev) => (prev + emb)?,
+ None => emb,
+ });
+ }
+
+ // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
+ let result = sum.unwrap().transpose(1, 2)?;
+ Ok(result)
+ }
+
+ /// Number of codebooks.
+ pub fn num_codebooks(&self) -> usize {
+ self.num_codebooks
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Speech Tokenizer (Encoder + Decoder)
+// ---------------------------------------------------------------------------
+
+/// The complete speech tokenizer with encoder and decoder.
+pub struct SpeechTokenizer {
+ /// Encoder: waveform -> latent (for voice cloning).
+ encoder_input_conv: ConvBlock,
+ encoder_blocks: Vec<EncoderBlock>,
+ encoder_output_conv: ConvBlock,
+
+ /// RVQ codebooks for quantization.
+ codebook: RvqCodebook,
+
+ /// Decoder: codes -> waveform.
+ decoder_input_conv: ConvBlock,
+ decoder_blocks: Vec<DecoderBlock>,
+ decoder_output_conv: ConvBlock,
+
+ /// Projection from codebook dim to decoder hidden channels.
+ decoder_proj: Linear,
+
+ config: SpeechTokenizerConfig,
+ device: Device,
+}
+
+impl SpeechTokenizer {
+ /// Load the speech tokenizer from safetensors.
+ pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> {
+ let hidden = config.hidden_channels; // 512
+
+ // ===== Encoder =====
+ // Input: [batch, 1, samples] -> [batch, hidden/8, ...]
+ let encoder_input_conv = ConvBlock::new(
+ 1,
+ hidden / 8, // 64
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("encoder.input_conv"),
+ )?;
+
+ // Downsampling blocks with increasing channels
+ let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480
+ let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512
+ let mut encoder_blocks = Vec::with_capacity(strides.len());
+ for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() {
+ let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] };
+ let block = EncoderBlock::new(
+ in_ch,
+ out_ch,
+ stride,
+ 3, // 3 residual units per block
+ vb.pp(format!("encoder.blocks.{i}")),
+ )?;
+ encoder_blocks.push(block);
+ }
+
+ // Encoder output projection to codebook dim
+ let encoder_output_conv = ConvBlock::new(
+ hidden,
+ config.codebook_dim,
+ 3,
+ 1,
+ 1,
+ 1,
+ ConvActivation::None,
+ vb.pp("encoder.output_conv"),
+ )?;
+
+ // ===== RVQ Codebook =====
+ let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?;
+
+ // ===== Decoder =====
+ // Projection from codebook dim to decoder hidden
+ let decoder_proj = linear_no_bias(
+ config.codebook_dim,
+ hidden,
+ vb.pp("decoder.proj"),
+ )?;
+
+ // Input conv
+ let decoder_input_conv = ConvBlock::new(
+ hidden,
+ hidden,
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("decoder.input_conv"),
+ )?;
+
+ // Upsampling blocks (reverse order of encoder)
+ let dec_strides = [3, 4, 5, 8];
+ let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64
+ let mut decoder_blocks = Vec::with_capacity(dec_strides.len());
+ for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate()
+ {
+ let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] };
+ let block = DecoderBlock::new(
+ in_ch,
+ out_ch,
+ stride,
+ 3,
+ vb.pp(format!("decoder.blocks.{i}")),
+ )?;
+ decoder_blocks.push(block);
+ }
+
+ // Output conv: hidden/8 -> 1 channel (waveform)
+ let decoder_output_conv = ConvBlock::new(
+ hidden / 8,
+ 1,
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Tanh,
+ vb.pp("decoder.output_conv"),
+ )?;
+
+ Ok(Self {
+ encoder_input_conv,
+ encoder_blocks,
+ encoder_output_conv,
+ codebook,
+ decoder_input_conv,
+ decoder_blocks,
+ decoder_output_conv,
+ decoder_proj,
+ config: config.clone(),
+ device: device.clone(),
+ })
+ }
+
+ /// Encode reference audio waveform to discrete codebook tokens.
+ ///
+ /// `audio`: [num_samples] — mono 24kHz audio
+ /// Returns: Vec of `num_codebooks` vectors, each containing token indices.
+ pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> {
+ // [1, 1, num_samples]
+ let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?;
+
+ // Run encoder
+ let mut hidden = self.encoder_input_conv.forward(&x)?;
+ for block in &self.encoder_blocks {
+ hidden = block.forward(&hidden)?;
+ }
+ let latent = self.encoder_output_conv.forward(&hidden)?;
+
+ // latent: [1, codebook_dim, seq_len]
+ // Quantize via nearest-neighbor lookup in each codebook
+ let seq_len = latent.dim(D::Minus1)?;
+ let mut all_codes = Vec::with_capacity(self.config.num_codebooks);
+
+ // Residual quantization: subtract each codebook's contribution
+ let mut residual = latent.clone();
+
+ for cb_idx in 0..self.config.num_codebooks {
+ // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep
+ let codes = self.quantize_layer(&residual, cb_idx, seq_len)?;
+
+ // Look up the quantized vectors and subtract from residual
+ let code_indices =
+ Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?;
+ let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?;
+ // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
+ let quantized = quantized.transpose(1, 2)?;
+ residual = (residual - quantized)?;
+
+ all_codes.push(codes);
+ }
+
+ Ok(all_codes)
+ }
+
+ /// Quantize a single RVQ layer by finding the nearest codebook entry.
+ fn quantize_layer(
+ &self,
+ residual: &Tensor,
+ codebook_idx: usize,
+ _seq_len: usize,
+ ) -> Result<Vec<u32>> {
+ // residual: [1, codebook_dim, seq_len]
+ // codebook weights: [codebook_size, codebook_dim]
+ let cb_weight = self.codebook.codebooks[codebook_idx]
+ .embeddings()
+ .clone(); // [codebook_size, codebook_dim]
+
+ // Transpose residual: [1, seq_len, codebook_dim]
+ let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim]
+
+ // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2
+ let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len]
+ let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size]
+ let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size]
+
+ let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1]
+ let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size]
+
+ let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size]
+
+ // Argmin per timestep
+ let indices = distances.argmin(D::Minus1)?; // [seq_len]
+ let codes: Vec<u32> = indices.to_vec1()?;
+
+ Ok(codes)
+ }
+
+ /// Decode discrete codebook tokens to audio waveform.
+ ///
+ /// `codes`: Vec of `num_codebooks` vectors of token indices.
+ /// Returns: Vec<f32> — mono 24kHz audio samples.
+ pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> {
+ // Look up and sum all codebook embeddings
+ let embeddings = self.codebook.decode(codes, &self.device)?;
+ // embeddings: [1, codebook_dim, seq_len]
+
+ // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden]
+ let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim]
+ let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden]
+ let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len]
+
+ // Run decoder
+ hidden = self.decoder_input_conv.forward(&hidden)?;
+ for block in &self.decoder_blocks {
+ hidden = block.forward(&hidden)?;
+ }
+ let waveform = self.decoder_output_conv.forward(&hidden)?;
+
+ // [1, 1, num_samples] -> Vec<f32>
+ let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?;
+ Ok(samples)
+ }
+
+ /// Decode a single frame's codes to audio samples (for streaming).
+ ///
+ /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame
+ /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz)
+ pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> {
+ let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect();
+ self.decode(&codes)
+ }
+
+ /// Get the number of codebooks.
+ pub fn num_codebooks(&self) -> usize {
+ self.config.num_codebooks
+ }
+
+ /// Get the output sample rate.
+ pub fn sample_rate(&self) -> u32 {
+ self.config.sample_rate
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_elu_positive() {
+ let device = Device::Cpu;
+ let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
+ let result = elu(&x, 1.0).unwrap();
+ let values: Vec<f32> = result.to_vec1().unwrap();
+ assert!((values[0] - 1.0).abs() < 1e-5);
+ assert!((values[1] - 2.0).abs() < 1e-5);
+ }
+
+ #[test]
+ fn test_elu_negative() {
+ let device = Device::Cpu;
+ let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap();
+ let result = elu(&x, 1.0).unwrap();
+ let values: Vec<f32> = result.to_vec1().unwrap();
+ // ELU(-1) = exp(-1) - 1 ≈ -0.6321
+ assert!((values[0] - (-0.6321)).abs() < 0.01);
+ }
+
+ #[test]
+ fn test_speech_tokenizer_config() {
+ let config = SpeechTokenizerConfig::default();
+ assert_eq!(config.num_codebooks, 16);
+ assert_eq!(config.codebook_size, 2048);
+ assert_eq!(config.sample_rate, 24_000);
+ }
+}