summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-28 02:54:17 +0000
committerGitHub <noreply@github.com>2026-01-28 02:54:17 +0000
commiteabd1304cce0e053cd32ec910d2f0ea429e8af14 (patch)
treefca3b08810a1dc0c0c610a8189a466cc23d5c547
parentc618174e60e4632d36d7352d83399508c72b2f42 (diff)
downloadsoryu-eabd1304cce0e053cd32ec910d2f0ea429e8af14.tar.gz
soryu-eabd1304cce0e053cd32ec910d2f0ea429e8af14.zip
Add Qwen3-TTS streaming endpoint for voice synthesis (#40)
* Task completion checkpoint * Task completion checkpoint * Task completion checkpoint * Add Qwen3-TTS research document for live TTS replacement Research findings for replacing Chatterbox TTS with Qwen3-TTS-12Hz-0.6B-Base: - Current TTS: Chatterbox-Turbo-ONNX with batch-only generation, no streaming - Qwen3-TTS: 97ms end-to-end latency, streaming support, 3-second voice cloning - Voice cloning: Requires 3s reference audio + transcript (Makima voice planned) - Integration: Python service with WebSocket bridge (no ONNX export available) - Languages: 10 supported including English and Japanese Document includes: - Current architecture analysis (makima/src/tts.rs) - Qwen3-TTS capabilities and requirements - Feasibility assessment for live/streaming TTS - Audio clip requirements for voice cloning - Preliminary technical approach with architecture diagrams Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:11:15 UTC * Add Qwen3-TTS research documentation Comprehensive research on replacing Chatterbox TTS with Qwen3-TTS-12Hz-0.6B-Base: - Current TTS implementation analysis (Chatterbox-Turbo-ONNX in makima/src/tts.rs) - Qwen3-TTS capabilities: 97ms streaming latency, voice cloning with 3s reference - Cross-lingual support: Japanese voice (Makima/Tomori Kusunoki) speaking English - Python microservice architecture recommendation (FastAPI + WebSocket) - Implementation phases and technical approach - Hardware requirements and dependencies Key findings: - Live/streaming TTS is highly feasible with 97ms latency - Voice cloning fully supported with 0.95 speaker similarity - Recommended: Python microservice with WebSocket streaming Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add comprehensive Qwen3-TTS integration specification This specification document defines the complete integration of Qwen3-TTS-12Hz-0.6B-Base as a replacement for the existing Chatterbox-Turbo TTS implementation. The document covers: ## Functional Requirements - WebSocket endpoint /api/v1/speak for streaming TTS - Voice cloning with default Makima voice (Japanese VA speaking English) - Support for custom voice references - Detailed client-to-server and server-to-client message protocols - Integration with Listen page for bidirectional speech ## Non-Functional Requirements - Latency targets: < 200ms first audio byte - Audio quality: 24kHz, mono, PCM16/PCM32f - Hardware requirements: CUDA GPU with 4-8GB VRAM - Scalability: 10 concurrent sessions per GPU ## Architecture Specification - Python TTS microservice with FastAPI/WebSocket - Rust proxy endpoint in makima server - Voice prompt caching mechanism (LRU cache) - Error handling and recovery strategies ## API Contract - Complete WebSocket message format definitions (TypeScript) - Error codes and responses (TTS_UNAVAILABLE, SYNTHESIS_ERROR, etc.) - Session state machine and lifecycle management ## Voice Asset Requirements - Makima voice clip specifications (5-10s WAV, transcript required) - Storage location: models/voices/makima/ - Metadata format for voice management ## Testing Strategy - Unit tests for Python TTS service and Rust proxy - Integration tests for WebSocket flow - Latency benchmarks with performance targets - Test data fixtures for various text lengths Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Qwen3-TTS implementation plan Comprehensive implementation plan for replacing Chatterbox-TTS with Qwen3-TTS streaming TTS service, including: - Task breakdown with estimated hours for each phase - Phase 1: Python TTS microservice (FastAPI, WebSocket) - Phase 2: Rust proxy integration (speak.rs, tts_client.rs) - Detailed file changes and new module structure - Testing plan with unit, integration, and latency benchmarks - Risk assessment with mitigation strategies - Success criteria for each phase Based on specification in docs/specs/qwen3-tts-spec.md Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add author and research references to TTS implementation plan Add links to research documentation and author attribution. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:25:06 UTC * Add Python TTS service project structure (Phase 1.1-1.3) Create the initial makima-tts Python service directory structure with: - pyproject.toml with FastAPI, Qwen-TTS, and torch dependencies - config.py with pydantic-settings TTSConfig class - models.py with Pydantic message models (Start, Speak, Stop, Ready, etc.) This implements tasks P1.1, P1.2, and P1.3 from the Qwen3-TTS implementation plan. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add TTS engine and voice manager for Qwen3-TTS (Phase 1.4-1.5) Implement core TTS functionality: - tts_engine.py: Qwen3-TTS wrapper with streaming audio chunk generation - voice_manager.py: Voice prompt caching with LRU eviction and TTL support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 03:30:06 UTC * Add TTS proxy client and message types (Phase 2.1, 2.2, 2.4) - Add tts_client.rs with TtsConfig, TtsCircuitBreaker, TtsError, TtsProxyClient, and TtsConnection structs for WebSocket proxying - Add TTS message types to messages.rs (TtsAudioEncoding, TtsPriority, TtsStartMessage, TtsSpeakMessage, TtsStopMessage, TtsClientMessage, TtsReadyMessage, TtsAudioChunkMessage, TtsCompleteMessage, TtsErrorMessage, TtsStoppedMessage, TtsServerMessage) - Export tts_client module from server mod.rs - tokio-tungstenite already present in Cargo.toml Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add TTS WebSocket handler and route (Phase 2.3, 2.5, 2.6) - Create speak.rs WebSocket handler that proxies to Python TTS service - Add TtsState fields (tts_client, tts_config) to AppState - Add with_tts() builder and is_tts_healthy() methods to AppState - Register /api/v1/speak route in the router - Add speak module export in handlers/mod.rs The handler forwards WebSocket messages bidirectionally between the client and the Python TTS microservice with proper error handling. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Makima voice profile assets for TTS voice cloning Creates the voice assets directory structure with: - manifest.json containing voice configuration (voice_id, speaker, language, reference audio path, and Japanese transcript placeholder) - README.md with instructions for obtaining voice reference audio Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add Rust-native Qwen3-TTS integration research document Research findings for integrating Qwen3-TTS-12Hz-0.6B-Base directly into the makima Rust codebase without Python. Key conclusions: - ONNX export is not viable (unsupported architecture) - Candle (HF Rust ML framework) is the recommended approach - Model weights available in safetensors format (2.52GB total) - Three components needed: LM backbone, code predictor, speech tokenizer - Crane project has Qwen3-TTS as highest priority (potential upstream) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * [WIP] Heartbeat checkpoint - 2026-01-27 11:21:43 UTC * [WIP] Heartbeat checkpoint - 2026-01-27 11:24:19 UTC * [WIP] Heartbeat checkpoint - 2026-01-27 11:26:43 UTC * feat: implement Rust-native Qwen3-TTS using candle framework Replace monolithic tts.rs with modular tts/ directory structure: - tts/mod.rs: TtsEngine trait, TtsEngineFactory, shared types (AudioChunk, TtsError), and utility functions (save_wav, resample, argmax) - tts/chatterbox.rs: existing ONNX-based ChatterboxTTS adapted to implement TtsEngine trait with Mutex-wrapped sessions for Send+Sync - tts/qwen3/mod.rs: Qwen3Tts entry point with HuggingFace model loading - tts/qwen3/config.rs: Qwen3TtsConfig parsing from HF config.json - tts/qwen3/model.rs: 28-layer Qwen3 transformer with RoPE, GQA (16 heads, 8 KV heads), SiLU MLP, RMS norm, and KV cache - tts/qwen3/code_predictor.rs: 5-layer MTP module predicting 16 codebooks - tts/qwen3/speech_tokenizer.rs: ConvNet encoder/decoder with 16-layer RVQ - tts/qwen3/generate.rs: autoregressive generation loop with streaming support Add candle-core, candle-nn, candle-transformers, safetensors to Cargo.toml. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: integrate TTS engine into speak WebSocket handler - Update speak.rs handler to use TTS engine directly from SharedState instead of returning a stub "not implemented" error - Add TtsEngine (OnceCell lazy-loaded) to AppState in state.rs with get_tts_engine() method for lazy initialization on first connection - Implement full WebSocket protocol: client sends JSON speak/cancel/stop messages, server streams binary PCM audio chunks and audio_end signals - Create voices/makima/manifest.json for Makima voice profile configuration - All files compile successfully with zero errors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat: add /speak TTS page with WebSocket audio playback Add a new /speak frontend page for text-to-speech via WebSocket. The page accepts text input and streams synthesized PCM audio through the Web Audio API. Includes model loading indicator, cancel support, and connection status. Also adds a loading bar to the listen page ControlPanel during WebSocket connection. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
-rw-r--r--Cargo.lock688
-rw-r--r--TTS_RESEARCH.md351
-rw-r--r--docs/plans/qwen3-tts-implementation-plan.md514
-rw-r--r--docs/research/TTS_RESEARCH_NOTES.md405
-rw-r--r--docs/research/rust-native-tts-research.md297
-rw-r--r--docs/research/tts-qwen3-research.md548
-rw-r--r--docs/specs/qwen3-tts-spec.md928
-rw-r--r--makima/Cargo.toml6
-rw-r--r--makima/frontend/src/components/listen/ControlPanel.tsx36
-rw-r--r--makima/frontend/src/hooks/useSpeakWebSocket.ts329
-rw-r--r--makima/frontend/src/index.css6
-rw-r--r--makima/frontend/src/lib/api.ts1
-rw-r--r--makima/frontend/src/main.tsx9
-rw-r--r--makima/frontend/src/routes/listen.tsx1
-rw-r--r--makima/frontend/src/routes/speak.tsx159
-rw-r--r--makima/src/main.rs14
-rw-r--r--makima/src/server/handlers/mod.rs1
-rw-r--r--makima/src/server/handlers/speak.rs274
-rw-r--r--makima/src/server/messages.rs161
-rw-r--r--makima/src/server/mod.rs3
-rw-r--r--makima/src/server/state.rs22
-rw-r--r--makima/src/tts/chatterbox.rs (renamed from makima/src/tts.rs)391
-rw-r--r--makima/src/tts/mod.rs281
-rw-r--r--makima/src/tts/qwen3/code_predictor.rs261
-rw-r--r--makima/src/tts/qwen3/config.rs271
-rw-r--r--makima/src/tts/qwen3/generate.rs426
-rw-r--r--makima/src/tts/qwen3/mod.rs287
-rw-r--r--makima/src/tts/qwen3/model.rs581
-rw-r--r--makima/src/tts/qwen3/speech_tokenizer.rs612
-rw-r--r--voices/makima/manifest.json12
30 files changed, 7602 insertions, 273 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 1aeb184..30e65ff 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -250,6 +250,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a"
[[package]]
+name = "bit-set"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+dependencies = [
+ "bit-vec",
+]
+
+[[package]]
+name = "bit-vec"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+
+[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -284,6 +299,20 @@ name = "bytemuck"
version = "1.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
+dependencies = [
+ "bytemuck_derive",
+]
+
+[[package]]
+name = "bytemuck_derive"
+version = "1.10.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
[[package]]
name = "byteorder"
@@ -298,6 +327,62 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
[[package]]
+name = "candle-core"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1"
+dependencies = [
+ "byteorder",
+ "gemm 0.17.1",
+ "half",
+ "memmap2",
+ "num-traits",
+ "num_cpus",
+ "rand 0.9.2",
+ "rand_distr",
+ "rayon",
+ "safetensors",
+ "thiserror 1.0.69",
+ "ug",
+ "yoke 0.7.5",
+ "zip 1.1.4",
+]
+
+[[package]]
+name = "candle-nn"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be1160c3b63f47d40d91110a3e1e1e566ae38edddbbf492a60b40ffc3bc1ff38"
+dependencies = [
+ "candle-core",
+ "half",
+ "num-traits",
+ "rayon",
+ "safetensors",
+ "serde",
+ "thiserror 1.0.69",
+]
+
+[[package]]
+name = "candle-transformers"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020"
+dependencies = [
+ "byteorder",
+ "candle-core",
+ "candle-nn",
+ "fancy-regex",
+ "num-traits",
+ "rand 0.9.2",
+ "rayon",
+ "serde",
+ "serde_json",
+ "serde_plain",
+ "tracing",
+]
+
+[[package]]
name = "cassowary"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -885,6 +970,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555"
[[package]]
+name = "dyn-stack"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b"
+dependencies = [
+ "bytemuck",
+ "reborrow",
+]
+
+[[package]]
+name = "dyn-stack"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8"
+dependencies = [
+ "bytemuck",
+ "dyn-stack-macros",
+]
+
+[[package]]
+name = "dyn-stack-macros"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9"
+
+[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -909,6 +1020,18 @@ dependencies = [
]
[[package]]
+name = "enum-as-inner"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
+dependencies = [
+ "heck",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -984,6 +1107,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
+name = "fancy-regex"
+version = "0.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2"
+dependencies = [
+ "bit-set",
+ "regex-automata",
+ "regex-syntax",
+]
+
+[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1186,6 +1320,243 @@ dependencies = [
]
[[package]]
+name = "gemm"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-c32 0.17.1",
+ "gemm-c64 0.17.1",
+ "gemm-common 0.17.1",
+ "gemm-f16 0.17.1",
+ "gemm-f32 0.17.1",
+ "gemm-f64 0.17.1",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-c32 0.18.2",
+ "gemm-c64 0.18.2",
+ "gemm-common 0.18.2",
+ "gemm-f16 0.18.2",
+ "gemm-f32 0.18.2",
+ "gemm-f64 0.18.2",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-c32"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-common 0.17.1",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-c32"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-common 0.18.2",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-c64"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-common 0.17.1",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-c64"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-common 0.18.2",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-common"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
+dependencies = [
+ "bytemuck",
+ "dyn-stack 0.10.0",
+ "half",
+ "num-complex",
+ "num-traits",
+ "once_cell",
+ "paste",
+ "pulp 0.18.22",
+ "raw-cpuid 10.7.0",
+ "rayon",
+ "seq-macro",
+ "sysctl 0.5.5",
+]
+
+[[package]]
+name = "gemm-common"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
+dependencies = [
+ "bytemuck",
+ "dyn-stack 0.13.2",
+ "half",
+ "libm",
+ "num-complex",
+ "num-traits",
+ "once_cell",
+ "paste",
+ "pulp 0.21.5",
+ "raw-cpuid 11.6.0",
+ "rayon",
+ "seq-macro",
+ "sysctl 0.6.0",
+]
+
+[[package]]
+name = "gemm-f16"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-common 0.17.1",
+ "gemm-f32 0.17.1",
+ "half",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "rayon",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-f16"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-common 0.18.2",
+ "gemm-f32 0.18.2",
+ "half",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "rayon",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-f32"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-common 0.17.1",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-f32"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-common 0.18.2",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-f64"
+version = "0.17.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0"
+dependencies = [
+ "dyn-stack 0.10.0",
+ "gemm-common 0.17.1",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 10.7.0",
+ "seq-macro",
+]
+
+[[package]]
+name = "gemm-f64"
+version = "0.18.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd"
+dependencies = [
+ "dyn-stack 0.13.2",
+ "gemm-common 0.18.2",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "raw-cpuid 11.6.0",
+ "seq-macro",
+]
+
+[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1246,6 +1617,21 @@ dependencies = [
]
[[package]]
+name = "half"
+version = "2.7.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b"
+dependencies = [
+ "bytemuck",
+ "cfg-if",
+ "crunchy",
+ "num-traits",
+ "rand 0.9.2",
+ "rand_distr",
+ "zerocopy",
+]
+
+[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1549,7 +1935,7 @@ checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
- "yoke",
+ "yoke 0.8.1",
"zerofrom",
"zerovec",
]
@@ -1616,7 +2002,7 @@ dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
- "yoke",
+ "yoke 0.8.1",
"zerofrom",
"zerotrie",
"zerovec",
@@ -1896,6 +2282,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
+name = "libloading"
+version = "0.8.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55"
+dependencies = [
+ "cfg-if",
+ "windows-link",
+]
+
+[[package]]
name = "libm"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2001,6 +2397,9 @@ dependencies = [
"backoff",
"base64 0.22.1",
"bytes",
+ "candle-core",
+ "candle-nn",
+ "candle-transformers",
"chrono",
"clap",
"config",
@@ -2030,6 +2429,7 @@ dependencies = [
"regex",
"reqwest",
"rusqlite",
+ "safetensors",
"serde",
"serde_json",
"sha2",
@@ -2092,6 +2492,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
+name = "memmap2"
+version = "0.9.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490"
+dependencies = [
+ "libc",
+ "stable_deref_trait",
+]
+
+[[package]]
name = "memoffset"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2261,6 +2671,20 @@ dependencies = [
]
[[package]]
+name = "num"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23"
+dependencies = [
+ "num-bigint",
+ "num-complex",
+ "num-integer",
+ "num-iter",
+ "num-rational",
+ "num-traits",
+]
+
+[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2292,6 +2716,7 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
+ "bytemuck",
"num-traits",
]
@@ -2322,6 +2747,17 @@ dependencies = [
]
[[package]]
+name = "num-rational"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
+dependencies = [
+ "num-bigint",
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2342,6 +2778,28 @@ dependencies = [
]
[[package]]
+name = "num_enum"
+version = "0.7.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c"
+dependencies = [
+ "num_enum_derive",
+ "rustversion",
+]
+
+[[package]]
+name = "num_enum_derive"
+version = "0.7.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7"
+dependencies = [
+ "proc-macro-crate",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "number_prefix"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2699,6 +3157,15 @@ dependencies = [
]
[[package]]
+name = "proc-macro-crate"
+version = "3.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983"
+dependencies = [
+ "toml_edit 0.23.10+spec-1.0.0",
+]
+
+[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2708,6 +3175,32 @@ dependencies = [
]
[[package]]
+name = "pulp"
+version = "0.18.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
+dependencies = [
+ "bytemuck",
+ "libm",
+ "num-complex",
+ "reborrow",
+]
+
+[[package]]
+name = "pulp"
+version = "0.21.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907"
+dependencies = [
+ "bytemuck",
+ "cfg-if",
+ "libm",
+ "num-complex",
+ "reborrow",
+ "version_check",
+]
+
+[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2782,6 +3275,16 @@ dependencies = [
]
[[package]]
+name = "rand_distr"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
+dependencies = [
+ "num-traits",
+ "rand 0.9.2",
+]
+
+[[package]]
name = "ratatui"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2803,6 +3306,24 @@ dependencies = [
]
[[package]]
+name = "raw-cpuid"
+version = "10.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
+dependencies = [
+ "bitflags 1.3.2",
+]
+
+[[package]]
+name = "raw-cpuid"
+version = "11.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
+dependencies = [
+ "bitflags 2.10.0",
+]
+
+[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2851,6 +3372,12 @@ dependencies = [
]
[[package]]
+name = "reborrow"
+version = "0.5.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
+
+[[package]]
name = "redox_syscall"
version = "0.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3154,6 +3681,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
+name = "safetensors"
+version = "0.4.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
+[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3201,6 +3738,12 @@ dependencies = [
]
[[package]]
+name = "seq-macro"
+version = "0.3.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc"
+
+[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3255,6 +3798,15 @@ dependencies = [
]
[[package]]
+name = "serde_plain"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50"
+dependencies = [
+ "serde",
+]
+
+[[package]]
name = "serde_spanned"
version = "0.6.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -3971,6 +4523,34 @@ dependencies = [
]
[[package]]
+name = "sysctl"
+version = "0.5.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea"
+dependencies = [
+ "bitflags 2.10.0",
+ "byteorder",
+ "enum-as-inner",
+ "libc",
+ "thiserror 1.0.69",
+ "walkdir",
+]
+
+[[package]]
+name = "sysctl"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc"
+dependencies = [
+ "bitflags 2.10.0",
+ "byteorder",
+ "enum-as-inner",
+ "libc",
+ "thiserror 1.0.69",
+ "walkdir",
+]
+
+[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4310,8 +4890,8 @@ checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362"
dependencies = [
"serde",
"serde_spanned",
- "toml_datetime",
- "toml_edit",
+ "toml_datetime 0.6.11",
+ "toml_edit 0.22.27",
]
[[package]]
@@ -4324,6 +4904,15 @@ dependencies = [
]
[[package]]
+name = "toml_datetime"
+version = "0.7.5+spec-1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347"
+dependencies = [
+ "serde_core",
+]
+
+[[package]]
name = "toml_edit"
version = "0.22.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4332,12 +4921,33 @@ dependencies = [
"indexmap",
"serde",
"serde_spanned",
- "toml_datetime",
+ "toml_datetime 0.6.11",
"toml_write",
"winnow",
]
[[package]]
+name = "toml_edit"
+version = "0.23.10+spec-1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269"
+dependencies = [
+ "indexmap",
+ "toml_datetime 0.7.5+spec-1.1.0",
+ "toml_parser",
+ "winnow",
+]
+
+[[package]]
+name = "toml_parser"
+version = "1.0.6+spec-1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44"
+dependencies = [
+ "winnow",
+]
+
+[[package]]
name = "toml_write"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4530,6 +5140,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971"
[[package]]
+name = "ug"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437"
+dependencies = [
+ "gemm 0.18.2",
+ "half",
+ "libloading",
+ "memmap2",
+ "num",
+ "num-traits",
+ "num_cpus",
+ "rayon",
+ "safetensors",
+ "serde",
+ "thiserror 1.0.69",
+ "tracing",
+ "yoke 0.7.5",
+]
+
+[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -4738,7 +5369,7 @@ dependencies = [
"serde_json",
"url",
"utoipa",
- "zip",
+ "zip 3.0.0",
]
[[package]]
@@ -5324,17 +5955,41 @@ dependencies = [
[[package]]
name = "yoke"
+version = "0.7.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40"
+dependencies = [
+ "serde",
+ "stable_deref_trait",
+ "yoke-derive 0.7.5",
+ "zerofrom",
+]
+
+[[package]]
+name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
- "yoke-derive",
+ "yoke-derive 0.8.1",
"zerofrom",
]
[[package]]
name = "yoke-derive"
+version = "0.7.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "synstructure",
+]
+
+[[package]]
+name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
@@ -5399,7 +6054,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
- "yoke",
+ "yoke 0.8.1",
"zerofrom",
]
@@ -5409,7 +6064,7 @@ version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
- "yoke",
+ "yoke 0.8.1",
"zerofrom",
"zerovec-derive",
]
@@ -5427,6 +6082,21 @@ dependencies = [
[[package]]
name = "zip"
+version = "1.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164"
+dependencies = [
+ "arbitrary",
+ "crc32fast",
+ "crossbeam-utils",
+ "displaydoc",
+ "indexmap",
+ "num_enum",
+ "thiserror 1.0.69",
+]
+
+[[package]]
+name = "zip"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12598812502ed0105f607f941c386f43d441e00148fce9dec3ca5ffb0bde9308"
diff --git a/TTS_RESEARCH.md b/TTS_RESEARCH.md
new file mode 100644
index 0000000..da7b8b8
--- /dev/null
+++ b/TTS_RESEARCH.md
@@ -0,0 +1,351 @@
+# TTS Research Document: Qwen3-TTS Integration for Makima
+
+## Executive Summary
+
+This document summarizes research on replacing the existing ChatterboxTTS implementation with Qwen3-TTS for live/streaming TTS with Makima's Japanese voice speaking English.
+
+---
+
+## 1. Current TTS Implementation Analysis
+
+### 1.1 Architecture Overview
+
+The current makima codebase uses **ChatterboxTTS** (ResembleAI/chatterbox-turbo-ONNX) with the following components:
+
+| Component | File | Purpose |
+|-----------|------|---------|
+| TTS Module | `makima/src/tts.rs` | Core TTS inference using ONNX Runtime |
+| Audio Processing | `makima/src/audio.rs` | Audio decoding, resampling (Symphonia) |
+| Library Export | `makima/src/lib.rs` | Exposes `pub mod tts` |
+
+### 1.2 ChatterboxTTS Technical Details
+
+```rust
+// Key constants from tts.rs
+pub const SAMPLE_RATE: u32 = 24_000;
+const MODEL_ID: &str = "ResembleAI/chatterbox-turbo-ONNX";
+const DEFAULT_MODEL_DIR: &str = "models/chatterbox-turbo";
+```
+
+**ONNX Model Files:**
+- `speech_encoder.onnx` - Encodes reference voice audio
+- `embed_tokens.onnx` - Text token embedding
+- `language_model.onnx` - Autoregressive token generation (24 layers, 16 KV heads)
+- `conditional_decoder.onnx` - Decodes speech tokens to waveform
+- `tokenizer.json` - Text tokenization
+
+### 1.3 Current API Surface
+
+```rust
+pub struct ChatterboxTTS {
+ pub fn from_pretrained(model_dir: Option<&str>) -> Result<Self, TtsError>
+ pub fn generate_tts(&mut self, text: &str) -> Result<Vec<f32>, TtsError> // Returns VoiceRequired error
+ pub fn generate_tts_with_voice(text: &str, sample_audio_path: &Path) -> Result<Vec<f32>, TtsError>
+ pub fn generate_tts_with_samples(text: &str, samples: &[f32], sample_rate: u32) -> Result<Vec<f32>, TtsError>
+}
+pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError>
+```
+
+### 1.4 Voice Cloning Capabilities (Current)
+
+- **Requires** voice reference audio (returns `VoiceRequired` error without it)
+- Accepts reference audio via file path or raw samples
+- Resamples reference to 24kHz internally
+- Uses speaker embeddings + speaker features for voice cloning
+
+### 1.5 Streaming/Live TTS (Current)
+
+**NOT SUPPORTED** - The current implementation:
+- Generates entire audio in one pass
+- Uses autoregressive token generation with max 1024 tokens
+- No chunked/streaming output capability
+- Full pipeline must complete before audio is available
+
+### 1.6 Server Integration Status
+
+The TTS module is **not currently exposed via HTTP endpoints**. The server (`makima/src/server/mod.rs`) has:
+- `/api/v1/listen` - WebSocket for Speech-to-Text (STT) only
+- No TTS endpoints exist
+
+---
+
+## 2. Qwen3-TTS Model Analysis
+
+### 2.1 Model Specifications
+
+| Attribute | Value |
+|-----------|-------|
+| **Model** | Qwen/Qwen3-TTS-12Hz-0.6B-Base |
+| **Parameters** | 0.6B (also available: 1.7B version) |
+| **Architecture** | Discrete multi-codebook LM (16 codebooks, 2048 size) |
+| **Tokenizer** | Qwen3-TTS-Tokenizer-12Hz |
+| **Sample Rate** | 12 Hz tokenizer (reconstructs to standard rates) |
+| **License** | Apache 2.0 |
+
+### 2.2 Key Capabilities
+
+#### Voice Cloning
+- **3-second rapid voice clone** - Minimal reference audio needed
+- **Flexible input formats**: local files, URLs, base64, (numpy_array, sample_rate) tuples
+- **Reusable voice prompts**: Create once, use for multiple generations
+
+```python
+# Voice cloning example
+model = Qwen3TTSModel.from_pretrained("Qwen/Qwen3-TTS-12Hz-0.6B-Base")
+wavs, sr = model.generate_voice_clone(
+ text="Target text",
+ language="English",
+ ref_audio="reference.wav",
+ ref_text="Reference transcript"
+)
+```
+
+#### Streaming/Live TTS
+- **97ms end-to-end latency** - Ultra-low latency streaming
+- **Dual-track hybrid streaming architecture** - Supports both streaming and non-streaming
+- **First packet after single character** - Immediate response capability
+
+#### Multilingual Support
+10 languages: Chinese, English, Japanese, Korean, German, French, Russian, Portuguese, Spanish, Italian
+
+#### Quality Metrics
+| Metric | Value |
+|--------|-------|
+| WER (English) | 1.32 |
+| Speaker Similarity (English) | 0.829 |
+| PESQ_WB | 3.21 |
+| STOI | 0.96 |
+| UTMOS | 4.16 |
+
+### 2.3 Requirements
+
+```bash
+# Environment
+Python 3.12 (recommended)
+CUDA-compatible GPU with 8GB+ VRAM
+
+# Installation
+pip install -U qwen-tts
+pip install -U flash-attn --no-build-isolation # Optional, reduces GPU memory
+```
+
+### 2.4 Current Deployment Options
+
+| Option | Status | Notes |
+|--------|--------|-------|
+| **Python (qwen-tts)** | Stable | Official package |
+| **vLLM-Omni** | Offline only | Online serving coming |
+| **ONNX Export** | Not available | No official support |
+| **Rust Implementation** | Draft PR #8 | Early development |
+| **DashScope API** | Available | Alibaba Cloud hosted |
+
+---
+
+## 3. Makima Voice Audio Clips
+
+### 3.1 Voice Actress Information
+
+| Attribute | Value |
+|-----------|-------|
+| **Character** | Makima (Chainsaw Man) |
+| **Japanese VA** | Tomori Kusunoki (楠木ともり) |
+| **English VA** | Suzie Yeung |
+| **Agency** | Sony Music Artists |
+
+### 3.2 Audio Clip Sources
+
+1. **Chainsaw Man Anime Episodes** - Primary source for Japanese voice
+2. **Behind The Voice Actors** - Character voice samples
+3. **YouTube Clips** - Interview compilations, scene clips
+4. **Official Media** - Promotional videos, trailers
+
+### 3.3 Audio Requirements for Voice Cloning
+
+#### Qwen3-TTS Requirements
+| Parameter | Requirement |
+|-----------|-------------|
+| **Minimum Duration** | 3 seconds (basic quality) |
+| **Recommended Duration** | 10-30 seconds (professional quality) |
+| **Format** | WAV, FLAC, MP3, OGG, AIFF, AAC |
+| **Sample Rate** | 24kHz or above recommended |
+| **Channels** | Mono preferred |
+
+#### Best Practices
+- **Clean audio**: No background music/noise
+- **Single speaker**: Makima's voice only
+- **Consistent tone**: Avoid dramatic variations
+- **Include transcript**: Reference text improves quality
+- **Varied content**: Mix of sentence types for flexibility
+
+#### Recommended Clip Types
+1. Calm, composed dialogue (Makima's signature tone)
+2. Commands/instructions (authoritative delivery)
+3. Questions (natural intonation)
+4. Longer monologues (for voice consistency)
+
+---
+
+## 4. Feasibility Assessment for Live/Streaming TTS
+
+### 4.1 Technical Challenges
+
+| Challenge | Severity | Notes |
+|-----------|----------|-------|
+| No ONNX export | **High** | Current codebase uses ONNX Runtime |
+| Rust implementation | **High** | Only draft PR available |
+| Python dependency | Medium | Would require sidecar service |
+| GPU memory | Medium | 8GB+ VRAM required |
+| Streaming API | Low | Supported in Qwen3-TTS |
+
+### 4.2 Integration Approaches
+
+#### Option A: Python Sidecar Service (Recommended)
+**Architecture**: Rust server + Python TTS service via HTTP/gRPC
+
+**Pros:**
+- Uses official Qwen3-TTS Python package
+- Full streaming support (97ms latency)
+- Simpler maintenance
+
+**Cons:**
+- Additional deployment complexity
+- Inter-process communication overhead
+
+```
+┌─────────────────┐ HTTP/gRPC ┌─────────────────┐
+│ Makima Server │ ◄──────────────► │ Qwen3-TTS │
+│ (Rust/Axum) │ │ (Python/FastAPI)│
+└─────────────────┘ └─────────────────┘
+```
+
+**Available Implementations:**
+- [ValyrianTech/Qwen3-TTS_server](https://github.com/ValyrianTech/Qwen3-TTS_server) - FastAPI server
+- [Qwen3-TTS-Openai-Fastapi](https://github.com/twolven/Qwen3-TTS-Openai-Fastapi) - OpenAI-compatible API
+
+#### Option B: Wait for Rust Implementation
+**Status**: Draft PR #8 in early development
+
+**Pros:**
+- Native Rust integration
+- No Python dependency
+- Matches current architecture
+
+**Cons:**
+- Unknown timeline
+- May require significant adaptation
+
+#### Option C: Hybrid (ChatterboxTTS + Qwen3-TTS)
+Keep ChatterboxTTS for ONNX compatibility, add Qwen3-TTS for streaming
+
+**Pros:**
+- Gradual migration
+- Fallback capability
+
+**Cons:**
+- Dual model maintenance
+- Increased complexity
+
+### 4.3 Recommendation
+
+**Short-term (1-2 weeks)**: Implement **Option A** with Python sidecar
+- Deploy ValyrianTech/Qwen3-TTS_server or similar
+- Add HTTP client in Rust to call TTS service
+- Implement WebSocket endpoint for streaming audio
+
+**Long-term (3-6 months)**: Monitor Rust implementation progress
+- Evaluate draft PR #8 stability
+- Consider contributing to Rust port
+- Migrate to native Rust when mature
+
+---
+
+## 5. Preliminary Technical Approach
+
+### 5.1 Phase 1: Voice Preparation
+
+1. **Collect Makima Audio Clips**
+ - Extract 3-5 clean clips from anime (10-30 seconds each)
+ - Ensure Japanese voice, clear audio, no BGM
+ - Prepare transcripts for each clip
+
+2. **Test Voice Cloning Quality**
+ - Use Qwen3-TTS demo to validate clips
+ - Iterate on clip selection for best results
+
+### 5.2 Phase 2: TTS Service Setup
+
+1. **Deploy Qwen3-TTS Server**
+ ```bash
+ # Using ValyrianTech server
+ docker run --gpus all -p 7860:7860 qwen3-tts-server
+ ```
+
+2. **Configure Voice Clone Profile**
+ - Upload Makima reference audio
+ - Store voice clone prompt for reuse
+
+### 5.3 Phase 3: Makima Integration
+
+1. **Add TTS Client Module**
+ ```rust
+ // New module: makima/src/tts_client.rs
+ pub struct QwenTTSClient {
+ base_url: String,
+ voice_profile: String,
+ }
+
+ impl QwenTTSClient {
+ pub async fn generate_speech(&self, text: &str) -> Result<Vec<u8>, Error>
+ pub async fn generate_speech_streaming(&self, text: &str) -> impl Stream<Item = Vec<u8>>
+ }
+ ```
+
+2. **Add TTS Endpoint**
+ ```rust
+ // In makima/src/server/mod.rs
+ .route("/api/v1/tts", post(tts_handler))
+ .route("/api/v1/tts/stream", get(tts_streaming_handler))
+ ```
+
+3. **WebSocket Integration for Listen Page**
+ - Bidirectional audio: STT input, TTS output
+ - Low-latency streaming for conversational flow
+
+### 5.4 Phase 4: Listen Page Integration
+
+1. **Update Frontend**
+ - Add TTS playback capability
+ - Handle streaming audio chunks
+ - UI for voice response indicators
+
+2. **Orchestration Logic**
+ - STT → LLM → TTS pipeline
+ - Interrupt handling for user speech
+
+---
+
+## 6. Open Questions
+
+1. **Voice Rights**: Are there legal considerations for cloning Tomori Kusunoki's voice?
+2. **GPU Allocation**: Shared GPU for STT + TTS, or separate?
+3. **Latency Budget**: What's acceptable end-to-end latency for Listen page?
+4. **Fallback Strategy**: What happens if TTS service is unavailable?
+5. **Multi-user**: How to handle concurrent TTS requests?
+
+---
+
+## 7. References
+
+- [Qwen3-TTS HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+- [Qwen3-TTS GitHub](https://github.com/QwenLM/Qwen3-TTS)
+- [Qwen3-TTS Technical Report](https://arxiv.org/abs/2601.15621)
+- [ValyrianTech Qwen3-TTS Server](https://github.com/ValyrianTech/Qwen3-TTS_server)
+- [Qwen3-TTS OpenAI-Compatible FastAPI](https://github.com/twolven/Qwen3-TTS-Openai-Fastapi)
+- [Makima Voice Actors](https://www.behindthevoiceactors.com/tv-shows/Chainsaw-Man/Makima/)
+- [ChatterboxTTS Audio Guidelines](https://github.com/resemble-ai/chatterbox/issues/39)
+- [Voice Cloning Best Practices - Resemble AI](https://www.resemble.ai/script-to-read-for-voice-cloning-guidelines/)
+- [Qwen3-rs (Rust LLM implementation)](https://github.com/reinterpretcat/qwen3-rs)
+
+---
+
+*Document created: Research phase for Makima TTS replacement contract*
diff --git a/docs/plans/qwen3-tts-implementation-plan.md b/docs/plans/qwen3-tts-implementation-plan.md
new file mode 100644
index 0000000..76ecb33
--- /dev/null
+++ b/docs/plans/qwen3-tts-implementation-plan.md
@@ -0,0 +1,514 @@
+# Qwen3-TTS Implementation Plan — Pure Rust (Candle)
+
+**Version:** 2.0
+**Created:** 2026-01-27
+**Status:** Final
+**Authors:** makima development team
+**Spec Reference:** [docs/specs/qwen3-tts-spec.md](../specs/qwen3-tts-spec.md)
+**Research:** [docs/research/rust-native-tts-research.md](../research/rust-native-tts-research.md)
+
+---
+
+## Table of Contents
+
+1. [Overview](#1-overview)
+2. [Task Breakdown](#2-task-breakdown)
+3. [File Changes](#3-file-changes)
+4. [Phase 1: Candle-Based TTS Module](#4-phase-1-candle-based-tts-module)
+5. [Phase 2: WebSocket Handler + Voice Assets](#5-phase-2-websocket-handler--voice-assets)
+6. [Phase 3: Optimization + Integration](#6-phase-3-optimization--integration)
+7. [Testing Plan](#7-testing-plan)
+8. [Risk Assessment](#8-risk-assessment)
+9. [Dependencies & Prerequisites](#9-dependencies--prerequisites)
+10. [Success Criteria](#10-success-criteria)
+
+---
+
+## 1. Overview
+
+This plan details the implementation of Qwen3-TTS integration for the makima system as a **pure Rust** solution using the **candle** ML framework. There is no Python microservice and no proxy pattern — the TTS model runs directly inside the main makima process, loading safetensors weights via candle.
+
+### Key Objectives
+
+1. Implement Qwen3-TTS model inference natively in Rust using candle
+2. Create a `makima/src/tts/` module with TTS trait, Chatterbox adapter, and Qwen3 submodule
+3. Update the `/api/v1/speak` WebSocket handler to call the TTS engine directly
+4. Enable streaming audio delivery with <200ms time-to-first-audio (TTFA)
+5. Support voice cloning with default Makima voice
+
+### Architecture Summary
+
+```
+Client Browser
+ │
+ │ WebSocket: /api/v1/speak
+ ▼
+Makima Server (Rust/Axum)
+ │
+ │ speak.rs handler → TTS Engine (in-process)
+ │
+ │ candle-based Qwen3-TTS inference
+ │ (safetensors weights loaded directly)
+ ▼
+Audio Stream back to client
+```
+
+**Key architectural decisions:**
+- **No Python.** All inference runs in Rust via candle.
+- **No microservice.** TTS runs in-process, no separate service to deploy.
+- **No proxy.** The speak handler calls the TTS engine directly.
+- **Lazy loading.** Models loaded on first TTS request (like listen.rs pattern).
+- **SafeTensors.** Weights loaded directly — no ONNX conversion needed.
+
+---
+
+## 2. Task Breakdown
+
+### Phase 1: Candle-Based TTS Module (Priority: Critical)
+
+| ID | Task | Depends On | Estimated Hours |
+|----|------|------------|-----------------|
+| P1.1 | Create `makima/src/tts/mod.rs` — TTS trait + factory + types | - | 3 |
+| P1.2 | Move existing `tts.rs` to `makima/src/tts/chatterbox.rs` | P1.1 | 2 |
+| P1.3 | Create `makima/src/tts/qwen3/config.rs` — Model config parsing | P1.1 | 2 |
+| P1.4 | Implement `makima/src/tts/qwen3/model.rs` — 28-layer LM backbone | P1.3 | 12 |
+| P1.5 | Implement `makima/src/tts/qwen3/code_predictor.rs` — MTP module | P1.4 | 8 |
+| P1.6 | Implement `makima/src/tts/qwen3/speech_tokenizer.rs` — ConvNet codec | P1.3 | 10 |
+| P1.7 | Implement `makima/src/tts/qwen3/generate.rs` — Autoregressive generation | P1.4, P1.5, P1.6 | 8 |
+| P1.8 | Create `makima/src/tts/qwen3/mod.rs` — Public API | P1.7 | 3 |
+| P1.9 | Add candle dependencies to `Cargo.toml` | - | 1 |
+| P1.10 | Unit tests for config, model layers, tokenizer | P1.4-P1.6 | 6 |
+
+**Phase 1 Total: ~55 hours**
+
+### Phase 2: WebSocket Handler + Voice Assets (Priority: High)
+
+| ID | Task | Depends On | Estimated Hours |
+|----|------|------------|-----------------|
+| P2.1 | Rewrite `speak.rs` — Direct TTS handler (remove proxy) | P1.8 | 6 |
+| P2.2 | Add TTS models to `SharedState` (lazy loading via `OnceCell`) | P1.8 | 3 |
+| P2.3 | Implement voice prompt caching (LRU) | P1.8 | 3 |
+| P2.4 | Remove `tts_client.rs` (no longer needed) | P2.1 | 1 |
+| P2.5 | Update `state.rs` — Remove TTS proxy fields, add TTS model fields | P2.2 | 2 |
+| P2.6 | Update `mod.rs` — Remove `tts_client` module | P2.4 | 0.5 |
+| P2.7 | Create voice manifest structure (`models/voices/makima/`) | - | 1 |
+| P2.8 | Acquire Makima voice reference audio | - | 2 |
+| P2.9 | Test voice cloning quality | P1.8, P2.8 | 2 |
+
+**Phase 2 Total: ~20.5 hours**
+
+### Phase 3: Optimization + Integration (Priority: Medium)
+
+| ID | Task | Depends On | Estimated Hours |
+|----|------|------------|-----------------|
+| P3.1 | Implement streaming generation (token-by-token waveform decode) | P2.1 | 6 |
+| P3.2 | GPU memory optimization (bf16, cache management) | P3.1 | 4 |
+| P3.3 | Listen page integration for bidirectional speech | P2.1 | 4 |
+| P3.4 | Latency benchmarks | P3.1 | 3 |
+| P3.5 | Integration tests (WebSocket end-to-end) | P2.1 | 4 |
+| P3.6 | Documentation | P3.5 | 2 |
+
+**Phase 3 Total: ~23 hours**
+
+---
+
+## 3. File Changes
+
+### New Files
+
+```
+makima/src/tts/
+├── mod.rs // TTS trait, factory, shared types
+├── chatterbox.rs // Existing ONNX-based Chatterbox (moved from tts.rs)
+└── qwen3/
+ ├── mod.rs // Qwen3TTS public API
+ ├── model.rs // Qwen3 LM transformer (28 layers)
+ ├── code_predictor.rs // MTP module (5 layers, 16 codebooks)
+ ├── speech_tokenizer.rs // Encoder + Decoder (causal ConvNet + RVQ)
+ ├── config.rs // Model config from config.json / safetensors
+ └── generate.rs // Autoregressive generation loop with KV cache
+```
+
+### Modified Files
+
+| File | Change Description |
+|------|-------------------|
+| `makima/src/server/handlers/speak.rs` | Rewrite: direct TTS engine call instead of proxy |
+| `makima/src/server/state.rs` | Remove `tts_client`/`tts_config` fields, add `tts_models: OnceCell<TtsModels>` |
+| `makima/src/server/mod.rs` | Remove `pub mod tts_client;` |
+| `makima/src/server/handlers/mod.rs` | No change (speak already exported) |
+| `makima/Cargo.toml` | Add candle-core, candle-nn, candle-transformers; remove tokio-tungstenite if unused |
+| `makima/src/lib.rs` or `main.rs` | Add `pub mod tts;` |
+
+### Deleted Files
+
+| File | Reason |
+|------|--------|
+| `makima/src/server/tts_client.rs` | No longer needed — no proxy pattern |
+| `tts-service/` (entire directory) | Python service rejected; pure Rust solution |
+
+---
+
+## 4. Phase 1: Candle-Based TTS Module
+
+### 4.1 TTS Trait and Factory (`tts/mod.rs`)
+
+```rust
+use async_trait::async_trait;
+
+/// Audio chunk for streaming output.
+pub struct AudioChunk {
+ pub samples: Vec<f32>,
+ pub sample_rate: u32,
+ pub is_final: bool,
+}
+
+/// TTS engine trait — implemented by Chatterbox and Qwen3.
+#[async_trait]
+pub trait TtsEngine: Send + Sync {
+ /// Generate audio from text.
+ async fn generate(
+ &self,
+ text: &str,
+ voice_id: &str,
+ language: &str,
+ ) -> Result<Vec<AudioChunk>, TtsError>;
+
+ /// Pre-load a voice prompt.
+ async fn load_voice(&self, voice_id: &str) -> Result<(), TtsError>;
+
+ /// Check if the engine is ready.
+ fn is_ready(&self) -> bool;
+}
+```
+
+### 4.2 Qwen3 LM Backbone (`tts/qwen3/model.rs`)
+
+Extend candle-transformers' Qwen2 model implementation:
+
+- **28 transformer layers** with RoPE, GQA (16 heads, 8 KV heads), head dim 128
+- **Hidden size:** 1024, **intermediate size:** 3072
+- **Input:** text tokens + reference audio codes (concatenated)
+- **Output:** zeroth codebook token logits
+
+**Key implementation detail:** The existing `candle_transformers::models::qwen2` module provides the base attention and MLP layers. We extend this with:
+- TTS-specific input embedding (text + audio token embeddings)
+- Speaker encoder concatenation
+- Code predictor output head (instead of standard LM head)
+
+### 4.3 Code Predictor (`tts/qwen3/code_predictor.rs`)
+
+- **5-layer** transformer module
+- **Input:** hidden states from the main LM
+- **Output:** 16 codebook predictions (vocab size 2048 each)
+- After the main LM predicts the zeroth codebook token, this module predicts the remaining 15 codebook layers in parallel
+
+### 4.4 Speech Tokenizer (`tts/qwen3/speech_tokenizer.rs`)
+
+Two sub-components:
+
+**Encoder** (used for voice cloning):
+- Causal 1D ConvNet converting reference audio waveform → discrete multi-codebook tokens
+- 16-layer RVQ (Residual Vector Quantization)
+- First codebook = semantic (WavLM-guided), remaining 15 = acoustic
+
+**Decoder** (used for audio output):
+- Causal 1D ConvNet reconstructing waveforms from discrete codes
+- Input: 16 codebook indices → lookup embeddings → ConvNet → waveform
+- Output: 24kHz mono audio
+
+**candle implementation notes:**
+- `candle_nn::Conv1d` for all convolution layers
+- `candle_nn::Embedding` for codebook lookups
+- Weight normalization handled manually
+
+### 4.5 Autoregressive Generation (`tts/qwen3/generate.rs`)
+
+```rust
+pub async fn generate(
+ model: &Qwen3Model,
+ code_predictor: &CodePredictor,
+ speech_tokenizer: &SpeechTokenizer,
+ text_tokens: &[u32],
+ voice_prompt: &VoicePrompt,
+) -> Result<Vec<AudioChunk>, TtsError> {
+ // 1. Encode reference audio → speaker embedding + audio codes
+ let speaker_emb = speech_tokenizer.encode(&voice_prompt.audio)?;
+
+ // 2. Prepare input: [text_tokens, audio_codes]
+ let input = prepare_input(text_tokens, &speaker_emb)?;
+
+ // 3. Autoregressive loop with KV cache
+ let mut kv_cache = KvCache::new(model.num_layers());
+ let mut generated_codes = Vec::new();
+
+ loop {
+ let logits = model.forward(&input, &mut kv_cache)?;
+ let next_token = sample_token(&logits);
+
+ if next_token == EOS_TOKEN { break; }
+ generated_codes.push(next_token);
+
+ // 4. Code predictor: predict remaining 15 codebooks
+ let all_codes = code_predictor.predict(&model.last_hidden_state(), next_token)?;
+
+ // 5. Decode to audio (can be done incrementally for streaming)
+ let chunk = speech_tokenizer.decode(&all_codes)?;
+ // yield chunk for streaming
+ }
+}
+```
+
+---
+
+## 5. Phase 2: WebSocket Handler + Voice Assets
+
+### 5.1 Speak Handler (Rewritten)
+
+```rust
+// makima/src/server/handlers/speak.rs
+//
+// Direct TTS handler — no proxy, no external service.
+
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+
+ // Lazy-load TTS models (like listen.rs does for STT)
+ let tts = match state.get_tts_models().await {
+ Ok(tts) => tts,
+ Err(e) => {
+ send_error(&mut socket, "MODEL_LOADING", &e.to_string()).await;
+ return;
+ }
+ };
+
+ // Session loop: parse JSON messages, dispatch to TTS engine
+ let (mut sender, mut receiver) = socket.split();
+
+ while let Some(msg) = receiver.next().await {
+ match parse_client_message(msg) {
+ ClientMessage::Start(config) => {
+ // Load voice, send Ready
+ }
+ ClientMessage::Speak(text) => {
+ // Run inference, stream audio chunks
+ for chunk in tts.engine.generate(&text, voice_id, language).await? {
+ sender.send(Message::Binary(chunk.to_pcm16())).await?;
+ }
+ sender.send(complete_message()).await?;
+ }
+ ClientMessage::Stop => break,
+ ClientMessage::Cancel => { /* abort current generation */ }
+ }
+ }
+}
+```
+
+### 5.2 State Changes
+
+Remove from `AppState`:
+- `tts_client: Option<Arc<TtsProxyClient>>`
+- `tts_config: Option<TtsConfig>`
+- `with_tts()` method
+- `is_tts_healthy()` method
+
+Add to `AppState`:
+- `tts_models: OnceCell<TtsModels>` — lazily loaded TTS engine
+- `get_tts_models()` method (async, like `get_ml_models()`)
+
+### 5.3 Voice Assets
+
+```
+models/voices/makima/
+├── manifest.json # Voice metadata
+├── reference.wav # 5-15 second reference audio
+└── transcript.txt # Exact transcript of reference audio
+```
+
+---
+
+## 6. Phase 3: Optimization + Integration
+
+### 6.1 Streaming Generation
+
+The 12Hz model's causal architecture enables token-by-token waveform generation:
+- Each token = ~80ms of audio (12.5 Hz)
+- After generating each token, decode immediately and send audio chunk
+- Client receives audio before full generation completes
+
+### 6.2 GPU Memory Optimization
+
+- Load weights in bf16/f16 (candle supports both)
+- Implement KV cache with fixed maximum size
+- Clear cache between sessions
+- CPU fallback when GPU is unavailable
+
+### 6.3 Listen Page Integration
+
+Following the pattern in `listen.rs`:
+- TTS model protected behind `tokio::sync::Mutex`
+- WebSocket endpoint emits audio chunks as tokens are generated
+- Bidirectional: STT (listen) → process → TTS (speak) loop
+
+---
+
+## 7. Testing Plan
+
+### 7.1 Unit Tests
+
+| Test Area | Coverage | Key Tests |
+|-----------|----------|-----------|
+| Config parsing | 100% | Load config from JSON, validate fields |
+| Model layers | 80% | Attention, MLP, Conv1d shapes |
+| Code predictor | 85% | Multi-codebook output shapes |
+| Speech tokenizer | 80% | Encode/decode round-trip |
+| Voice cache | 95% | LRU eviction, TTL expiration |
+| Message parsing | 100% | All client/server message types |
+
+### 7.2 Integration Tests
+
+| Test | Description |
+|------|-------------|
+| WebSocket flow | Connect → Start → Speak → Audio chunks → Complete → Stop |
+| Error handling | Invalid text, unknown voice, model loading failure |
+| Cancellation | Cancel mid-generation |
+| Voice cloning | Generate with custom reference audio |
+
+### 7.3 Latency Benchmarks
+
+| Metric | Target | Acceptable | Warning |
+|--------|--------|------------|---------|
+| First Audio (short text) | < 150ms | < 200ms | > 300ms |
+| First Audio (medium text) | < 200ms | < 300ms | > 500ms |
+| First Audio (long text) | < 300ms | < 500ms | > 800ms |
+| Inter-chunk latency | < 30ms | < 50ms | > 100ms |
+| GPU memory | < 4GB | < 6GB | > 8GB |
+
+---
+
+## 8. Risk Assessment
+
+### 8.1 Technical Risks
+
+| Risk | Likelihood | Impact | Mitigation |
+|------|------------|--------|------------|
+| **Candle implementation takes longer** | Medium | Medium | Reference Crane's Spark-TTS; use qwen3-rs as LM reference |
+| **Speech tokenizer ConvNet is complex** | Medium | High | Study PyTorch source; ConvNet layers are simpler than transformers |
+| **Model quality differs from PyTorch** | Low | High | Validate with reference audio; ensure bf16 precision |
+| **Crane ships Qwen3-TTS first** | Medium | Positive | Adopt their implementation or use as reference |
+| **GPU memory issues** | Low | Medium | 0.6B model is small (~2.5GB); fits in 4GB VRAM |
+
+### 8.2 Contingency Plans
+
+| Scenario | Response |
+|----------|----------|
+| Candle implementation blocked | Use Crane crate as dependency if they ship Qwen3-TTS |
+| ConvNet decoder too complex | Implement simplified decoder; optimize later |
+| Latency exceeds targets | Start with batch mode + chunked delivery (acceptable UX) |
+| No GPU available | CPU fallback with candle's MKL support (degraded performance) |
+
+---
+
+## 9. Dependencies & Prerequisites
+
+### 9.1 Rust Dependencies
+
+Add to `Cargo.toml`:
+
+```toml
+[dependencies]
+candle-core = "0.8"
+candle-nn = "0.8"
+candle-transformers = "0.8"
+# Keep existing: tokenizers, hf-hub, safetensors, ndarray
+```
+
+### 9.2 Hardware Requirements
+
+| Component | Minimum | Recommended |
+|-----------|---------|-------------|
+| GPU | CUDA 4GB VRAM / Metal (macOS) | NVIDIA RTX 3060+ (8GB+) |
+| RAM | 8GB | 16GB |
+| Storage | 5GB (model weights) | 10GB |
+
+### 9.3 Voice Asset Prerequisites
+
+Before Phase 2 voice testing:
+1. Makima voice reference audio (5-15 seconds, clean speech)
+2. Accurate transcript of reference audio
+3. Format: WAV 24kHz mono, 16-bit PCM
+
+---
+
+## 10. Success Criteria
+
+### 10.1 Phase 1 Completion
+
+- [ ] candle-based Qwen3 model loads safetensors weights
+- [ ] Forward pass produces valid logits
+- [ ] Speech tokenizer encodes/decodes audio
+- [ ] Code predictor generates 16 codebook layers
+- [ ] Unit tests pass with > 80% coverage
+
+### 10.2 Phase 2 Completion
+
+- [ ] `/api/v1/speak` endpoint produces audio from text
+- [ ] No Python service required
+- [ ] Voice cloning works with reference audio
+- [ ] Error handling returns appropriate codes
+- [ ] speak.rs calls TTS engine directly (no proxy)
+
+### 10.3 Phase 3 Completion
+
+- [ ] Streaming generation with < 200ms TTFA
+- [ ] GPU memory usage < 6GB
+- [ ] Integration tests pass
+- [ ] Listen page bidirectional speech works
+- [ ] Latency benchmarks documented
+
+### 10.4 Final Acceptance Criteria
+
+1. **Functional:** End-to-end TTS streaming via WebSocket, pure Rust, no Python
+2. **Performance:** TTFA < 200ms, subsequent chunks < 100ms
+3. **Quality:** Synthesized speech is intelligible and recognizable as Makima
+4. **Reliability:** Error handling is robust; graceful degradation on GPU failure
+5. **Architecture:** Clean `tts/` module with trait-based engine selection
+
+---
+
+## Appendix A: Quick Start Commands
+
+### Development
+
+```bash
+# Build with candle GPU support
+cd makima
+cargo build --features cuda # or --features metal for macOS
+
+# Run server with TTS enabled
+TTS_ENGINE=qwen3 TTS_DEVICE=cuda:0 cargo run
+
+# Run TTS-specific tests
+cargo test tts
+```
+
+### Benchmarks
+
+```bash
+# Run latency benchmarks (requires GPU)
+cargo bench --bench tts_latency
+```
+
+## Appendix B: Reference Implementations
+
+- [candle-transformers qwen2 model](https://docs.rs/candle-transformers/latest/candle_transformers/models/qwen2/index.html) — base attention/MLP layers
+- [qwen3-rs](https://github.com/reinterpretcat/qwen3-rs) — educational Qwen3 in Rust
+- [Crane](https://github.com/lucasjinreal/Crane) — Rust TTS engine (Qwen3-TTS on roadmap)
+- [docs/research/rust-native-tts-research.md](../research/rust-native-tts-research.md) — full feasibility analysis
diff --git a/docs/research/TTS_RESEARCH_NOTES.md b/docs/research/TTS_RESEARCH_NOTES.md
new file mode 100644
index 0000000..72ac8c6
--- /dev/null
+++ b/docs/research/TTS_RESEARCH_NOTES.md
@@ -0,0 +1,405 @@
+# TTS Replacement Research Notes
+
+## Executive Summary
+
+This document summarizes research on replacing the existing TTS endpoint in makima with Qwen3-TTS-12Hz-0.6B-Base, with the goal of supporting voice cloning using Makima's Japanese voice speaking English, and achieving near-live/streaming TTS capabilities.
+
+---
+
+## 1. Current TTS Implementation Analysis
+
+### 1.1 Current Model: Chatterbox-Turbo
+
+The existing TTS implementation in `makima/src/tts.rs` uses **ResembleAI/chatterbox-turbo-ONNX**:
+
+- **Architecture**: 4-component ONNX model pipeline
+ - `speech_encoder.onnx` - Encodes reference voice audio
+ - `embed_tokens.onnx` - Token embedding layer
+ - `language_model.onnx` - Autoregressive language model (24 layers, 16 KV heads, 64 head dim)
+ - `conditional_decoder.onnx` - Decodes speech tokens to audio waveform
+
+- **Sample Rate**: 24,000 Hz output
+- **Voice Cloning**: Required (no default voice support)
+- **Special Tokens**:
+ - START_SPEECH_TOKEN: 6561
+ - STOP_SPEECH_TOKEN: 6562
+ - SILENCE_TOKEN: 4299
+
+### 1.2 Current API Surface
+
+**Core TTS Functions:**
+```rust
+pub fn generate_tts(&mut self, _text: &str) -> Result<Vec<f32>, TtsError>
+ // Returns VoiceRequired error - voice cloning is mandatory
+
+pub fn generate_tts_with_voice(&mut self, text: &str, sample_audio_path: &Path) -> Result<Vec<f32>, TtsError>
+ // Voice cloning from file path
+
+pub fn generate_tts_with_samples(&mut self, text: &str, samples: &[f32], sample_rate: u32) -> Result<Vec<f32>, TtsError>
+ // Voice cloning from raw samples
+```
+
+**Audio Processing:**
+- Input audio resampled to 24kHz
+- Reference voice encoded into:
+ - `audio_features` - Acoustic features
+ - `prompt_tokens` - Token representation
+ - `speaker_embeddings` - Speaker identity
+ - `speaker_features` - Voice characteristics
+
+### 1.3 Current Limitations
+
+1. **No Streaming Support**: Current implementation generates complete audio before returning
+2. **No Default Voice**: Requires voice reference audio for every call
+3. **No HTTP Endpoint**: TTS is only available as a Rust library, not exposed via REST API
+4. **Single Language**: Optimized for English, unclear multilingual support
+5. **High Latency**: Full autoregressive generation before any audio output
+
+### 1.4 Related Components
+
+**Audio Processing (`makima/src/audio.rs`):**
+- Uses Symphonia for audio decoding (MP3, WAV, FLAC, OGG, etc.)
+- Resampling via sinc interpolation
+- Stereo to mono mixdown
+- Target: 16kHz mono for STT
+
+**Listen Endpoint (`makima/src/server/handlers/listen.rs`):**
+- WebSocket-based streaming STT
+- Uses Parakeet for transcription
+- Sortformer for speaker diarization
+- Already has real-time audio streaming infrastructure
+
+---
+
+## 2. Qwen3-TTS-12Hz-0.6B-Base Model Analysis
+
+### 2.1 Model Capabilities
+
+| Feature | Specification |
+|---------|---------------|
+| **Model Size** | 0.6B parameters (lightweight variant) |
+| **Voice Cloning** | 3-second reference audio only |
+| **Streaming** | Dual-track hybrid architecture |
+| **Minimum Latency** | 97ms end-to-end |
+| **Languages** | 10 (Chinese, English, Japanese, Korean, German, French, Russian, Portuguese, Spanish, Italian) |
+| **Cross-lingual Cloning** | Japanese voice to English speech supported |
+| **Speaker Similarity** | 0.95 (near human-level) |
+| **Output Sample Rate** | Up to 48kHz (standard 24kHz) |
+
+### 2.2 Voice Cloning Requirements
+
+**Reference Audio:**
+- **Minimum Duration**: 3 seconds
+- **Recommended Duration**: 5-15 seconds
+- **Format**: WAV preferred; also supports URL, base64, numpy array
+- **Quality**: Clean, noise-free audio essential
+- **Transcript**: Providing `ref_text` significantly improves quality
+
+**Cross-Lingual Usage (Japanese to English):**
+```python
+ref_audio = "makima_japanese.wav" # Japanese reference
+ref_text = "日本語のテキスト" # Japanese transcription
+
+wavs, sr = model.generate_voice_clone(
+ text="This is English text", # English output
+ language="English",
+ ref_audio=ref_audio,
+ ref_text=ref_text,
+)
+```
+
+### 2.3 Technical Requirements
+
+**Python Dependencies:**
+```bash
+pip install -U qwen-tts
+pip install -U flash-attn --no-build-isolation # For optimal performance
+```
+
+**Hardware:**
+- CUDA-compatible GPU required
+- FlashAttention 2 for optimal memory usage
+- Float16/bfloat16 precision support
+- For <96GB RAM: `MAX_JOBS=4` for flash-attn installation
+
+**Model Loading:**
+```python
+from qwen_tts import Qwen3TTSModel
+import torch
+
+model = Qwen3TTSModel.from_pretrained(
+ "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
+ device_map="cuda:0",
+ dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+)
+```
+
+### 2.4 Streaming Architecture
+
+**Dual-Track Hybrid Design:**
+- Single model supports both streaming and non-streaming
+- Audio output begins after minimal text input
+- 97ms minimum latency achieved through:
+ - Proprietary Qwen3-TTS-Tokenizer-12Hz (efficient acoustic compression)
+ - Discrete multi-codebook LM (eliminates LM+DiT bottleneck)
+ - Lightweight non-DiT vocoder
+
+**Reusable Voice Clone Prompt (Critical for Performance):**
+```python
+# Pre-compute once
+prompt_items = model.create_voice_clone_prompt(
+ ref_audio=ref_audio,
+ ref_text=ref_text,
+ x_vector_only_mode=False
+)
+
+# Reuse for multiple generations
+wavs, sr = model.generate_voice_clone(
+ text=["Line 1", "Line 2"],
+ language=["English", "English"],
+ voice_clone_prompt=prompt_items, # Cached prompt
+)
+```
+
+---
+
+## 3. Makima Voice Audio Sources
+
+### 3.1 Character Information
+
+- **Character**: Makima from Chainsaw Man anime
+- **Japanese Voice Actress**: Tomori Kusunoki (楠木ともり)
+- **English Voice Actress**: Suzie Yeung
+
+### 3.2 Potential Audio Sources
+
+| Source | Type | Notes |
+|--------|------|-------|
+| **Voicy Network Soundboard** | Official clips | MP3 download available, 20+ sound effects |
+| **101Soundboards** | Fan-curated clips | Various character sounds |
+| **Anime Episodes** | Source material | Highest quality, requires extraction |
+| **Nikke: Goddess of Victory** | Game voicelines | Same voice actress (Tomori Kusunoki) |
+| **Ko-fi (erusha)** | WAV files | x5 character impression audio files |
+
+### 3.3 Recommended Approach
+
+1. **Primary Source**: Extract 5-15 seconds of clean dialogue from Chainsaw Man anime (Japanese audio track)
+2. **Selection Criteria**:
+ - Clear, isolated dialogue (no background music/effects)
+ - Natural speaking tone (not shouting/whispering)
+ - Variety of phonemes for better cloning
+3. **Transcription**: Provide accurate Japanese transcription for `ref_text`
+4. **Processing**: Convert to WAV format, ensure clean audio quality
+
+### 3.4 Legal Considerations
+
+- Voice cloning of real voice actors for commercial use may have legal implications
+- Synthetic voice generation based on copyrighted characters may require licenses
+- Consider using for internal/personal use only, or creating disclaimer
+
+---
+
+## 4. Feasibility Assessment
+
+### 4.1 Live/Streaming TTS Feasibility: HIGHLY FEASIBLE
+
+**Evidence:**
+- Qwen3-TTS achieves 97ms latency (well under 200ms real-time threshold)
+- Existing WebSocket infrastructure in makima (`/api/v1/listen`) can be adapted
+- Streaming architecture designed for interactive scenarios
+
+**Implementation Approach:**
+1. Create new WebSocket endpoint `/api/v1/speak` mirroring listen endpoint
+2. Pre-compute voice clone prompt on connection start
+3. Stream audio chunks as they're generated
+4. Use chunked audio encoding (similar to listen's binary message handling)
+
+### 4.2 Voice Cloning with Japanese Voice: FULLY SUPPORTED
+
+**Evidence:**
+- Qwen3-TTS explicitly supports cross-lingual voice cloning
+- Japanese is one of 10 supported languages
+- 0.95 speaker similarity maintained across languages
+
+**Implementation Approach:**
+1. Pre-process Makima voice clips (5-15 seconds Japanese audio)
+2. Include Japanese transcription
+3. Generate English speech while preserving voice characteristics
+
+### 4.3 Integration Challenges
+
+| Challenge | Difficulty | Mitigation |
+|-----------|-----------|------------|
+| **Python to Rust Integration** | Medium | Use Python subprocess or microservice |
+| **GPU Memory** | Low | 0.6B model is lightweight |
+| **Latency Target** | Low | 97ms base latency is excellent |
+| **Audio Format Conversion** | Low | Existing symphonia infrastructure |
+| **Default Voice Setup** | Low | One-time voice prompt caching |
+
+### 4.4 Architecture Options
+
+**Option A: Python Microservice**
+```
+[Makima Rust Server] --HTTP/WebSocket--> [Python TTS Service]
+ |
+ [Qwen3-TTS Model]
+```
+Pros: Clean separation, easy Python integration
+Cons: Network overhead, deployment complexity
+
+**Option B: PyO3 Rust Bindings**
+```
+[Makima Rust Server] --FFI--> [pyo3 Python Bindings] --> [Qwen3-TTS]
+```
+Pros: Single process, lower latency
+Cons: Complex build, Python GIL issues
+
+**Option C: ONNX Export (Like Current Chatterbox)**
+```
+[Makima Rust Server] --ort--> [Qwen3-TTS ONNX Models]
+```
+Pros: Pure Rust, consistent with existing architecture
+Cons: May not have ONNX export available for Qwen3-TTS
+
+**Recommended: Option A (Python Microservice)**
+- Fastest time to implementation
+- Aligns with Qwen3-TTS's native Python API
+- Can use WebSocket for streaming audio chunks
+- Easy to deploy alongside existing makima server
+
+---
+
+## 5. Preliminary Technical Approach
+
+### 5.1 Phase 1: Python TTS Microservice
+
+```python
+# tts_service.py
+from fastapi import FastAPI, WebSocket
+from qwen_tts import Qwen3TTSModel
+import torch
+import base64
+
+app = FastAPI()
+model = None
+voice_prompt = None
+
+@app.on_event("startup")
+async def load_model():
+ global model, voice_prompt
+ model = Qwen3TTSModel.from_pretrained(
+ "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
+ device_map="cuda:0",
+ dtype=torch.bfloat16,
+ )
+ # Pre-load Makima voice
+ voice_prompt = model.create_voice_clone_prompt(
+ ref_audio="makima_voice.wav",
+ ref_text="日本語の台詞...",
+ )
+
+@app.websocket("/ws/speak")
+async def speak(websocket: WebSocket):
+ await websocket.accept()
+ while True:
+ text = await websocket.receive_text()
+ wavs, sr = model.generate_voice_clone(
+ text=text,
+ language="English",
+ voice_clone_prompt=voice_prompt,
+ )
+ # Stream audio chunks
+ audio_bytes = wavs[0].tobytes()
+ await websocket.send_bytes(audio_bytes)
+```
+
+### 5.2 Phase 2: Rust Integration
+
+```rust
+// makima/src/server/handlers/speak.rs
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ // Connect to Python TTS service
+ let tts_ws = tokio_tungstenite::connect_async("ws://localhost:8001/ws/speak").await?;
+
+ // Forward text to TTS, stream audio back to client
+ // ...
+}
+```
+
+### 5.3 API Design
+
+**WebSocket Endpoint: `/api/v1/speak`**
+
+**Client to Server Messages:**
+```json
+{
+ "type": "start",
+ "sample_rate": 24000,
+ "encoding": "pcm16"
+}
+
+{
+ "type": "speak",
+ "text": "Hello, I am Makima."
+}
+
+{
+ "type": "stop"
+}
+```
+
+**Server to Client Messages:**
+```json
+{
+ "type": "ready",
+ "session_id": "uuid"
+}
+
+{
+ "type": "audio_chunk",
+ "data": "<base64-encoded-audio>",
+ "is_final": false
+}
+
+{
+ "type": "complete"
+}
+```
+
+---
+
+## 6. Next Steps
+
+### Immediate Actions
+1. [ ] Obtain Makima voice clips (5-15 seconds clean Japanese audio)
+2. [ ] Create Japanese transcription of voice clips
+3. [ ] Test Qwen3-TTS voice cloning with Makima samples
+4. [ ] Benchmark latency on target hardware
+
+### Development Phases
+1. **Phase 1**: Python TTS microservice proof-of-concept
+2. **Phase 2**: WebSocket streaming integration
+3. **Phase 3**: Rust proxy endpoint in makima
+4. **Phase 4**: Listen page integration for bidirectional speech
+
+### Hardware Requirements
+- CUDA-compatible GPU (minimum)
+- Recommended: 8GB+ VRAM for 0.6B model with FlashAttention 2
+- Python 3.12+ environment
+
+---
+
+## References
+
+- [Qwen3-TTS-12Hz-0.6B-Base on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+- [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
+- [Behind The Voice Actors - Makima](https://www.behindthevoiceactors.com/tv-shows/Chainsaw-Man/Makima/)
+- [Voicy Network Chainsaw Man Soundboard](https://www.voicy.network/official-soundboards/anime/chainsaw-man)
diff --git a/docs/research/rust-native-tts-research.md b/docs/research/rust-native-tts-research.md
new file mode 100644
index 0000000..5bc75f7
--- /dev/null
+++ b/docs/research/rust-native-tts-research.md
@@ -0,0 +1,297 @@
+# Rust-Native Qwen3-TTS Integration Research
+
+## Executive Summary
+
+This document researches integrating **Qwen3-TTS-12Hz-0.6B-Base** directly into the makima Rust codebase, replacing the current Chatterbox TTS implementation. The goal is a **pure Rust** solution — no Python, no separate microservice.
+
+**Bottom line:** A Rust-native integration is feasible but requires significant implementation work. The most viable path is using **candle** (HuggingFace's Rust ML framework) to implement the model architecture natively, loading safetensors weights directly. The existing ONNX-based approach used for Chatterbox TTS is **not viable** for Qwen3-TTS due to architecture incompatibilities with ONNX exporters.
+
+---
+
+## 1. Current TTS Implementation Analysis
+
+The existing Chatterbox TTS in `makima/src/tts.rs` uses:
+
+- **ONNX Runtime** via the `ort` crate (v2.0.0-rc.10)
+- **Four ONNX model files**: `speech_encoder.onnx`, `embed_tokens.onnx`, `language_model.onnx`, `conditional_decoder.onnx`
+- **tokenizers** crate for text tokenization
+- **ndarray** for tensor manipulation
+- **hf-hub** for model downloading
+- **Pipeline**: encode voice → tokenize text → autoregressive token generation with KV cache → decode tokens to waveform
+- **Architecture constants**: 24 layers, 16 KV heads, 64 head dim, 24kHz sample rate
+
+The pattern is well-established: download ONNX models from HuggingFace, load sessions, run inference with manual KV cache management.
+
+### STT Pattern (listen.rs)
+
+The Listen/STT handler in `makima/src/server/handlers/listen.rs` demonstrates the broader ML pattern:
+- WebSocket-based streaming
+- Lazy model loading via `SharedState::get_ml_models()`
+- Models held behind `tokio::sync::Mutex` for async access
+- `parakeet-rs` local crate for STT, `sortformer` for diarization
+- All models are Rust-native with ONNX backends
+
+---
+
+## 2. Qwen3-TTS-12Hz-0.6B-Base Architecture
+
+### Model Overview
+
+| Property | Value |
+|----------|-------|
+| **Parameters** | 0.6B |
+| **Architecture** | `Qwen3TTSForConditionalGeneration` |
+| **Output Sample Rate** | 24,000 Hz |
+| **Token Frame Rate** | 12.5 Hz (~80ms per token) |
+| **Model Format** | SafeTensors (1.83 GB main + 682 MB tokenizer) |
+| **Total Size** | ~2.52 GB |
+| **Precision** | bfloat16/float16 |
+
+### Components
+
+The model has **three distinct components**:
+
+#### A. Main Language Model (Talker) — 1.83 GB safetensors
+- Hidden size: 1024
+- Layers: 28
+- Attention heads: 16 (8 KV heads)
+- Intermediate size: 3072
+- Head dimension: 128
+- Text vocab size: 151,936
+- Max position embeddings: 32,768
+- Autoregressive transformer predicting speech token sequences from text
+
+#### B. Code Predictor (Multi-Token Prediction) — embedded in main model
+- Hidden size: 1024
+- Layers: 5
+- Attention heads: 16
+- Number of code groups: 16
+- Codebook vocab size: 2048
+- Predicts residual codebooks (16 layers) after the main LM predicts the zeroth codebook
+
+#### C. Speech Tokenizer (Qwen3-TTS-Tokenizer-12Hz) — 682 MB safetensors
+- Separate model in `speech_tokenizer/` directory
+- GAN-based codec: encoder + decoder
+- 16-layer multi-codebook RVQ (Residual Vector Quantization)
+- First codebook: semantic (WavLM-guided)
+- Remaining 15: acoustic details
+- **Decoder**: lightweight causal ConvNet (no DiT/diffusion needed)
+- Encodes reference audio → discrete codes, decodes codes → waveform
+
+### Inference Pipeline
+
+```
+Text Input + Reference Audio
+ ↓
+[Speech Tokenizer Encoder] → reference audio codes + speaker embedding
+ ↓
+[Text Tokenizer] → text token IDs
+ ↓
+[Language Model] → autoregressive generation of zeroth codebook tokens
+ ↓
+[Code Predictor / MTP] → predict remaining 15 codebook layers
+ ↓
+[Speech Tokenizer Decoder / Causal ConvNet] → waveform output (24kHz)
+```
+
+---
+
+## 3. ONNX Export Feasibility — NOT VIABLE
+
+### Status: No ONNX support exists
+
+- **No official ONNX export** from Qwen team
+- **No community ONNX conversion** for Qwen3-TTS
+- The Qwen3 architecture is **not supported** by HuggingFace Optimum's ONNX exporter
+- Users attempting export get: `ValueError: Trying to export a qwen3 model, that is a custom or unsupported architecture, but no custom onnx configuration was passed`
+- Even for base Qwen3 LLMs (non-TTS), ONNX export has significant issues with MoE routing, hybrid attention, and novel architecture components
+
+### Why ONNX Won't Work for Qwen3-TTS
+
+1. **Custom architecture** — `Qwen3TTSForConditionalGeneration` is not a standard transformer; it combines LM + code predictor + speech tokenizer
+2. **Multi-codebook MTP module** — the code predictor generates 16 codebook layers, a non-standard operation
+3. **Causal ConvNet decoder** — the speech tokenizer's decoder is a custom GAN-trained ConvNet, not a standard vocoder
+4. **Dynamic control flow** — dual-track streaming architecture with conditional branching
+5. **No Optimum support** — would require writing a custom ONNX config from scratch for each sub-component
+
+**Verdict: The ONNX path (matching our Chatterbox approach) is a dead end for Qwen3-TTS.**
+
+---
+
+## 4. Rust-Native Inference Options
+
+### Option A: Candle (HuggingFace) — RECOMMENDED
+
+[candle](https://github.com/huggingface/candle) is HuggingFace's minimalist Rust ML framework.
+
+**Why candle is the best fit:**
+
+| Factor | Assessment |
+|--------|------------|
+| **Qwen model support** | ✅ Has `qwen2` module in candle-transformers; Qwen3 variants supported |
+| **SafeTensors loading** | ✅ Native first-class support (safetensors is a Rust crate) |
+| **GPU support** | ✅ CUDA backend, Metal (macOS), CPU with MKL |
+| **Tokenizer support** | ✅ Uses the same `tokenizers` crate makima already depends on |
+| **Audio models** | ✅ Supports EnCodec, Whisper, MetaVoice, Parler-TTS |
+| **KV cache** | ✅ Well-established patterns in existing model implementations |
+| **Community** | ✅ Active; Crane project already lists Qwen3-TTS as "highest priority" |
+| **Binary size** | ✅ Compiles to single binary, no Python dependency |
+
+**What needs to be implemented:**
+
+1. **Qwen3-TTS transformer layers** — extend existing `qwen2` model code for the 28-layer LM with TTS-specific modifications (speaker encoder concatenation, code predictor output heads)
+2. **Code Predictor (MTP)** — 5-layer module that generates 16 codebook predictions from the LM hidden states
+3. **Speech Tokenizer Encoder** — ConvNet encoder that converts reference audio to discrete multi-codebook tokens + speaker embeddings
+4. **Speech Tokenizer Decoder** — causal ConvNet that reconstructs waveforms from discrete codes
+5. **Multi-codebook handling** — manage 16 parallel codebook sequences
+
+**Estimated effort:** Medium-High. The LM backbone can reuse existing Qwen2/3 code. The speech tokenizer (encoder + decoder) is the most novel component.
+
+**Key crate dependencies to add:**
+```toml
+candle-core = "0.8"
+candle-nn = "0.8"
+candle-transformers = "0.8"
+# Keep existing: tokenizers, hf-hub, ndarray (for compatibility)
+```
+
+### Option B: Crane (Candle-based TTS Engine)
+
+[Crane](https://github.com/lucasjinreal/Crane) is a pure Rust LLM inference engine built on candle, specifically designed for multi-modal models including TTS.
+
+**Key facts:**
+- Already supports Spark-TTS (codec-based TTS with similar architecture)
+- **Qwen3-TTS is listed as "Highest Priority" on their roadmap**
+- Handles multi-module architectures (codec + LLM pipelines)
+- Supports Qwen2.5, Moonshine ASR
+- Claims 50x faster than PyTorch on Apple Silicon
+
+**Strategy:** Monitor Crane's Qwen3-TTS implementation. If they ship it, we could either:
+- Use Crane as a dependency directly
+- Port their implementation into makima's codebase
+- Contribute to Crane and depend on it
+
+**Risk:** Crane is a relatively new project; depending on it adds supply chain risk.
+
+### Option C: qwen3-rs (Educational Reference)
+
+[qwen3-rs](https://github.com/reinterpretcat/qwen3-rs) is an educational project implementing Qwen3 inference from scratch in Rust.
+
+**Useful for:** Reference implementation of Qwen3 transformer layers, tokenization, KV cache, and safetensors loading — all without heavy ML framework dependencies. However, it only implements the base LLM, not the TTS-specific components.
+
+### Option D: Direct ort (ONNX Runtime) with Custom Export — FALLBACK
+
+If we could manually export each sub-component to ONNX:
+
+1. Export the 28-layer LM backbone (similar complexity to Chatterbox)
+2. Export the code predictor separately
+3. Export the speech tokenizer encoder/decoder
+
+This would match our existing Chatterbox pattern but requires Python scripting for the one-time export, and the Qwen3 architecture is explicitly unsupported by standard exporters. **Not recommended unless ONNX support materializes upstream.**
+
+### Option E: PyTorch C++ (libtorch) via FFI — NOT RECOMMENDED
+
+Using libtorch via Rust FFI bindings (`tch-rs` crate). This would:
+- Add a ~2GB libtorch dependency
+- Require complex build setup
+- Introduce C++ dependency management
+- Defeat the purpose of a pure Rust solution
+
+---
+
+## 5. Recommended Approach
+
+### Phase 1: Candle-Based Implementation
+
+**Architecture:**
+
+```
+makima/src/tts/
+├── mod.rs // TTS trait + factory (select Chatterbox vs Qwen3)
+├── chatterbox.rs // Existing ONNX-based Chatterbox (moved from tts.rs)
+├── qwen3/
+│ ├── mod.rs // Qwen3TTS public API
+│ ├── model.rs // Qwen3 LM transformer (28 layers)
+│ ├── code_predictor.rs // MTP module (5 layers, 16 codebooks)
+│ ├── speech_tokenizer.rs // Encoder + Decoder (causal ConvNet)
+│ ├── config.rs // Model config from config.json
+│ └── generate.rs // Autoregressive generation loop with KV cache
+```
+
+**Key implementation details:**
+
+1. **Load safetensors directly** — candle's `safetensors` support reads the 1.83GB main model and 682MB speech tokenizer
+2. **Reuse Qwen2 attention** — candle-transformers already has `qwen2::Model` with RoPE, GQA, and KV cache
+3. **Implement ConvNet codec** — the speech tokenizer's encoder/decoder is a causal 1D ConvNet; candle has `Conv1d` layers
+4. **Multi-codebook RVQ** — implement the 16-codebook residual vector quantization lookup
+5. **Speaker embedding** — extract from reference audio via the speech tokenizer encoder
+6. **Streaming support** — the 12Hz model's causal architecture enables token-by-token waveform generation
+
+### Phase 2: Voice Assets
+
+The model supports voice cloning with reference audio. For the default Makima voice:
+- Need 5-15 second Japanese-accented English audio clip
+- Reference audio + transcript fed to speech tokenizer encoder
+- Speaker embedding cached for reuse
+
+### Phase 3: Integration with Listen Page
+
+Following the pattern in `listen.rs`:
+- TTS model loaded lazily via `SharedState`
+- Protected behind `tokio::sync::Mutex` (or `RwLock` for concurrent reads)
+- WebSocket endpoint for streaming TTS (emit audio chunks as tokens are generated)
+- Bidirectional: STT (listen) → process → TTS (speak) loop
+
+---
+
+## 6. Comparison Matrix
+
+| Criteria | ONNX (current pattern) | Candle | Crane | libtorch |
+|----------|----------------------|--------|-------|----------|
+| Pure Rust | ✅ (ort crate) | ✅ | ✅ | ❌ (C++ FFI) |
+| Qwen3-TTS support | ❌ No export | ⚠️ Needs impl | ⚠️ Planned | ✅ (full PyTorch) |
+| Single binary | ✅ | ✅ | ✅ | ❌ |
+| GPU acceleration | ✅ | ✅ | ✅ | ✅ |
+| SafeTensors loading | ❌ (needs ONNX) | ✅ | ✅ | ✅ |
+| Streaming TTS | ✅ | ✅ | ✅ | ✅ |
+| Maintenance burden | Low | Medium | Low (if adopted) | High |
+| Implementation effort | N/A (blocked) | Medium-High | Low (if available) | Medium |
+| Dependency size | ~50MB | ~5MB | ~5MB | ~2GB |
+
+---
+
+## 7. Risk Assessment
+
+| Risk | Likelihood | Impact | Mitigation |
+|------|-----------|--------|------------|
+| Candle implementation takes longer than expected | Medium | Medium | Reference Crane's Spark-TTS implementation; use qwen3-rs as LM reference |
+| Speech tokenizer ConvNet is complex to port | Medium | High | Study the PyTorch source in qwen-tts package; ConvNet layers are simpler than transformers |
+| Model quality differs from reference PyTorch | Low | High | Validate with reference audio samples; ensure bfloat16 precision |
+| Crane ships Qwen3-TTS before we finish | Medium | Positive | Adopt their implementation |
+| GPU memory issues on target hardware | Low | Medium | 0.6B model is small (~2.5GB); fits in 4GB VRAM with float16 |
+
+---
+
+## 8. Next Steps
+
+1. **Immediate:** Add `candle-core`, `candle-nn`, `candle-transformers` to Cargo.toml
+2. **Week 1:** Implement Qwen3 LM backbone in candle (extend existing qwen2 model)
+3. **Week 2:** Implement speech tokenizer encoder/decoder (ConvNet + RVQ)
+4. **Week 2:** Implement code predictor (MTP module)
+5. **Week 3:** Integration testing with reference audio; validate output quality
+6. **Week 3:** Wire into makima server as TTS endpoint
+7. **Ongoing:** Monitor Crane project for Qwen3-TTS implementation
+
+---
+
+## Sources
+
+- [Qwen3-TTS-12Hz-0.6B-Base on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+- [Qwen3-TTS Technical Report (arXiv)](https://arxiv.org/html/2601.15621v1)
+- [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
+- [Candle — HuggingFace Rust ML Framework](https://github.com/huggingface/candle)
+- [Crane — Rust LLM Inference Engine](https://github.com/lucasjinreal/Crane)
+- [qwen3-rs — Educational Qwen3 Rust Implementation](https://github.com/reinterpretcat/qwen3-rs)
+- [candle-transformers Qwen2 model](https://docs.rs/candle-transformers/latest/candle_transformers/models/qwen2/index.html)
+- [Qwen3-TTS-Tokenizer-12Hz on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-Tokenizer-12Hz)
+- [ONNX export issues for Qwen3](https://huggingface.co/onnx-community/Qwen3-1.7B-ONNX/discussions/1)
diff --git a/docs/research/tts-qwen3-research.md b/docs/research/tts-qwen3-research.md
new file mode 100644
index 0000000..a961b4f
--- /dev/null
+++ b/docs/research/tts-qwen3-research.md
@@ -0,0 +1,548 @@
+# TTS Research: Qwen3-TTS-12Hz-0.6B-Base Integration
+
+## Executive Summary
+
+This document evaluates replacing the current Chatterbox TTS implementation with Qwen3-TTS-12Hz-0.6B-Base for the makima system. The goal is to enable near-real-time voice synthesis with voice cloning capabilities, defaulting to Makima's Japanese voice (Tomori Kusunoki) speaking English.
+
+**Key Findings:**
+- Qwen3-TTS offers superior streaming capabilities (~97ms latency) compared to the current batch-only Chatterbox implementation
+- Voice cloning requires only 3 seconds of reference audio
+- No official ONNX export exists; Python/PyTorch inference required
+- The 0.6B model is optimized for resource-constrained environments
+
+---
+
+## 1. Current TTS Implementation Analysis
+
+### 1.1 Architecture Overview
+
+The current implementation uses **Chatterbox-Turbo-ONNX** from ResembleAI:
+
+```
+Location: makima/src/tts.rs
+Model ID: ResembleAI/chatterbox-turbo-ONNX
+Sample Rate: 24,000 Hz
+```
+
+**Components:**
+| Component | File | Purpose |
+|-----------|------|---------|
+| `speech_encoder.onnx` | ~XX MB | Encodes reference audio to speaker embeddings |
+| `embed_tokens.onnx` | ~XX MB | Token embedding layer |
+| `language_model.onnx` | ~XX MB | Autoregressive text-to-speech token generation |
+| `conditional_decoder.onnx` | ~XX MB | Converts speech tokens to waveform |
+| `tokenizer.json` | ~KB | Text tokenization |
+
+### 1.2 Current API Surface
+
+```rust
+pub struct ChatterboxTTS {
+ speech_encoder: Session,
+ embed_tokens: Session,
+ language_model: Session,
+ conditional_decoder: Session,
+ tokenizer: Tokenizer,
+}
+
+impl ChatterboxTTS {
+ // Load from pretrained models
+ pub fn from_pretrained(model_dir: Option<&str>) -> Result<Self, TtsError>;
+
+ // Generate speech (requires voice reference)
+ pub fn generate_tts(&mut self, _text: &str) -> Result<Vec<f32>, TtsError>;
+
+ // Voice cloning from file path
+ pub fn generate_tts_with_voice(
+ &mut self,
+ text: &str,
+ sample_audio_path: &Path,
+ ) -> Result<Vec<f32>, TtsError>;
+
+ // Voice cloning from raw samples
+ pub fn generate_tts_with_samples(
+ &mut self,
+ text: &str,
+ samples: &[f32],
+ sample_rate: u32,
+ ) -> Result<Vec<f32>, TtsError>;
+}
+```
+
+### 1.3 Current Capabilities
+
+| Feature | Supported | Notes |
+|---------|-----------|-------|
+| Voice Cloning | **Yes** | Required for all synthesis |
+| Streaming | **No** | Batch generation only |
+| Languages | Limited | English-focused |
+| ONNX Runtime | **Yes** | CPU inference via `ort` crate |
+| GPU Acceleration | Partial | ONNX supports CUDA EP |
+| Real-time Factor | Unknown | Not benchmarked |
+
+### 1.4 Integration Points
+
+The TTS module is currently:
+- Exposed as `pub mod tts` in `lib.rs`
+- Used in `main.rs` for testing
+- **Not integrated with the web server** (no `/api/v1/tts` endpoint)
+
+The audio processing infrastructure is shared with the Listen (STT) module:
+- `audio.rs` provides format conversion utilities
+- `symphonia` for decoding various audio formats
+- Resampling to target sample rates (16kHz for STT, 24kHz for TTS)
+
+---
+
+## 2. Qwen3-TTS-12Hz-0.6B-Base Analysis
+
+### 2.1 Model Overview
+
+**Source:** [Hugging Face - Qwen/Qwen3-TTS-12Hz-0.6B-Base](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+
+| Specification | Value |
+|---------------|-------|
+| Parameters | 0.6B |
+| Release Date | January 22, 2026 |
+| Architecture | Dual-Track hybrid streaming LM |
+| Tokenizer | Qwen3-TTS-Tokenizer-12Hz |
+| Frame Rate | 12.5 Hz |
+| Output Sample Rate | 24 kHz |
+| Languages | 10 (Chinese, English, Japanese, Korean, German, French, Russian, Portuguese, Spanish, Italian) |
+
+### 2.2 Key Features
+
+| Feature | Status | Details |
+|---------|--------|---------|
+| **Voice Cloning** | Yes | 3-second minimum reference audio |
+| **Streaming** | Yes | 97ms end-to-end latency |
+| **Real-time** | Yes | First audio packet after single character |
+| **Multilingual** | Yes | 10 languages supported |
+| **Instruction Control** | No | Base model limitation |
+
+### 2.3 Streaming Architecture
+
+The Dual-Track architecture enables:
+1. **Streaming text input** - Processes text incrementally
+2. **Streaming audio output** - Emits audio packets as generated
+3. **Multi-Token Prediction (MTP)** - Enables immediate speech decoding from first codec frame
+
+**Latency Benchmarks:**
+- First token latency: ~97ms (end-to-end)
+- Optimized TTFT: ~170ms on RTX 5090 (community fork)
+- Initial implementations: ~800ms TTFT (before optimization)
+
+### 2.4 Voice Cloning Requirements
+
+| Requirement | Specification |
+|-------------|---------------|
+| Reference Audio Length | **3 seconds minimum** |
+| Audio Format | WAV, MP3, or common formats |
+| Input Methods | File path, URL, base64, numpy array |
+| Reference Text | **Required** (transcript of reference audio) |
+| X-Vector Only Mode | Optional (speaker embedding only, lower quality) |
+
+### 2.5 Python API
+
+```python
+from qwen_tts import Qwen3TTSModel
+
+# Load model
+model = Qwen3TTSModel.from_pretrained(
+ "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
+ device_map="cuda:0",
+ dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+)
+
+# Voice cloning
+wavs, sr = model.generate_voice_clone(
+ text="Hello, this is a test.",
+ language="English",
+ ref_audio="reference.wav",
+ ref_text="Original speaker text from reference",
+)
+
+# Reusable prompt (efficient for multiple generations)
+prompt = model.create_voice_clone_prompt(
+ ref_audio="reference.wav",
+ ref_text="Reference transcript",
+)
+
+wavs, sr = model.generate_voice_clone(
+ text="New text",
+ language="English",
+ voice_clone_prompt=prompt,
+)
+```
+
+### 2.6 Dependencies
+
+```
+pip install -U qwen-tts
+pip install -U flash-attn --no-build-isolation # Optional, recommended
+```
+
+**Requirements:**
+- Python 3.12 recommended
+- CUDA-capable GPU (for optimal performance)
+- FlashAttention 2 compatible hardware
+- PyTorch with bfloat16 support
+
+---
+
+## 3. Feasibility Assessment
+
+### 3.1 Streaming/Live TTS Feasibility
+
+**Assessment: FEASIBLE with caveats**
+
+| Factor | Current State | Path Forward |
+|--------|---------------|--------------|
+| Streaming API | Experimental (community fork) | Use [dffdeeq/Qwen3-TTS-streaming](https://github.com/dffdeeq/Qwen3-TTS-streaming) or wait for official support |
+| Latency Target | 97ms advertised | Achievable with optimization |
+| First Token | ~170ms optimized | Acceptable for conversational use |
+| Audio Stability | First 1-2s may have timbre issues | Known limitation, may need buffering |
+
+**Streaming Implementation Status:**
+- Official repository: Streaming documented but **not released**
+- Community fork: Working implementation with ~170ms TTFT
+- vLLM-Omni: Offline inference only (online serving planned)
+
+### 3.2 Voice Cloning for Makima
+
+**Assessment: FULLY FEASIBLE**
+
+Requirements for Makima voice cloning:
+1. **3+ seconds of clean audio** - Tomori Kusunoki (Japanese VA) speaking
+2. **Transcript of the audio** - Required for full quality
+3. **Audio format** - WAV/MP3 acceptable
+
+**Audio Sources:**
+- Chainsaw Man anime clips
+- Voice actress promotional material
+- Behind The Voice Actors database
+
+**Considerations:**
+- Japanese VA speaking English may work with explicit `language="English"` setting
+- May need English-speaking Makima clips (Suzie Yeung, English dub VA) as fallback
+- X-vector mode available if transcript is difficult to obtain
+
+### 3.3 Integration Complexity
+
+| Component | Complexity | Notes |
+|-----------|------------|-------|
+| Model Loading | Medium | Python subprocess or PyO3 bridge required |
+| Streaming API | High | WebSocket integration needed |
+| Voice Caching | Low | `create_voice_clone_prompt()` supports this |
+| Audio Format | Low | Both use 24kHz output |
+| ONNX Migration | N/A | **No ONNX export available** |
+
+### 3.4 ONNX vs Python Inference
+
+**Current approach (Chatterbox):** Rust + ONNX Runtime
+- Pros: Native Rust, low latency, CPU-friendly
+- Cons: Limited model ecosystem, no streaming
+
+**Required approach (Qwen3-TTS):** Python + PyTorch
+- Pros: Full model access, streaming support, GPU acceleration
+- Cons: Python subprocess overhead, dependency management
+
+**Integration Options:**
+
+1. **Python Subprocess/Service**
+ - Run `qwen-tts` as separate Python service
+ - Communicate via HTTP/WebSocket
+ - Cleanest separation, easiest to implement
+
+2. **PyO3 Bridge**
+ - Embed Python in Rust binary
+ - Higher complexity, tighter integration
+ - May have GIL contention issues
+
+3. **Custom ONNX Export** (Future)
+ - Not currently available
+ - Would require community effort
+ - No timeline from Qwen team
+
+**Recommendation:** Python service with WebSocket streaming
+
+---
+
+## 4. Audio Clip Requirements
+
+### 4.1 For Voice Cloning Setup
+
+| Requirement | Specification |
+|-------------|---------------|
+| Minimum Duration | 3 seconds |
+| Recommended Duration | 5-10 seconds |
+| Format | WAV (preferred), MP3 |
+| Sample Rate | Any (will be resampled) |
+| Content | Clear speech, minimal background noise |
+| Transcript | Required for full quality |
+
+### 4.2 Makima Voice Sources
+
+**Priority 1: Japanese VA (Tomori Kusunoki) speaking Japanese**
+- Source: Chainsaw Man anime
+- Challenge: Need clear dialogue without music/SFX
+- Fallback: May not transfer well to English output
+
+**Priority 2: English VA (Suzie Yeung)**
+- Source: Chainsaw Man English dub
+- Advantage: Native English speaker for English output
+- Trade-off: Different vocal characteristics from Japanese VA
+
+**Recommended Approach:**
+1. Extract 5-10 second clips of both VAs
+2. Test voice cloning quality with each
+3. Select based on English speech naturalness
+4. Store reference audio + transcript in `models/voices/makima/`
+
+### 4.3 Transcript Requirements
+
+For optimal voice cloning:
+```
+ref_audio: "models/voices/makima/makima-reference.wav"
+ref_text: "The exact words spoken in the reference audio"
+```
+
+X-vector fallback (lower quality, no transcript needed):
+```python
+prompt = model.create_voice_clone_prompt(
+ ref_audio="reference.wav",
+ x_vector_only_mode=True, # No transcript required
+)
+```
+
+---
+
+## 5. Preliminary Technical Approach
+
+### 5.1 Architecture Overview
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ Makima Server (Rust) │
+├─────────────────────────────────────────────────────────────┤
+│ ┌─────────────┐ ┌─────────────┐ ┌──────────────────────┐│
+│ │ Listen (STT)│ │ TTS Proxy │ │ Chat/Contract APIs ││
+│ │ /api/v1/ │ │ /api/v1/tts │ │ /api/v1/... ││
+│ │ listen │ │ │ │ ││
+│ └──────┬──────┘ └──────┬──────┘ └──────────────────────┘│
+│ │ │ │
+│ │ ┌──────▼──────┐ │
+│ │ │ WebSocket │ │
+│ │ │ Bridge │ │
+│ │ └──────┬──────┘ │
+└─────────┼────────────────┼──────────────────────────────────┘
+ │ │
+ │ ┌──────▼──────┐
+ │ │ Python TTS │
+ │ │ Service │
+ │ │ (Qwen3-TTS) │
+ │ └─────────────┘
+ │
+ ┌──────▼──────┐
+ │ ML Models │
+ │ (Parakeet, │
+ │ Sortformer) │
+ └─────────────┘
+```
+
+### 5.2 Python TTS Service
+
+**Proposed Architecture:**
+
+```python
+# tts_service.py
+import asyncio
+from fastapi import FastAPI, WebSocket
+from qwen_tts import Qwen3TTSModel
+
+app = FastAPI()
+model = None
+voice_prompts = {}
+
+@app.on_event("startup")
+async def load_model():
+ global model
+ model = Qwen3TTSModel.from_pretrained(
+ "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
+ device_map="cuda:0",
+ dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ )
+
+ # Pre-load Makima voice prompt
+ voice_prompts["makima"] = model.create_voice_clone_prompt(
+ ref_audio="models/voices/makima/reference.wav",
+ ref_text="[Makima reference transcript]",
+ )
+
+@app.websocket("/ws/tts")
+async def tts_stream(websocket: WebSocket):
+ await websocket.accept()
+ while True:
+ data = await websocket.receive_json()
+ text = data["text"]
+ voice = data.get("voice", "makima")
+ language = data.get("language", "English")
+
+ # Generate with streaming (when available)
+ prompt = voice_prompts.get(voice)
+ wavs, sr = model.generate_voice_clone(
+ text=text,
+ language=language,
+ voice_clone_prompt=prompt,
+ )
+
+ # Send audio chunks
+ await websocket.send_bytes(wavs[0].tobytes())
+
+@app.post("/api/tts")
+async def tts_batch(request: TTSRequest):
+ # Batch fallback for non-streaming use cases
+ ...
+```
+
+### 5.3 Rust Integration
+
+**New Endpoint: `/api/v1/tts`**
+
+```rust
+// server/handlers/tts.rs
+pub async fn tts_websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_tts_socket(socket, state))
+}
+
+async fn handle_tts_socket(socket: WebSocket, state: SharedState) {
+ // Proxy WebSocket to Python TTS service
+ let tts_client = state.tts_client.lock().await;
+
+ let (mut sender, mut receiver) = socket.split();
+
+ while let Some(msg) = receiver.next().await {
+ match msg {
+ Ok(Message::Text(text)) => {
+ // Forward to Python service
+ let response = tts_client.generate(text).await;
+
+ // Stream audio back
+ for chunk in response.audio_chunks {
+ sender.send(Message::Binary(chunk)).await.ok();
+ }
+ }
+ _ => {}
+ }
+ }
+}
+```
+
+### 5.4 Voice Prompt Caching
+
+```rust
+// Pre-computed voice prompts stored in state
+pub struct TtsConfig {
+ pub default_voice: String,
+ pub voices: HashMap<String, VoicePrompt>,
+}
+
+pub struct VoicePrompt {
+ pub name: String,
+ pub ref_audio_path: PathBuf,
+ pub ref_text: String,
+ pub language: String,
+ // Cached prompt from Python service
+ pub cached_prompt_id: Option<String>,
+}
+```
+
+### 5.5 Message Protocol
+
+**Client -> Server:**
+```json
+{
+ "type": "synthesize",
+ "text": "Hello, I am Makima.",
+ "voice": "makima",
+ "language": "English",
+ "stream": true
+}
+```
+
+**Server -> Client:**
+```json
+// Audio chunk
+{"type": "audio", "data": "<base64 PCM>", "sample_rate": 24000, "final": false}
+
+// Completion
+{"type": "complete", "duration_ms": 1234}
+
+// Error
+{"type": "error", "code": "SYNTHESIS_ERROR", "message": "..."}
+```
+
+---
+
+## 6. Implementation Phases
+
+### Phase 1: Research & Setup (Current)
+- [x] Analyze current TTS implementation
+- [x] Research Qwen3-TTS capabilities
+- [x] Document feasibility and approach
+- [ ] Acquire Makima voice reference clips
+- [ ] Test voice cloning quality
+
+### Phase 2: Python Service
+- [ ] Create Python TTS service with FastAPI
+- [ ] Implement batch TTS endpoint
+- [ ] Implement WebSocket streaming (when available)
+- [ ] Add voice prompt management
+- [ ] GPU optimization with FlashAttention 2
+
+### Phase 3: Rust Integration
+- [ ] Add TTS proxy endpoints to makima server
+- [ ] WebSocket bridge implementation
+- [ ] Voice configuration management
+- [ ] Error handling and fallbacks
+
+### Phase 4: Production Ready
+- [ ] Health checks for Python service
+- [ ] Voice prompt caching optimization
+- [ ] Latency benchmarking
+- [ ] Integration with Listen page
+
+---
+
+## 7. Open Questions
+
+1. **Streaming API Availability**: When will official streaming support be released?
+ - Fallback: Use community fork or batch with chunked responses
+
+2. **Voice Quality**: How well does Japanese VA voice clone to English speech?
+ - Action: Test with both Japanese and English VA samples
+
+3. **GPU Requirements**: What's the minimum VRAM for 0.6B model?
+ - Estimate: ~2-4GB with bf16 quantization
+
+4. **Latency Target**: What's acceptable for "close to live" TTS?
+ - Proposal: <500ms first audio, <100ms subsequent chunks
+
+5. **Transcript Acquisition**: How to obtain accurate transcripts for voice clips?
+ - Options: Manual transcription, Whisper ASR, community resources
+
+---
+
+## 8. References
+
+- [Qwen3-TTS-12Hz-0.6B-Base (Hugging Face)](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+- [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
+- [Qwen3-TTS Technical Report (arXiv)](https://arxiv.org/abs/2601.15621)
+- [Streaming Inference Issue #77](https://github.com/QwenLM/Qwen3-TTS/issues/77)
+- [Community Streaming Fork](https://github.com/dffdeeq/Qwen3-TTS-streaming)
+- [Makima Voice Actors](https://www.behindthevoiceactors.com/characters/Chainsaw-Man/Makima/)
+- [Chatterbox-Turbo-ONNX (Current Model)](https://huggingface.co/ResembleAI/chatterbox-turbo-ONNX)
diff --git a/docs/specs/qwen3-tts-spec.md b/docs/specs/qwen3-tts-spec.md
new file mode 100644
index 0000000..91d447d
--- /dev/null
+++ b/docs/specs/qwen3-tts-spec.md
@@ -0,0 +1,928 @@
+# Qwen3-TTS Integration Specification
+
+**Version:** 2.0
+**Date:** 2026-01-27
+**Status:** Draft
+**Author:** Makima Engineering
+
+## Table of Contents
+
+1. [Overview](#1-overview)
+2. [Functional Requirements](#2-functional-requirements)
+3. [Non-Functional Requirements](#3-non-functional-requirements)
+4. [Architecture Specification](#4-architecture-specification)
+5. [API Contract](#5-api-contract)
+6. [Voice Asset Requirements](#6-voice-asset-requirements)
+7. [Testing Strategy](#7-testing-strategy)
+8. [Implementation Phases](#8-implementation-phases)
+9. [Appendix](#appendix)
+
+---
+
+## 1. Overview
+
+### 1.1 Purpose
+
+This specification defines the integration of Qwen3-TTS-12Hz-0.6B-Base as a replacement for the existing Chatterbox-Turbo TTS implementation in the makima system. The new implementation is a **pure Rust** solution using the **candle** ML framework — no Python, no separate microservice. The TTS model runs directly inside the main makima process. The implementation will provide:
+
+- **Streaming TTS** with near-real-time audio synthesis
+- **Voice cloning** with default Makima voice (Japanese voice actress speaking English)
+- **Bidirectional speech integration** for the Listen page
+- **WebSocket-based streaming API** for low-latency delivery
+
+### 1.2 Background
+
+The current TTS implementation (Chatterbox-Turbo-ONNX) has limitations:
+- No streaming support (batch-only generation)
+- No HTTP/WebSocket endpoint exposed
+- High latency for interactive use cases
+
+Qwen3-TTS offers significant improvements:
+- **97ms** end-to-end latency (vs. batch processing)
+- **10 languages** supported including Japanese cross-lingual cloning
+- **3-second** reference audio for voice cloning
+- **Dual-track streaming architecture**
+
+### 1.3 Scope
+
+This specification covers:
+- WebSocket endpoint `/api/v1/speak` for streaming TTS
+- Pure Rust candle-based model inference running in-process
+- Voice asset management
+- Testing and benchmarking
+
+Out of scope:
+- ONNX export of Qwen3-TTS (not available)
+- Instruction-following TTS features (base model only)
+- Full replacement of STT/Listen functionality
+
+---
+
+## 2. Functional Requirements
+
+### 2.1 WebSocket Endpoint: `/api/v1/speak`
+
+The TTS service SHALL be exposed via a WebSocket endpoint at `/api/v1/speak` for streaming audio synthesis.
+
+#### 2.1.1 Connection Flow
+
+```
+Client Server (Rust/Axum)
+ | |
+ |--- WS Connect ------------------>|
+ | | [Load TTS model lazily if needed]
+ |<-- Ready (session_id) -----------|
+ | |
+ |--- Start (config) -------------->|
+ |<-- Started ----------------------| [Load voice prompt]
+ | |
+ |--- Speak (text) ---------------->| [Direct candle model inference]
+ | |
+ |<-- AudioChunk (binary) ----------|
+ |<-- AudioChunk (binary) ----------|
+ |<-- Complete ---------------------|
+ | |
+ |--- Stop ------------------------>|
+ |<-- Stopped ----------------------|
+ | |
+```
+
+#### 2.1.2 Voice Cloning
+
+The system SHALL support voice cloning with the following modes:
+
+| Mode | Description | Requirements |
+|------|-------------|--------------|
+| Default (Makima) | Pre-loaded Makima voice | None (auto-selected) |
+| Custom Voice | User-provided reference | Audio file + transcript |
+| X-Vector Only | Speaker embedding only | Audio file (no transcript) |
+
+**Default Voice Behavior:**
+- If no voice is specified, use the pre-loaded Makima voice prompt
+- Makima voice SHALL be a Japanese voice actress (Tomori Kusunoki) speaking English
+- Voice prompt is pre-computed at model load time for zero-latency switching
+
+#### 2.1.3 Message Protocol
+
+##### Client-to-Server Messages
+
+All messages use JSON format with a `type` field for routing.
+
+**Start Message** - Initialize TTS session
+```json
+{
+ "type": "start",
+ "sampleRate": 24000,
+ "encoding": "pcm16",
+ "voice": "makima",
+ "language": "English",
+ "authToken": "optional-jwt-token",
+ "contractId": "optional-contract-uuid"
+}
+```
+
+| Field | Type | Required | Description |
+|-------|------|----------|-------------|
+| `type` | string | Yes | Must be "start" |
+| `sampleRate` | number | No | Output sample rate (default: 24000) |
+| `encoding` | string | No | Audio encoding: "pcm16", "pcm32f" (default: "pcm16") |
+| `voice` | string | No | Voice ID or "makima" (default: "makima") |
+| `language` | string | No | Output language (default: "English") |
+| `authToken` | string | No | JWT for authenticated sessions |
+| `contractId` | string | No | Contract ID for context |
+
+**Speak Message** - Request speech synthesis
+```json
+{
+ "type": "speak",
+ "text": "Hello, I am Makima.",
+ "priority": "normal"
+}
+```
+
+| Field | Type | Required | Description |
+|-------|------|----------|-------------|
+| `type` | string | Yes | Must be "speak" |
+| `text` | string | Yes | Text to synthesize |
+| `priority` | string | No | "high" or "normal" (default: "normal") |
+
+**Stop Message** - End session
+```json
+{
+ "type": "stop",
+ "reason": "user_requested"
+}
+```
+
+**Cancel Message** - Cancel current synthesis
+```json
+{
+ "type": "cancel"
+}
+```
+
+##### Server-to-Client Messages
+
+**Ready Message** - Session established
+```json
+{
+ "type": "ready",
+ "sessionId": "uuid-string",
+ "voiceLoaded": "makima",
+ "capabilities": {
+ "streaming": true,
+ "languages": ["English", "Japanese", "Chinese", "Korean", "German", "French", "Russian", "Portuguese", "Spanish", "Italian"]
+ }
+}
+```
+
+**Started Message** - TTS session configured
+```json
+{
+ "type": "started",
+ "sampleRate": 24000,
+ "encoding": "pcm16",
+ "voice": "makima"
+}
+```
+
+**AudioChunk Message** - Streaming audio data
+```json
+{
+ "type": "audioChunk",
+ "data": "<base64-encoded-audio>",
+ "sequenceNumber": 1,
+ "isFinal": false,
+ "timestampMs": 1234567890
+}
+```
+
+For binary transport (recommended for performance):
+- Server MAY send raw binary WebSocket frames
+- Binary frames contain PCM audio data directly
+- JSON control messages indicate start/end of audio stream
+
+**Complete Message** - Synthesis finished
+```json
+{
+ "type": "complete",
+ "durationMs": 1500,
+ "charactersProcessed": 25,
+ "audioLengthMs": 2100
+}
+```
+
+**Error Message** - Error occurred
+```json
+{
+ "type": "error",
+ "code": "SYNTHESIS_ERROR",
+ "message": "Failed to generate audio",
+ "recoverable": true
+}
+```
+
+**Stopped Message** - Session ended
+```json
+{
+ "type": "stopped",
+ "reason": "user_requested"
+}
+```
+
+#### 2.1.4 Integration with Listen Page
+
+The TTS endpoint SHALL integrate with the existing Listen page (`/api/v1/listen`) to enable bidirectional speech:
+
+**Bidirectional Flow:**
+```
+User Speech -> /api/v1/listen (STT) -> Transcription
+ |
+ v
+ LLM Processing / Task Creation
+ |
+ v
+Response Text -> /api/v1/speak (TTS) -> Audio -> User
+```
+
+**Implementation Requirements:**
+1. Both endpoints SHALL support the same `contractId` for context sharing
+2. TTS SHALL support interruption when new STT input is detected
+3. Session management SHALL coordinate between STT and TTS
+
+### 2.2 Voice Configuration API
+
+#### 2.2.1 List Available Voices
+
+```
+GET /api/v1/voices
+```
+
+Response:
+```json
+{
+ "voices": [
+ {
+ "id": "makima",
+ "name": "Makima (Default)",
+ "language": "Japanese",
+ "description": "Default Makima voice (Tomori Kusunoki)",
+ "isDefault": true
+ }
+ ]
+}
+```
+
+#### 2.2.2 Upload Custom Voice (Future)
+
+```
+POST /api/v1/voices
+Content-Type: multipart/form-data
+
+audio: <audio-file>
+transcript: "Text spoken in the audio"
+name: "Custom Voice"
+```
+
+---
+
+## 3. Non-Functional Requirements
+
+### 3.1 Latency Requirements
+
+| Metric | Target | Maximum | Notes |
+|--------|--------|---------|-------|
+| First Audio Byte | < 200ms | 500ms | From text submission to first audio chunk |
+| Subsequent Chunks | < 50ms | 100ms | Inter-chunk latency |
+| End-to-End Latency | < 300ms | 800ms | Total time for short phrases |
+| Voice Prompt Loading | < 500ms | 2000ms | One-time at session start |
+
+**Measurement Points:**
+- T0: Client sends "speak" message
+- T1: First audio chunk received by client
+- T2: Last audio chunk received ("complete" message)
+- First Audio Latency = T1 - T0
+- Total Latency = T2 - T0
+
+### 3.2 Audio Quality Requirements
+
+| Specification | Value |
+|---------------|-------|
+| Output Sample Rate | 24,000 Hz |
+| Bit Depth | 16-bit (PCM16) or 32-bit float |
+| Channels | Mono (1 channel) |
+| Audio Codec | Raw PCM (WebSocket), WAV (download) |
+| Voice Similarity | > 0.90 speaker similarity score |
+
+**Quality Metrics:**
+- MOS (Mean Opinion Score): Target > 4.0
+- Speaker similarity to reference: Target > 0.90
+- No audible artifacts or glitches in streaming mode
+
+### 3.3 Hardware Requirements
+
+The TTS model runs directly in the makima process using candle.
+
+| Component | Minimum | Recommended |
+|-----------|---------|-------------|
+| GPU | CUDA-capable, 4GB VRAM (or Metal on macOS) | RTX 3060+ with 8GB+ VRAM |
+| GPU Memory | 4GB | 8GB |
+| System RAM | 8GB | 16GB |
+| Storage | 5GB (model weights) | 10GB |
+
+**GPU Memory Breakdown:**
+- Model weights (bf16): ~1.2GB
+- Speech tokenizer: ~682MB
+- KV cache during inference: ~1-2GB
+- Safety margin: ~1GB
+
+**CPU Fallback:**
+- candle supports CPU with MKL for systems without GPU
+- Latency will be higher but functional
+
+### 3.4 Scalability Requirements
+
+| Metric | Target |
+|--------|--------|
+| Concurrent Sessions | 10 per GPU |
+| Requests per Second | 50 text-to-speech requests |
+| Audio Throughput | 10 hours of audio per hour |
+
+### 3.5 Availability Requirements
+
+| Metric | Target |
+|--------|--------|
+| Service Uptime | 99.5% |
+| Recovery Time | < 30 seconds |
+| Graceful Degradation | Fall back to batch mode if streaming fails |
+
+---
+
+## 4. Architecture Specification
+
+### 4.1 System Architecture
+
+```
++-------------------------------------------------------------------------+
+| Client Application |
+| +-------------+ +-------------+ +------------------------------+ |
+| | Listen | | Speak | | UI Components | |
+| | (STT UI) | | (TTS UI) | | (Audio Player, Controls) | |
+| +------+------+ +------+------+ +------------------------------+ |
++---------|--------------------|------------------------------------------+
+ | WebSocket | WebSocket
+ | /api/v1/listen | /api/v1/speak
+ | |
++---------|--------------------|------------------------------------------+
+| | Makima Server (Rust/Axum) |
+| +------v--------------------v------+ |
+| | WebSocket Router | |
+| | (axum WebSocket handlers) | |
+| +------+--------------------+------+ |
+| | | |
+| +------v------+ +------v------+ +-----------------------------+ |
+| | Listen | | Speak | | Shared State | |
+| | Handler | | Handler | | - ML Models (STT) | |
+| | (STT/ML) | | (TTS/ML) | | - TTS Model (candle) | |
+| +-------------+ +------+------+ | - Voice Prompt Cache | |
+| | | - Session Manager | |
+| | +-----------------------------+ |
+| +------v------+ |
+| | TTS Module | |
+| | (candle) | |
+| +------+------+ |
+| | |
+| +------v------+ |
+| | Qwen3-TTS | |
+| | Components | |
+| | - LM (28L) | |
+| | - Code Pred | |
+| | - Speech Tok| |
+| +-------------+ |
++--------------------------------------------------------------------------+
+```
+
+### 4.2 TTS Module Structure
+
+```
+makima/src/tts/
+├── mod.rs // TTS trait + factory (select Chatterbox vs Qwen3)
+├── chatterbox.rs // Existing ONNX-based Chatterbox (moved from tts.rs)
+├── qwen3/
+│ ├── mod.rs // Qwen3TTS public API
+│ ├── model.rs // Qwen3 LM transformer (28 layers)
+│ ├── code_predictor.rs // MTP module (5 layers, 16 codebooks)
+│ ├── speech_tokenizer.rs // Encoder + Decoder (causal ConvNet)
+│ ├── config.rs // Model config from config.json
+│ └── generate.rs // Autoregressive generation loop with KV cache
+```
+
+#### 4.2.1 TTS Trait
+
+```rust
+// makima/src/tts/mod.rs
+
+/// Trait for text-to-speech implementations.
+#[async_trait]
+pub trait TtsEngine: Send + Sync {
+ /// Generate audio from text with a given voice prompt.
+ async fn generate(
+ &self,
+ text: &str,
+ voice_id: &str,
+ language: &str,
+ ) -> Result<Vec<AudioChunk>, TtsError>;
+
+ /// Load and cache a voice prompt from reference audio.
+ async fn load_voice(&self, voice_id: &str) -> Result<(), TtsError>;
+
+ /// Check if the engine is ready for inference.
+ fn is_ready(&self) -> bool;
+}
+
+/// Select the appropriate TTS engine based on configuration.
+pub fn create_engine(config: &TtsConfig) -> Box<dyn TtsEngine> {
+ match config.engine {
+ TtsEngineType::Qwen3 => Box::new(qwen3::Qwen3Tts::new(config)),
+ TtsEngineType::Chatterbox => Box::new(chatterbox::ChatterboxTts::new(config)),
+ }
+}
+```
+
+#### 4.2.2 Qwen3 Candle Implementation
+
+The Qwen3 module implements the three core model components using candle:
+
+1. **Language Model** — 28-layer transformer using candle-transformers' Qwen2 attention with TTS-specific modifications
+2. **Code Predictor** — 5-layer MTP module predicting 16 codebook layers
+3. **Speech Tokenizer** — GAN-based codec with Conv1d encoder/decoder
+
+**Key candle features used:**
+- `candle_core::Tensor` for all tensor operations
+- `candle_nn::Module` for model layers
+- `candle_nn::VarBuilder` for loading safetensors weights
+- `candle_core::Device` for GPU/CPU selection
+
+#### 4.2.3 Model Loading
+
+Models are loaded lazily on first TTS request, following the pattern established by `listen.rs`:
+
+```rust
+// Models held in SharedState behind async mutex
+pub struct TtsModels {
+ pub engine: Box<dyn TtsEngine>,
+ pub voice_cache: VoicePromptCache,
+}
+
+impl AppState {
+ pub async fn get_tts_models(&self) -> Result<&TtsModels, TtsError> {
+ self.tts_models.get_or_try_init(|| async {
+ // Load safetensors weights via candle
+ // Initialize voice cache with default Makima voice
+ }).await
+ }
+}
+```
+
+### 4.3 Speak Handler
+
+```rust
+// makima/src/server/handlers/speak.rs
+
+/// WebSocket handler for TTS streaming.
+/// Calls the TTS engine directly — no proxy, no external service.
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+
+ // Get or lazily load TTS models
+ let tts = match state.get_tts_models().await {
+ Ok(tts) => tts,
+ Err(e) => {
+ // Send error and close
+ return;
+ }
+ };
+
+ // Handle WebSocket messages directly
+ // Parse JSON commands, run inference, stream audio chunks back
+}
+```
+
+### 4.4 Voice Prompt Caching
+
+Voice prompts are cached in-memory using an LRU cache:
+
+```rust
+// makima/src/tts/mod.rs
+
+pub struct VoicePromptCache {
+ cache: tokio::sync::Mutex<lru::LruCache<String, VoicePrompt>>,
+}
+
+impl VoicePromptCache {
+ pub fn new(max_size: usize) -> Self { /* ... */ }
+ pub async fn get(&self, voice_id: &str) -> Option<VoicePrompt> { /* ... */ }
+ pub async fn insert(&self, voice_id: String, prompt: VoicePrompt) { /* ... */ }
+}
+```
+
+### 4.5 Error Handling and Recovery
+
+#### 4.5.1 Error Categories
+
+| Error Code | Category | Recoverable | Action |
+|------------|----------|-------------|--------|
+| `MODEL_LOADING` | Initialization | Yes | Wait and retry |
+| `SYNTHESIS_ERROR` | Generation | Yes | Retry with same input |
+| `INVALID_TEXT` | Input | No | Return error to client |
+| `VOICE_NOT_FOUND` | Configuration | No | Fall back to default voice |
+| `GPU_OUT_OF_MEMORY` | Resource | Yes | Clear cache, retry on CPU |
+| `TIMEOUT` | Inference | Yes | Retry with backoff |
+
+---
+
+## 5. API Contract
+
+### 5.1 WebSocket Message Formats
+
+#### 5.1.1 Client-to-Server Messages
+
+```typescript
+// TypeScript type definitions for client implementation
+
+interface StartMessage {
+ type: "start";
+ sampleRate?: number; // Default: 24000
+ encoding?: "pcm16" | "pcm32f"; // Default: "pcm16"
+ voice?: string; // Default: "makima"
+ language?: string; // Default: "English"
+ authToken?: string; // JWT for authenticated sessions
+ contractId?: string; // Contract context
+}
+
+interface SpeakMessage {
+ type: "speak";
+ text: string; // Required: text to synthesize
+ priority?: "normal" | "high"; // Default: "normal"
+}
+
+interface CancelMessage {
+ type: "cancel";
+}
+
+interface StopMessage {
+ type: "stop";
+ reason?: string;
+}
+
+type ClientMessage = StartMessage | SpeakMessage | CancelMessage | StopMessage;
+```
+
+#### 5.1.2 Server-to-Client Messages
+
+```typescript
+interface ReadyMessage {
+ type: "ready";
+ sessionId: string;
+ voiceLoaded: string;
+ capabilities: {
+ streaming: boolean;
+ languages: string[];
+ };
+}
+
+interface StartedMessage {
+ type: "started";
+ sampleRate: number;
+ encoding: string;
+ voice: string;
+}
+
+interface AudioChunkMessage {
+ type: "audioChunk";
+ data: string; // Base64-encoded PCM audio
+ sequenceNumber: number;
+ isFinal: boolean;
+ timestampMs: number;
+}
+
+interface CompleteMessage {
+ type: "complete";
+ durationMs: number;
+ charactersProcessed: number;
+ audioLengthMs: number;
+}
+
+interface ErrorMessage {
+ type: "error";
+ code: string;
+ message: string;
+ recoverable: boolean;
+}
+
+interface StoppedMessage {
+ type: "stopped";
+ reason: string;
+}
+
+type ServerMessage =
+ | ReadyMessage
+ | StartedMessage
+ | AudioChunkMessage
+ | CompleteMessage
+ | ErrorMessage
+ | StoppedMessage;
+```
+
+### 5.2 Error Codes
+
+| Code | HTTP-like | Description | Recovery |
+|------|-----------|-------------|----------|
+| `MODEL_LOADING` | 503 | Model still loading | Wait and retry |
+| `SYNTHESIS_ERROR` | 500 | Failed to generate audio | Retry |
+| `INVALID_TEXT` | 400 | Text is empty or invalid | Fix input |
+| `VOICE_NOT_FOUND` | 404 | Requested voice doesn't exist | Use default |
+| `UNAUTHORIZED` | 401 | Invalid or missing auth token | Re-authenticate |
+| `RATE_LIMITED` | 429 | Too many requests | Back off |
+| `TIMEOUT` | 408 | Operation timed out | Retry |
+| `CANCELLED` | 499 | Client cancelled request | N/A |
+
+### 5.3 Session Management
+
+#### 5.3.1 Session Lifecycle
+
+```
+DISCONNECTED -> CONNECTING -> READY -> STARTED -> SPEAKING -> READY -> ... -> STOPPED
+ | | | |
+ v v v v
+ ERROR ERROR ERROR STOPPED
+```
+
+---
+
+## 6. Voice Asset Requirements
+
+### 6.1 Makima Voice Clip Specifications
+
+#### 6.1.1 Audio Requirements
+
+| Specification | Requirement |
+|---------------|-------------|
+| Duration | 5-10 seconds (minimum 3s) |
+| Format | WAV (PCM) |
+| Sample Rate | 24,000 Hz or higher |
+| Bit Depth | 16-bit or higher |
+| Channels | Mono (preferred) or Stereo |
+| Content | Clear speech, natural tone |
+| Background | Minimal noise/music |
+
+#### 6.1.2 Content Guidelines
+
+**DO:**
+- Use dialogue with varied intonation
+- Include multiple phonemes
+- Capture natural speaking rhythm
+- Extract from clean audio scenes
+
+**DON'T:**
+- Include background music
+- Use shouting or whispering
+- Include sound effects
+- Use heavily processed audio
+
+#### 6.1.3 Transcript Requirements
+
+| Specification | Requirement |
+|---------------|-------------|
+| Format | Plain text (.txt) or JSON |
+| Encoding | UTF-8 |
+| Content | Exact transcription of audio |
+| Language | Japanese (for Japanese reference) |
+
+### 6.2 Storage Location and Management
+
+#### 6.2.1 Directory Structure
+
+```
+models/
+└── voices/
+ ├── makima/
+ │ ├── reference.wav # Primary reference audio
+ │ ├── transcript.txt # Plain text transcript
+ │ ├── transcript.json # Structured transcript (optional)
+ │ └── metadata.json # Voice metadata
+ ├── makima-alt/ # Alternative Makima clips (future)
+ │ └── ...
+ └── custom/ # User-uploaded voices (future)
+ └── {voice_id}/
+ ├── reference.wav
+ ├── transcript.txt
+ └── metadata.json
+```
+
+---
+
+## 7. Testing Strategy
+
+### 7.1 Unit Tests
+
+#### 7.1.1 Rust TTS Module Tests
+
+```rust
+// makima/src/tts/qwen3/tests.rs
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_config_loading() {
+ let config = Qwen3Config::from_json("test_config.json").unwrap();
+ assert_eq!(config.hidden_size, 1024);
+ assert_eq!(config.num_layers, 28);
+ }
+
+ #[test]
+ fn test_voice_prompt_cache_lru() {
+ let cache = VoicePromptCache::new(2);
+ cache.insert("a", prompt_a);
+ cache.insert("b", prompt_b);
+ cache.get("a"); // access a
+ cache.insert("c", prompt_c); // should evict b
+
+ assert!(cache.get("a").is_some());
+ assert!(cache.get("b").is_none());
+ assert!(cache.get("c").is_some());
+ }
+
+ #[tokio::test]
+ async fn test_speak_handler_message_parsing() {
+ let json = r#"{"type": "start", "voice": "makima"}"#;
+ let msg: SpeakClientMessage = serde_json::from_str(json).unwrap();
+
+ match msg {
+ SpeakClientMessage::Start(start) => {
+ assert_eq!(start.voice, Some("makima".to_string()));
+ }
+ _ => panic!("Expected Start message"),
+ }
+ }
+}
+```
+
+### 7.2 Integration Tests
+
+```rust
+// tests/tts_integration.rs
+
+#[tokio::test]
+async fn test_speak_websocket_flow() {
+ // Start test server with TTS enabled
+ let state = create_test_state_with_tts().await;
+ let app = make_router(state);
+
+ // Connect WebSocket
+ let ws = connect_ws("/api/v1/speak").await;
+
+ // Send start
+ ws.send_json(json!({"type": "start", "voice": "makima"})).await;
+ let ready = ws.recv_json().await;
+ assert_eq!(ready["type"], "ready");
+
+ // Send speak
+ ws.send_json(json!({"type": "speak", "text": "Hello."})).await;
+
+ // Collect audio chunks
+ let mut chunks = vec![];
+ loop {
+ let msg = ws.recv().await;
+ match msg {
+ WsMsg::Binary(data) => chunks.push(data),
+ WsMsg::Text(json) => {
+ let data: Value = serde_json::from_str(&json).unwrap();
+ if data["type"] == "complete" { break; }
+ }
+ }
+ }
+ assert!(!chunks.is_empty());
+}
+```
+
+### 7.3 Performance Targets
+
+| Metric | Target | Acceptable | Warning |
+|--------|--------|------------|---------|
+| First Audio (short) | < 150ms | < 200ms | > 300ms |
+| First Audio (medium) | < 200ms | < 300ms | > 500ms |
+| First Audio (long) | < 300ms | < 500ms | > 800ms |
+| Inter-chunk | < 30ms | < 50ms | > 100ms |
+| Memory (GPU) | < 4GB | < 6GB | > 8GB |
+| Memory (CPU) | < 2GB | < 4GB | > 8GB |
+
+---
+
+## 8. Implementation Phases
+
+### Phase 1: Candle-Based Qwen3-TTS Module (Week 1-2)
+
+**Deliverables:**
+- [ ] `makima/src/tts/mod.rs` — TTS trait + factory
+- [ ] `makima/src/tts/chatterbox.rs` — Move existing code from tts.rs
+- [ ] `makima/src/tts/qwen3/model.rs` — 28-layer LM backbone (extend candle Qwen2)
+- [ ] `makima/src/tts/qwen3/code_predictor.rs` — MTP module (5 layers, 16 codebooks)
+- [ ] `makima/src/tts/qwen3/speech_tokenizer.rs` — ConvNet encoder/decoder + RVQ
+- [ ] `makima/src/tts/qwen3/config.rs` — Config from safetensors
+- [ ] `makima/src/tts/qwen3/generate.rs` — Autoregressive generation with KV cache
+- [ ] Add `candle-core`, `candle-nn`, `candle-transformers` to Cargo.toml
+
+**Success Criteria:**
+- Model loads safetensors weights successfully
+- Can generate audio from text via direct inference
+- First audio latency < 500ms (initial, unoptimized)
+
+### Phase 2: WebSocket Handler + Voice Assets (Week 2-3)
+
+**Deliverables:**
+- [ ] Update `makima/src/server/handlers/speak.rs` — Direct TTS handler (no proxy)
+- [ ] Lazy model loading via `SharedState`
+- [ ] Voice prompt caching
+- [ ] Makima voice asset acquisition and processing
+- [ ] Basic error handling and session management
+
+**Success Criteria:**
+- `/api/v1/speak` endpoint produces streaming audio
+- Default Makima voice works
+- Error handling matches specification
+
+### Phase 3: Optimization + Integration (Week 3-4)
+
+**Deliverables:**
+- [ ] Streaming audio generation (token-by-token decoding)
+- [ ] GPU memory optimization
+- [ ] Listen page integration for bidirectional speech
+- [ ] Session coordination between STT and TTS
+- [ ] Full test suite (unit, integration)
+- [ ] Latency benchmarks
+
+**Success Criteria:**
+- First audio latency < 200ms
+- Memory usage < 6GB
+- All tests passing
+- Documentation complete
+
+---
+
+## Appendix
+
+### A. Dependencies
+
+#### Rust (Cargo.toml additions)
+
+```toml
+[dependencies]
+candle-core = "0.8"
+candle-nn = "0.8"
+candle-transformers = "0.8"
+# Keep existing: tokenizers, hf-hub, ndarray (for compatibility)
+```
+
+### B. Environment Variables
+
+```bash
+# TTS Configuration
+TTS_ENGINE=qwen3 # "qwen3" or "chatterbox"
+TTS_MODEL_ID=Qwen/Qwen3-TTS-12Hz-0.6B-Base
+TTS_DEVICE=cuda:0 # "cuda:0", "metal", or "cpu"
+TTS_VOICES_DIR=models/voices
+TTS_DEFAULT_VOICE=makima
+```
+
+### C. References
+
+1. [Qwen3-TTS-12Hz-0.6B-Base (Hugging Face)](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
+2. [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
+3. [Qwen3-TTS Technical Report (arXiv:2601.15621)](https://arxiv.org/abs/2601.15621)
+4. [Candle — HuggingFace Rust ML Framework](https://github.com/huggingface/candle)
+5. [axum WebSocket Documentation](https://docs.rs/axum/latest/axum/extract/ws/index.html)
+6. [docs/research/rust-native-tts-research.md](../research/rust-native-tts-research.md) — Detailed feasibility analysis
+
+### D. Glossary
+
+| Term | Definition |
+|------|------------|
+| **TTS** | Text-to-Speech: Converting text input to audio output |
+| **STT** | Speech-to-Text: Converting audio input to text output |
+| **Voice Cloning** | Creating synthetic speech that mimics a specific speaker |
+| **Voice Prompt** | Pre-computed speaker embedding for voice cloning |
+| **Candle** | HuggingFace's minimalist Rust ML framework |
+| **SafeTensors** | Efficient, safe model weight serialization format |
+| **RVQ** | Residual Vector Quantization — multi-codebook audio tokenization |
+| **MTP** | Multi-Token Prediction — code predictor generating 16 codebook layers |
+| **bf16** | Brain floating-point 16-bit format for GPU computation |
diff --git a/makima/Cargo.toml b/makima/Cargo.toml
index 950c123..b6b12dd 100644
--- a/makima/Cargo.toml
+++ b/makima/Cargo.toml
@@ -17,6 +17,12 @@ tokenizers = "0.21"
hf-hub = "0.4"
ndarray = "0.16"
+# Candle ML framework (Qwen3-TTS native inference)
+candle-core = "0.8"
+candle-nn = "0.8"
+candle-transformers = "0.8"
+safetensors = "0.4"
+
# Web server
axum = { version = "0.8", features = ["ws", "multipart"] }
tokio = { version = "1.0", features = ["full", "signal", "process"] }
diff --git a/makima/frontend/src/components/listen/ControlPanel.tsx b/makima/frontend/src/components/listen/ControlPanel.tsx
index f0e5702..f482ec4 100644
--- a/makima/frontend/src/components/listen/ControlPanel.tsx
+++ b/makima/frontend/src/components/listen/ControlPanel.tsx
@@ -1,6 +1,7 @@
import { useState } from "react";
import { Logo } from "../Logo";
import type { MicrophoneStatus } from "../../hooks/useMicrophone";
+import type { ConnectionStatus } from "../../hooks/useWebSocket";
import { ContractPickerModal } from "./ContractPickerModal";
export interface ContractOption {
@@ -22,6 +23,8 @@ interface ControlPanelProps {
selectedContractId: string | null;
onContractChange: (contractId: string | null) => void;
contractsLoading?: boolean;
+ // Connection status for loading state
+ connectionStatus?: ConnectionStatus;
}
function getStatusText(isListening: boolean, micStatus: MicrophoneStatus): string {
@@ -54,6 +57,7 @@ export function ControlPanel({
selectedContractId,
onContractChange,
contractsLoading,
+ connectionStatus,
}: ControlPanelProps) {
const [isModalOpen, setIsModalOpen] = useState(false);
const statusText = getStatusText(isListening, micStatus);
@@ -121,18 +125,36 @@ export function ControlPanel({
{/* Connection status */}
<div
- className={`inline-flex items-center gap-1.5 px-2 py-1 border ${
+ className={`inline-flex flex-col gap-1 px-2 py-1 border ${
isConnected
? "border-[#3f6fb3] text-[#75aafc]"
+ : connectionStatus === "connecting"
+ ? "border-[#3f6fb3] text-[#9bc3ff]"
: "border-[rgba(117,170,252,0.25)] text-[#9bc3ff]"
}`}
>
- <span
- className={`w-1.5 h-1.5 rounded-full ${
- isConnected ? "bg-[#75aafc]" : "bg-[#3f6fb3]"
- }`}
- />
- {isConnected ? "CONNECTED" : "DISCONNECTED"}
+ <div className="inline-flex items-center gap-1.5">
+ <span
+ className={`w-1.5 h-1.5 rounded-full ${
+ isConnected ? "bg-[#75aafc]" : "bg-[#3f6fb3]"
+ }`}
+ />
+ {isConnected
+ ? "CONNECTED"
+ : connectionStatus === "connecting"
+ ? "LOADING MODELS..."
+ : "DISCONNECTED"}
+ </div>
+ {connectionStatus === "connecting" && (
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-1/3 bg-[#75aafc]"
+ style={{
+ animation: "loading-slide 1.5s ease-in-out infinite",
+ }}
+ />
+ </div>
+ )}
</div>
</div>
diff --git a/makima/frontend/src/hooks/useSpeakWebSocket.ts b/makima/frontend/src/hooks/useSpeakWebSocket.ts
new file mode 100644
index 0000000..3ef8851
--- /dev/null
+++ b/makima/frontend/src/hooks/useSpeakWebSocket.ts
@@ -0,0 +1,329 @@
+import { useState, useCallback, useRef, useEffect } from "react";
+import { SPEAK_ENDPOINT } from "../lib/api";
+
+export type SpeakStatus =
+ | "disconnected"
+ | "connecting"
+ | "connected"
+ | "loading_model"
+ | "speaking"
+ | "error";
+
+export interface SpeakWebSocketState {
+ status: SpeakStatus;
+ error: string | null;
+}
+
+export function useSpeakWebSocket() {
+ const [state, setState] = useState<SpeakWebSocketState>({
+ status: "disconnected",
+ error: null,
+ });
+
+ const wsRef = useRef<WebSocket | null>(null);
+ const audioContextRef = useRef<AudioContext | null>(null);
+ const audioQueueRef = useRef<Float32Array[]>([]);
+ const isPlayingRef = useRef(false);
+ const modelLoadingTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
+ const nextPlayTimeRef = useRef(0);
+
+ // Clean up on unmount
+ useEffect(() => {
+ return () => {
+ if (wsRef.current) {
+ wsRef.current.close();
+ wsRef.current = null;
+ }
+ if (audioContextRef.current) {
+ audioContextRef.current.close();
+ audioContextRef.current = null;
+ }
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ };
+ }, []);
+
+ const getAudioContext = useCallback((): AudioContext => {
+ if (!audioContextRef.current || audioContextRef.current.state === "closed") {
+ audioContextRef.current = new AudioContext({ sampleRate: 24000 });
+ }
+ return audioContextRef.current;
+ }, []);
+
+ const playAudioQueue = useCallback(() => {
+ if (isPlayingRef.current) return;
+ isPlayingRef.current = true;
+
+ const ctx = getAudioContext();
+
+ function scheduleNext() {
+ const chunk = audioQueueRef.current.shift();
+ if (!chunk) {
+ isPlayingRef.current = false;
+ return;
+ }
+
+ const buffer = ctx.createBuffer(1, chunk.length, 24000);
+ buffer.copyToChannel(chunk, 0);
+
+ const source = ctx.createBufferSource();
+ source.buffer = buffer;
+ source.connect(ctx.destination);
+
+ // Schedule playback at the right time to avoid gaps
+ const now = ctx.currentTime;
+ const startTime = Math.max(now, nextPlayTimeRef.current);
+ source.start(startTime);
+ nextPlayTimeRef.current = startTime + buffer.duration;
+
+ source.onended = () => {
+ if (audioQueueRef.current.length > 0) {
+ scheduleNext();
+ } else {
+ isPlayingRef.current = false;
+ }
+ };
+ }
+
+ scheduleNext();
+ }, [getAudioContext]);
+
+ const connect = useCallback((): Promise<boolean> => {
+ return new Promise((resolve) => {
+ if (wsRef.current?.readyState === WebSocket.OPEN) {
+ resolve(true);
+ return;
+ }
+
+ if (wsRef.current) {
+ wsRef.current.close();
+ wsRef.current = null;
+ }
+
+ setState({ status: "connecting", error: null });
+
+ try {
+ const ws = new WebSocket(SPEAK_ENDPOINT);
+ ws.binaryType = "arraybuffer";
+ wsRef.current = ws;
+
+ ws.onopen = () => {
+ setState({ status: "connected", error: null });
+ resolve(true);
+ };
+
+ ws.onmessage = (event) => {
+ // Binary data = PCM audio chunk
+ if (event.data instanceof ArrayBuffer) {
+ // Clear model loading timer on first audio data
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ // Update status to speaking if not already
+ setState((s) => {
+ if (s.status === "loading_model" || s.status === "connected") {
+ return { ...s, status: "speaking" };
+ }
+ return s;
+ });
+
+ // Convert PCM16 LE to Float32
+ const pcm16 = new Int16Array(event.data);
+ const float32 = new Float32Array(pcm16.length);
+ for (let i = 0; i < pcm16.length; i++) {
+ float32[i] = pcm16[i] / 32768;
+ }
+
+ audioQueueRef.current.push(float32);
+ playAudioQueue();
+ return;
+ }
+
+ // Text data = JSON message
+ try {
+ const message = JSON.parse(event.data);
+
+ switch (message.type) {
+ case "audio_end":
+ // Clear model loading timer
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ // Wait for audio queue to drain, then go back to connected
+ // Use a short delay to let buffered audio finish
+ {
+ const checkDone = () => {
+ if (audioQueueRef.current.length === 0 && !isPlayingRef.current) {
+ setState((s) => {
+ if (s.status === "speaking" || s.status === "loading_model") {
+ return { ...s, status: "connected" };
+ }
+ return s;
+ });
+ } else {
+ setTimeout(checkDone, 100);
+ }
+ };
+ checkDone();
+ }
+ break;
+
+ case "error":
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+ setState({
+ status: "error",
+ error: message.message || `Error: ${message.code}`,
+ });
+ break;
+ }
+ } catch {
+ console.error("Failed to parse speak WebSocket message:", event.data);
+ }
+ };
+
+ ws.onerror = () => {
+ setState({
+ status: "error",
+ error: "Failed to connect to speak server",
+ });
+ resolve(false);
+ };
+
+ ws.onclose = (event) => {
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ let errorMessage: string | null = null;
+ if (event.code === 1006) {
+ errorMessage = "Connection failed - server may be unavailable";
+ } else if (event.code !== 1000 && event.code !== 1001) {
+ errorMessage = `Connection closed unexpectedly (code: ${event.code})`;
+ }
+
+ setState((s) => ({
+ status: "disconnected",
+ error: errorMessage || s.error,
+ }));
+ wsRef.current = null;
+ };
+ } catch (err) {
+ const message =
+ err instanceof Error ? err.message : "Failed to create WebSocket connection";
+ setState({ status: "error", error: message });
+ resolve(false);
+ }
+ });
+ }, [playAudioQueue]);
+
+ const speak = useCallback(
+ async (text: string) => {
+ if (!text.trim()) return;
+
+ // Connect if not connected
+ if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
+ const connected = await connect();
+ if (!connected) return;
+ }
+
+ // Reset audio state
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ // Resume audio context if suspended (browser autoplay policy)
+ const ctx = getAudioContext();
+ if (ctx.state === "suspended") {
+ await ctx.resume();
+ }
+
+ // Start loading timer - if no audio arrives in 2 seconds, show loading state
+ modelLoadingTimerRef.current = setTimeout(() => {
+ setState((s) => {
+ if (s.status === "connected" || s.status === "connecting") {
+ return { ...s, status: "loading_model" };
+ }
+ return s;
+ });
+ modelLoadingTimerRef.current = null;
+ }, 2000);
+
+ // Send speak request
+ wsRef.current?.send(
+ JSON.stringify({ type: "speak", text })
+ );
+
+ setState((s) => ({ ...s, error: null }));
+ },
+ [connect, getAudioContext]
+ );
+
+ const cancel = useCallback(() => {
+ // Clear audio queue
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ // Clear model loading timer
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ // Send cancel message
+ if (wsRef.current?.readyState === WebSocket.OPEN) {
+ wsRef.current.send(JSON.stringify({ type: "cancel" }));
+ }
+
+ setState((s) => ({
+ ...s,
+ status: wsRef.current?.readyState === WebSocket.OPEN ? "connected" : "disconnected",
+ }));
+ }, []);
+
+ const disconnect = useCallback(() => {
+ // Clear audio queue
+ audioQueueRef.current = [];
+ isPlayingRef.current = false;
+ nextPlayTimeRef.current = 0;
+
+ if (modelLoadingTimerRef.current) {
+ clearTimeout(modelLoadingTimerRef.current);
+ modelLoadingTimerRef.current = null;
+ }
+
+ if (wsRef.current) {
+ // Send stop message before closing
+ if (wsRef.current.readyState === WebSocket.OPEN) {
+ wsRef.current.send(JSON.stringify({ type: "stop" }));
+ }
+ wsRef.current.close(1000, "User disconnected");
+ wsRef.current = null;
+ }
+
+ setState({ status: "disconnected", error: null });
+ }, []);
+
+ return {
+ ...state,
+ isConnected:
+ state.status === "connected" ||
+ state.status === "speaking" ||
+ state.status === "loading_model",
+ isSpeaking: state.status === "speaking",
+ isModelLoading: state.status === "loading_model",
+ speak,
+ cancel,
+ connect,
+ disconnect,
+ };
+}
diff --git a/makima/frontend/src/index.css b/makima/frontend/src/index.css
index 5c08006..f29873b 100644
--- a/makima/frontend/src/index.css
+++ b/makima/frontend/src/index.css
@@ -64,6 +64,12 @@ body {
background: rgba(117, 170, 252, 0.35);
}
+/* Loading bar animation for indeterminate progress */
+@keyframes loading-slide {
+ 0% { transform: translateX(-100%); }
+ 100% { transform: translateX(300%); }
+}
+
/* Grid overlay */
.grid-overlay {
position: fixed;
diff --git a/makima/frontend/src/lib/api.ts b/makima/frontend/src/lib/api.ts
index 4390b20..ca04ce7 100644
--- a/makima/frontend/src/lib/api.ts
+++ b/makima/frontend/src/lib/api.ts
@@ -99,6 +99,7 @@ async function authFetch(url: string, options: RequestInit = {}): Promise<Respon
});
}
export const LISTEN_ENDPOINT = `${WS_BASE}/api/v1/listen`;
+export const SPEAK_ENDPOINT = `${WS_BASE}/api/v1/speak`;
export const FILE_SUBSCRIBE_ENDPOINT = `${WS_BASE}/api/v1/files/subscribe`;
export const TASK_SUBSCRIBE_ENDPOINT = `${WS_BASE}/api/v1/mesh/tasks/subscribe`;
diff --git a/makima/frontend/src/main.tsx b/makima/frontend/src/main.tsx
index 383b732..ef1ba5c 100644
--- a/makima/frontend/src/main.tsx
+++ b/makima/frontend/src/main.tsx
@@ -19,6 +19,7 @@ import LoginPage from "./routes/login";
import SettingsPage from "./routes/settings";
import ContractFilePage from "./routes/contract-file";
import TemplatesPage from "./routes/templates";
+import SpeakPage from "./routes/speak";
createRoot(document.getElementById("root")!).render(
<StrictMode>
@@ -135,6 +136,14 @@ createRoot(document.getElementById("root")!).render(
</ProtectedRoute>
}
/>
+ <Route
+ path="/speak"
+ element={
+ <ProtectedRoute>
+ <SpeakPage />
+ </ProtectedRoute>
+ }
+ />
</Routes>
</BrowserRouter>
</SupervisorQuestionsProvider>
diff --git a/makima/frontend/src/routes/listen.tsx b/makima/frontend/src/routes/listen.tsx
index 55cf7e6..8af538e 100644
--- a/makima/frontend/src/routes/listen.tsx
+++ b/makima/frontend/src/routes/listen.tsx
@@ -207,6 +207,7 @@ export default function ListenPage() {
selectedContractId={selectedContractId}
onContractChange={setSelectedContractId}
contractsLoading={contractsLoading}
+ connectionStatus={ws.status}
/>
</div>
</main>
diff --git a/makima/frontend/src/routes/speak.tsx b/makima/frontend/src/routes/speak.tsx
new file mode 100644
index 0000000..c4692ff
--- /dev/null
+++ b/makima/frontend/src/routes/speak.tsx
@@ -0,0 +1,159 @@
+import { useState, useCallback } from "react";
+import { Masthead } from "../components/Masthead";
+import { useSpeakWebSocket } from "../hooks/useSpeakWebSocket";
+
+export default function SpeakPage() {
+ const [text, setText] = useState("");
+ const tts = useSpeakWebSocket();
+
+ const handleSpeak = useCallback(() => {
+ if (!text.trim()) return;
+ tts.speak(text);
+ }, [text, tts]);
+
+ const handleCancel = useCallback(() => {
+ tts.cancel();
+ }, [tts]);
+
+ const handleKeyDown = useCallback(
+ (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
+ // Ctrl/Cmd + Enter to speak
+ if ((e.ctrlKey || e.metaKey) && e.key === "Enter") {
+ e.preventDefault();
+ handleSpeak();
+ }
+ },
+ [handleSpeak]
+ );
+
+ const statusLabel = (() => {
+ switch (tts.status) {
+ case "disconnected":
+ return "DISCONNECTED";
+ case "connecting":
+ return "CONNECTING...";
+ case "connected":
+ return "CONNECTED";
+ case "loading_model":
+ return "LOADING TTS MODEL...";
+ case "speaking":
+ return "SPEAKING";
+ case "error":
+ return "ERROR";
+ default:
+ return "IDLE";
+ }
+ })();
+
+ const statusColor = (() => {
+ switch (tts.status) {
+ case "connected":
+ case "speaking":
+ return "border-[#3f6fb3] text-[#75aafc]";
+ case "error":
+ return "border-red-400/50 text-red-400";
+ default:
+ return "border-[rgba(117,170,252,0.25)] text-[#9bc3ff]";
+ }
+ })();
+
+ const dotColor = (() => {
+ switch (tts.status) {
+ case "connected":
+ case "speaking":
+ return "bg-[#75aafc]";
+ case "error":
+ return "bg-red-400";
+ default:
+ return "bg-[#3f6fb3]";
+ }
+ })();
+
+ return (
+ <div className="relative z-10 h-screen flex flex-col overflow-hidden">
+ <Masthead showTicker={false} showNav />
+
+ <main className="flex-1 flex flex-col items-center justify-center p-4 md:p-8 gap-6 min-h-0 overflow-auto">
+ {/* Text input area */}
+ <div className="w-full max-w-2xl">
+ <textarea
+ value={text}
+ onChange={(e) => setText(e.target.value)}
+ onKeyDown={handleKeyDown}
+ placeholder="Enter text to speak..."
+ disabled={tts.isSpeaking || tts.isModelLoading}
+ className="w-full h-48 p-4 font-mono text-sm text-[#dbe7ff] bg-[#0d1b2d] border border-[#0f3c78] focus:border-[#3f6fb3] focus:outline-none placeholder-[#3f6fb3] resize-none transition-colors disabled:opacity-50"
+ />
+ <div className="mt-1 text-right font-mono text-xs text-[#3f6fb3]">
+ Ctrl+Enter to speak
+ </div>
+ </div>
+
+ {/* Controls row */}
+ <div className="w-full max-w-2xl flex items-center gap-4">
+ {/* Speak / Cancel button */}
+ {tts.isSpeaking || tts.isModelLoading ? (
+ <button
+ onClick={handleCancel}
+ className="px-6 py-2 font-mono text-sm text-red-400 bg-[#0d1b2d] border border-red-400/50 hover:border-red-400 transition-colors uppercase tracking-wide"
+ >
+ Cancel
+ </button>
+ ) : (
+ <button
+ onClick={handleSpeak}
+ disabled={!text.trim()}
+ className="px-6 py-2 font-mono text-sm text-[#dbe7ff] bg-[#0d1b2d] border border-[#0f3c78] hover:border-[#3f6fb3] transition-colors uppercase tracking-wide disabled:opacity-50 disabled:cursor-not-allowed"
+ >
+ Speak
+ </button>
+ )}
+
+ {/* Status indicator */}
+ <div
+ className={`inline-flex items-center gap-1.5 px-2 py-1 border font-mono text-xs tracking-wide uppercase ${statusColor}`}
+ >
+ <span className={`w-1.5 h-1.5 rounded-full ${dotColor}`} />
+ {statusLabel}
+ </div>
+ </div>
+
+ {/* Loading bar (indeterminate) */}
+ {tts.isModelLoading && (
+ <div className="w-full max-w-2xl">
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-1/3 bg-[#75aafc]"
+ style={{
+ animation: "loading-slide 1.5s ease-in-out infinite",
+ }}
+ />
+ </div>
+ <div className="mt-2 font-mono text-xs text-[#9bc3ff] text-center tracking-wide uppercase">
+ Loading TTS model... This may take a moment on first use.
+ </div>
+ </div>
+ )}
+
+ {/* Speaking animation bar */}
+ {tts.isSpeaking && (
+ <div className="w-full max-w-2xl">
+ <div className="w-full h-1.5 bg-[#0f1c2f] overflow-hidden">
+ <div
+ className="h-full w-full bg-[#75aafc] animate-pulse"
+ />
+ </div>
+ </div>
+ )}
+
+ {/* Error display */}
+ {tts.error && (
+ <div className="w-full max-w-2xl font-mono text-xs text-red-400 text-center px-4 py-2 border border-red-400/50 bg-red-400/10">
+ {tts.error}
+ </div>
+ )}
+ </main>
+
+ </div>
+ );
+}
diff --git a/makima/src/main.rs b/makima/src/main.rs
index 2348b23..1d87106 100644
--- a/makima/src/main.rs
+++ b/makima/src/main.rs
@@ -7,21 +7,9 @@ pub mod tts;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Loading ChatterboxTTS...");
- let mut tts = ChatterboxTTS::from_pretrained(None)?;
+ let tts = ChatterboxTTS::from_pretrained(None)?;
println!("Model loaded successfully!");
- // // Voice cloning using existing audio file
- // println!("Generating TTS with voice cloning...");
- // let audio = tts.generate_tts_with_voice(
- // "Hello, this is a test of the voice cloning system.",
- // Path::new("audio.wav")
- // )?;
- //
- // println!("Generated {} samples", audio.len());
- // save_wav(&audio, Path::new("output.wav"))?;
- // println!("Saved to output.wav");
-
-
// Load reference audio from mp3
println!("Loading reference audio...");
let reference = audio::to_16k_mono_from_path(Path::new("audio.mp3"))?;
diff --git a/makima/src/server/handlers/mod.rs b/makima/src/server/handlers/mod.rs
index b496922..8207399 100644
--- a/makima/src/server/handlers/mod.rs
+++ b/makima/src/server/handlers/mod.rs
@@ -17,6 +17,7 @@ pub mod mesh_red_team;
pub mod mesh_supervisor;
pub mod mesh_ws;
pub mod repository_history;
+pub mod speak;
pub mod templates;
pub mod transcript_analysis;
pub mod users;
diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs
new file mode 100644
index 0000000..75e7780
--- /dev/null
+++ b/makima/src/server/handlers/speak.rs
@@ -0,0 +1,274 @@
+//! WebSocket handler for TTS streaming (direct in-process inference).
+//!
+//! This module implements the `/api/v1/speak` endpoint which performs
+//! text-to-speech synthesis directly using the candle-based TTS engine.
+//! No external Python service or proxy — the model runs in-process.
+//!
+//! ## Architecture
+//!
+//! The speak handler will:
+//! 1. Accept a WebSocket connection from the client
+//! 2. Lazily load the TTS model (candle) on first request
+//! 3. Parse JSON control messages (start, speak, stop, cancel)
+//! 4. Run inference directly and stream audio chunks back
+//!
+//! See `makima/src/tts/` for the TTS engine implementation.
+//! See `docs/specs/qwen3-tts-spec.md` for the full protocol specification.
+
+use axum::{
+ extract::{ws::Message, ws::WebSocket, State, WebSocketUpgrade},
+ response::Response,
+};
+use futures::{SinkExt, StreamExt};
+use serde::Deserialize;
+use uuid::Uuid;
+
+use crate::server::state::SharedState;
+
+/// Client-to-server control messages.
+#[derive(Debug, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+enum ClientMessage {
+ /// Request speech synthesis for the given text.
+ Speak {
+ text: String,
+ /// Optional voice ID (e.g., "makima"). Not yet used — reserved for future voice selection.
+ #[serde(default)]
+ #[allow(dead_code)]
+ voice: Option<String>,
+ },
+ /// Cancel any in-progress synthesis.
+ Cancel,
+ /// Graceful close.
+ Stop,
+}
+
+/// WebSocket upgrade handler for TTS streaming.
+///
+/// This endpoint accepts WebSocket connections for text-to-speech synthesis.
+/// The TTS model runs directly in-process using candle — no external service.
+#[utoipa::path(
+ get,
+ path = "/api/v1/speak",
+ responses(
+ (status = 101, description = "WebSocket connection established"),
+ (status = 503, description = "TTS engine not available"),
+ ),
+ tag = "Speak"
+)]
+pub async fn websocket_handler(
+ ws: WebSocketUpgrade,
+ State(state): State<SharedState>,
+) -> Response {
+ ws.on_upgrade(|socket| handle_speak_socket(socket, state))
+}
+
+/// Handle TTS WebSocket session with direct in-process inference.
+///
+/// Protocol:
+/// - Client sends JSON `{ "type": "speak", "text": "..." }` messages
+/// - Server responds with binary audio chunks (16-bit PCM @ 24kHz)
+/// - Server sends JSON `{ "type": "audio_end" }` when synthesis is complete
+/// - Server sends JSON `{ "type": "error", ... }` on failures
+async fn handle_speak_socket(socket: WebSocket, state: SharedState) {
+ let session_id = Uuid::new_v4().to_string();
+ tracing::info!(session_id = %session_id, "New TTS WebSocket connection");
+
+ let (mut sender, mut receiver) = socket.split();
+
+ // Process incoming messages
+ while let Some(msg) = receiver.next().await {
+ let msg = match msg {
+ Ok(m) => m,
+ Err(e) => {
+ tracing::warn!(session_id = %session_id, error = %e, "WebSocket receive error");
+ break;
+ }
+ };
+
+ match msg {
+ Message::Text(text) => {
+ let client_msg: ClientMessage = match serde_json::from_str(&text) {
+ Ok(m) => m,
+ Err(e) => {
+ let _ = send_error(
+ &mut sender,
+ "INVALID_MESSAGE",
+ &format!("Failed to parse message: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ match client_msg {
+ ClientMessage::Speak { text, .. } => {
+ tracing::info!(
+ session_id = %session_id,
+ text_len = text.len(),
+ "TTS speak request"
+ );
+
+ // Get or lazily load the TTS engine
+ let engine = match state.get_tts_engine().await {
+ Ok(e) => e,
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "Failed to load TTS engine"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_LOAD_FAILED",
+ &format!("Failed to load TTS engine: {e}"),
+ )
+ .await;
+ continue;
+ }
+ };
+
+ if !engine.is_ready() {
+ let _ = send_error(
+ &mut sender,
+ "TTS_NOT_READY",
+ "TTS engine is not ready yet",
+ )
+ .await;
+ continue;
+ }
+
+ // Run TTS inference (no voice reference for now — uses default)
+ match engine.generate(&text, None, None).await {
+ Ok(chunks) => {
+ for chunk in &chunks {
+ // Send binary PCM audio data
+ let pcm_bytes = chunk.to_pcm16_bytes();
+ if sender
+ .send(Message::Binary(pcm_bytes.into()))
+ .await
+ .is_err()
+ {
+ tracing::warn!(
+ session_id = %session_id,
+ "Failed to send audio chunk — client disconnected"
+ );
+ return;
+ }
+ }
+
+ // Signal end of audio
+ let end_msg = serde_json::json!({
+ "type": "audio_end",
+ "sample_rate": engine.sample_rate(),
+ "format": "pcm_s16le",
+ "channels": 1,
+ });
+ let _ = sender
+ .send(Message::Text(end_msg.to_string().into()))
+ .await;
+ }
+ Err(e) => {
+ tracing::error!(
+ session_id = %session_id,
+ error = %e,
+ "TTS inference failed"
+ );
+ let _ = send_error(
+ &mut sender,
+ "TTS_INFERENCE_FAILED",
+ &format!("TTS inference failed: {e}"),
+ )
+ .await;
+ }
+ }
+ }
+ ClientMessage::Cancel => {
+ tracing::info!(session_id = %session_id, "TTS cancel requested");
+ // TODO: support cancellation of in-progress inference
+ }
+ ClientMessage::Stop => {
+ tracing::info!(session_id = %session_id, "TTS stop requested, closing");
+ break;
+ }
+ }
+ }
+ Message::Close(_) => {
+ tracing::info!(session_id = %session_id, "TTS WebSocket closed by client");
+ break;
+ }
+ _ => {
+ // Ignore ping/pong/binary from client
+ }
+ }
+ }
+
+ tracing::info!(session_id = %session_id, "TTS WebSocket connection closed");
+}
+
+/// Send an error message to the client.
+async fn send_error<S>(sender: &mut S, code: &str, message: &str) -> Result<(), axum::Error>
+where
+ S: SinkExt<Message> + Unpin,
+ <S as futures::Sink<Message>>::Error: std::error::Error,
+{
+ let error_msg = serde_json::json!({
+ "type": "error",
+ "code": code,
+ "message": message,
+ "recoverable": false
+ });
+
+ sender
+ .send(Message::Text(error_msg.to_string().into()))
+ .await
+ .ok();
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_error_message_format() {
+ let error = serde_json::json!({
+ "type": "error",
+ "code": "TEST_ERROR",
+ "message": "Test message",
+ "recoverable": false
+ });
+
+ assert_eq!(error["type"], "error");
+ assert_eq!(error["code"], "TEST_ERROR");
+ assert_eq!(error["message"], "Test message");
+ assert_eq!(error["recoverable"], false);
+ }
+
+ #[test]
+ fn test_client_message_parse_speak() {
+ let json = r#"{"type": "speak", "text": "Hello world"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ match msg {
+ ClientMessage::Speak { text, voice } => {
+ assert_eq!(text, "Hello world");
+ assert!(voice.is_none());
+ }
+ _ => panic!("Expected Speak message"),
+ }
+ }
+
+ #[test]
+ fn test_client_message_parse_cancel() {
+ let json = r#"{"type": "cancel"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Cancel));
+ }
+
+ #[test]
+ fn test_client_message_parse_stop() {
+ let json = r#"{"type": "stop"}"#;
+ let msg: ClientMessage = serde_json::from_str(json).unwrap();
+ assert!(matches!(msg, ClientMessage::Stop));
+ }
+}
diff --git a/makima/src/server/messages.rs b/makima/src/server/messages.rs
index 9c50334..cecb622 100644
--- a/makima/src/server/messages.rs
+++ b/makima/src/server/messages.rs
@@ -103,3 +103,164 @@ impl ApiError {
}
}
}
+
+// =============================================================================
+// TTS (Text-to-Speech) Message Types
+// =============================================================================
+
+/// TTS audio encoding format for WebSocket streaming.
+#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum TtsAudioEncoding {
+ /// 16-bit signed integer PCM samples
+ #[default]
+ Pcm16,
+ /// 32-bit floating point PCM samples
+ Pcm32f,
+}
+
+/// TTS synthesis priority level.
+#[derive(Debug, Clone, Copy, Deserialize, Serialize, ToSchema, PartialEq, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum TtsPriority {
+ /// Low priority - may be queued
+ Low,
+ /// Normal priority (default)
+ #[default]
+ Normal,
+ /// High priority - processed immediately
+ High,
+}
+
+/// TTS session start message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStartMessage {
+ /// Audio sample rate in Hz (default: 24000)
+ #[serde(default = "default_tts_sample_rate")]
+ pub sample_rate: u32,
+ /// Audio encoding format
+ #[serde(default)]
+ pub encoding: TtsAudioEncoding,
+ /// Voice identifier (default: "makima")
+ #[serde(default = "default_tts_voice")]
+ pub voice: String,
+ /// Language for synthesis (default: "English")
+ #[serde(default = "default_tts_language")]
+ pub language: String,
+}
+
+fn default_tts_sample_rate() -> u32 {
+ 24000
+}
+
+fn default_tts_voice() -> String {
+ "makima".to_string()
+}
+
+fn default_tts_language() -> String {
+ "English".to_string()
+}
+
+/// TTS speak request message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsSpeakMessage {
+ /// Text to synthesize (max 1000 characters)
+ pub text: String,
+ /// Synthesis priority
+ #[serde(default)]
+ pub priority: TtsPriority,
+}
+
+/// TTS stop request message from client.
+#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStopMessage {
+ /// Optional reason for stopping
+ pub reason: Option<String>,
+}
+
+/// Wrapper for all TTS WebSocket messages from client to server.
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TtsClientMessage {
+ /// Start a new TTS session
+ Start(TtsStartMessage),
+ /// Request speech synthesis
+ Speak(TtsSpeakMessage),
+ /// Stop the current session
+ Stop(TtsStopMessage),
+}
+
+/// TTS session ready message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsReadyMessage {
+ /// Unique session identifier
+ pub session_id: String,
+ /// Confirmed sample rate
+ pub sample_rate: u32,
+ /// Confirmed encoding format
+ pub encoding: TtsAudioEncoding,
+ /// Confirmed voice
+ pub voice: String,
+}
+
+/// TTS audio chunk message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsAudioChunkMessage {
+ /// Base64-encoded audio data
+ pub data: String,
+ /// Whether this is the final chunk
+ pub is_final: bool,
+ /// Timestamp in seconds from start of audio
+ pub timestamp: f64,
+}
+
+/// TTS synthesis complete message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsCompleteMessage {
+ /// Total synthesis duration in milliseconds
+ pub duration_ms: u64,
+ /// Total number of chunks sent
+ pub total_chunks: u32,
+ /// Length of input text
+ pub text_length: u32,
+}
+
+/// TTS error message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsErrorMessage {
+ /// Error code for programmatic handling
+ pub code: String,
+ /// Human-readable error message
+ pub message: String,
+}
+
+/// TTS session stopped message sent from server to client.
+#[derive(Debug, Clone, Serialize, ToSchema)]
+#[serde(rename_all = "camelCase")]
+pub struct TtsStoppedMessage {
+ /// Reason for stopping
+ pub reason: String,
+}
+
+/// Wrapper for all TTS WebSocket messages from server to client.
+#[derive(Debug, Clone, Serialize)]
+#[serde(tag = "type", rename_all = "camelCase")]
+pub enum TtsServerMessage {
+ /// Session is ready for synthesis requests
+ Ready(TtsReadyMessage),
+ /// Audio chunk (streamed during synthesis)
+ AudioChunk(TtsAudioChunkMessage),
+ /// Synthesis completed
+ Complete(TtsCompleteMessage),
+ /// Error occurred
+ Error(TtsErrorMessage),
+ /// Session has been stopped
+ Stopped(TtsStoppedMessage),
+}
diff --git a/makima/src/server/mod.rs b/makima/src/server/mod.rs
index b969650..7c13f08 100644
--- a/makima/src/server/mod.rs
+++ b/makima/src/server/mod.rs
@@ -18,7 +18,7 @@ use tower_http::trace::TraceLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
-use crate::server::handlers::{api_keys, chat, contract_chat, contract_daemon, contracts, file_ws, files, history, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_red_team, mesh_supervisor, mesh_ws, repository_history, templates, transcript_analysis, users, versions};
+use crate::server::handlers::{api_keys, chat, contract_chat, contract_daemon, contracts, file_ws, files, history, listen, mesh, mesh_chat, mesh_daemon, mesh_merge, mesh_red_team, mesh_supervisor, mesh_ws, repository_history, speak, templates, transcript_analysis, users, versions};
use crate::server::openapi::ApiDoc;
use crate::server::state::SharedState;
@@ -44,6 +44,7 @@ pub fn make_router(state: SharedState) -> Router {
// API v1 routes
let api_v1 = Router::new()
.route("/listen", get(listen::websocket_handler))
+ .route("/speak", get(speak::websocket_handler))
// Listen/transcript analysis endpoints
.route("/listen/analyze", post(transcript_analysis::analyze_transcript))
.route("/listen/create-contract", post(transcript_analysis::create_contract_from_analysis))
diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs
index 1bc7d7e..bf8f6f2 100644
--- a/makima/src/server/state.rs
+++ b/makima/src/server/state.rs
@@ -8,6 +8,7 @@ use uuid::Uuid;
use crate::listen::{DiarizationConfig, ParakeetEOU, ParakeetTDT, Sortformer};
use crate::server::auth::{AuthConfig, JwtVerifier};
+use crate::tts::TtsEngine;
/// Notification payload for file updates (broadcast to WebSocket subscribers).
#[derive(Debug, Clone)]
@@ -599,6 +600,8 @@ pub struct AppState {
pub jwt_verifier: Option<JwtVerifier>,
/// Pending worktree info requests awaiting daemon response (keyed by task_id)
pub pending_worktree_info: DashMap<Uuid, oneshot::Sender<WorktreeInfoResponse>>,
+ /// Lazily-loaded TTS engine (initialized on first Speak connection)
+ pub tts_engine: OnceCell<Box<dyn TtsEngine>>,
}
impl AppState {
@@ -673,9 +676,28 @@ impl AppState {
tool_keys: DashMap::new(),
jwt_verifier,
pending_worktree_info: DashMap::new(),
+ tts_engine: OnceCell::new(),
}
}
+ /// Get or initialize the TTS engine (lazy loading).
+ ///
+ /// The TTS engine is loaded on first Speak connection using the Qwen3 backend.
+ /// Returns a reference to the engine, or an error if loading fails.
+ pub async fn get_tts_engine(&self) -> Result<&dyn TtsEngine, Box<dyn std::error::Error + Send + Sync>> {
+ self.tts_engine.get_or_try_init(|| async {
+ tracing::info!("Lazy-loading TTS engine (Qwen3) on first Speak connection...");
+ let engine = crate::tts::TtsEngineFactory::create(
+ crate::tts::TtsBackend::Qwen3,
+ None, // Use default model directory
+ ).map_err(|e| -> Box<dyn std::error::Error + Send + Sync> {
+ Box::new(e)
+ })?;
+ tracing::info!("TTS engine loaded successfully");
+ Ok(engine)
+ }).await.map(|b| b.as_ref())
+ }
+
/// Get or initialize ML models (lazy loading).
///
/// Models are loaded on first call and cached for subsequent calls.
diff --git a/makima/src/tts.rs b/makima/src/tts/chatterbox.rs
index 5198938..e26bc06 100644
--- a/makima/src/tts.rs
+++ b/makima/src/tts/chatterbox.rs
@@ -1,17 +1,26 @@
-use std::path::{Path, PathBuf};
-use std::fs;
+//! Chatterbox TTS engine — ONNX-based (legacy).
+//!
+//! This is the existing Chatterbox TTS implementation moved from `tts.rs`,
+//! now implementing the `TtsEngine` trait for unified access.
-use hf_hub::api::sync::Api;
use std::borrow::Cow;
+use std::fs;
+use std::path::{Path, PathBuf};
+use std::sync::Mutex;
-use ndarray::{ArrayD, Array2, Array3, Array4, IxDyn};
+use hf_hub::api::sync::Api;
+use ndarray::{Array2, Array3, Array4, ArrayD, IxDyn};
use ort::session::Session;
-use ort::value::{Value, DynValue};
+use ort::value::{DynValue, Value};
use tokenizers::Tokenizer;
use crate::audio;
-pub const SAMPLE_RATE: u32 = 24_000;
+use super::{
+ apply_repetition_penalty, argmax, resample_to_24k, AudioChunk, TtsEngine, TtsError,
+ SAMPLE_RATE,
+};
+
const START_SPEECH_TOKEN: i64 = 6561;
const STOP_SPEECH_TOKEN: i64 = 6562;
const SILENCE_TOKEN: i64 = 4299;
@@ -22,57 +31,6 @@ const HEAD_DIM: usize = 64;
const MODEL_ID: &str = "ResembleAI/chatterbox-turbo-ONNX";
const DEFAULT_MODEL_DIR: &str = "models/chatterbox-turbo";
-#[derive(Debug)]
-pub enum TtsError {
- ModelLoad(String),
- Inference(String),
- Tokenizer(String),
- Audio(audio::AudioError),
- Io(std::io::Error),
- VoiceRequired,
-}
-
-impl std::fmt::Display for TtsError {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"),
- TtsError::Inference(msg) => write!(f, "inference error: {msg}"),
- TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"),
- TtsError::Audio(err) => write!(f, "audio error: {err}"),
- TtsError::Io(err) => write!(f, "io error: {err}"),
- TtsError::VoiceRequired => write!(f, "voice reference audio is required for chatterbox-turbo"),
- }
- }
-}
-
-impl std::error::Error for TtsError {}
-
-impl From<audio::AudioError> for TtsError {
- fn from(value: audio::AudioError) -> Self {
- TtsError::Audio(value)
- }
-}
-
-impl From<std::io::Error> for TtsError {
- fn from(value: std::io::Error) -> Self {
- TtsError::Io(value)
- }
-}
-
-impl From<ort::Error> for TtsError {
- fn from(value: ort::Error) -> Self {
- TtsError::ModelLoad(value.to_string())
- }
-}
-
-pub struct ChatterboxTTS {
- speech_encoder: Session,
- embed_tokens: Session,
- language_model: Session,
- conditional_decoder: Session,
- tokenizer: Tokenizer,
-}
-
struct VoiceCondition {
audio_features: ArrayD<f32>,
prompt_tokens: ArrayD<i64>,
@@ -100,6 +58,18 @@ fn extract_i64_tensor(value: &Value) -> Result<ArrayD<i64>, TtsError> {
.map_err(|e| TtsError::Inference(e.to_string()))
}
+pub struct ChatterboxTTS {
+ speech_encoder: Mutex<Session>,
+ embed_tokens: Mutex<Session>,
+ language_model: Mutex<Session>,
+ conditional_decoder: Mutex<Session>,
+ tokenizer: Tokenizer,
+}
+
+// SAFETY: Sessions are behind Mutex, Tokenizer is Send+Sync
+unsafe impl Send for ChatterboxTTS {}
+unsafe impl Sync for ChatterboxTTS {}
+
impl ChatterboxTTS {
pub fn from_pretrained(model_dir: Option<&str>) -> Result<Self, TtsError> {
let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
@@ -133,21 +103,20 @@ impl ChatterboxTTS {
.map_err(|e| TtsError::Tokenizer(e.to_string()))?;
Ok(Self {
- speech_encoder,
- embed_tokens,
- language_model,
- conditional_decoder,
+ speech_encoder: Mutex::new(speech_encoder),
+ embed_tokens: Mutex::new(embed_tokens),
+ language_model: Mutex::new(language_model),
+ conditional_decoder: Mutex::new(conditional_decoder),
tokenizer,
})
}
- pub fn generate_tts(&mut self, _text: &str) -> Result<Vec<f32>, TtsError> {
- // Chatterbox TTS requires voice reference audio
+ pub fn generate_tts(&self) -> Result<Vec<f32>, TtsError> {
Err(TtsError::VoiceRequired)
}
pub fn generate_tts_with_voice(
- &mut self,
+ &self,
text: &str,
sample_audio_path: &Path,
) -> Result<Vec<f32>, TtsError> {
@@ -157,7 +126,7 @@ impl ChatterboxTTS {
}
pub fn generate_tts_with_samples(
- &mut self,
+ &self,
text: &str,
samples: &[f32],
sample_rate: u32,
@@ -168,10 +137,8 @@ impl ChatterboxTTS {
samples.to_vec()
};
- // 1. Encode reference audio
let voice_condition = self.encode_voice(&resampled)?;
- // 2. Tokenize text
let encoding = self
.tokenizer
.encode(text, true)
@@ -179,24 +146,18 @@ impl ChatterboxTTS {
let text_input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
- // 3. Generate speech tokens
- let generated_tokens = self.generate_speech_tokens(
- &text_input_ids,
- &voice_condition.audio_features,
- )?;
+ let generated_tokens =
+ self.generate_speech_tokens(&text_input_ids, &voice_condition.audio_features)?;
- // 4. Prepare final speech tokens: prompt_tokens + generated + silence
let prompt_tokens: Vec<i64> = voice_condition.prompt_tokens.iter().copied().collect();
let silence_tokens = vec![SILENCE_TOKEN; 3];
- let mut final_tokens = Vec::with_capacity(
- prompt_tokens.len() + generated_tokens.len() + silence_tokens.len()
- );
+ let mut final_tokens =
+ Vec::with_capacity(prompt_tokens.len() + generated_tokens.len() + silence_tokens.len());
final_tokens.extend_from_slice(&prompt_tokens);
final_tokens.extend_from_slice(&generated_tokens);
final_tokens.extend_from_slice(&silence_tokens);
- // 5. Decode to audio
let audio_samples = self.decode_speech_tokens(
&final_tokens,
&voice_condition.speaker_embeddings,
@@ -206,15 +167,18 @@ impl ChatterboxTTS {
Ok(audio_samples)
}
- fn encode_voice(&mut self, samples: &[f32]) -> Result<VoiceCondition, TtsError> {
+ fn encode_voice(&self, samples: &[f32]) -> Result<VoiceCondition, TtsError> {
let audio_arr = Array2::from_shape_vec((1, samples.len()), samples.to_vec())
.map_err(|e| TtsError::Inference(e.to_string()))?;
let audio_tensor = Value::from_array(audio_arr)?;
- let outputs = self.speech_encoder.run(ort::inputs!["audio_values" => audio_tensor])?;
+ let mut encoder = self
+ .speech_encoder
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = encoder.run(ort::inputs!["audio_values" => audio_tensor])?;
- // Order: audio_features, audio_tokens (prompt_token), speaker_embeddings, speaker_features
let audio_features = extract_f32_tensor(&outputs[0])?;
let prompt_tokens = extract_i64_tensor(&outputs[1])?;
let speaker_embeddings = extract_f32_tensor(&outputs[2])?;
@@ -229,57 +193,56 @@ impl ChatterboxTTS {
}
fn generate_speech_tokens(
- &mut self,
+ &self,
text_input_ids: &[i64],
audio_features: &ArrayD<f32>,
) -> Result<Vec<i64>, TtsError> {
let max_new_tokens: usize = 1024;
let repetition_penalty: f32 = 1.2;
- // Start with START_SPEECH_TOKEN
let mut generate_tokens: Vec<i64> = vec![START_SPEECH_TOKEN];
- // Initialize empty KV cache (seq_len = 0)
- let mut past_key_values = self.init_kv_cache(0)?;
-
+ let mut past_key_values = Self::init_kv_cache(0);
let mut first_iteration = true;
let mut total_seq_len: usize = 0;
for _ in 0..max_new_tokens {
- // Get embeddings for current input_ids
let current_input_ids = if first_iteration {
- // First iteration: use text input_ids
text_input_ids.to_vec()
} else {
- // Subsequent iterations: use last generated token
vec![*generate_tokens.last().unwrap()]
};
- let input_ids_arr = Array2::from_shape_vec(
- (1, current_input_ids.len()),
- current_input_ids
- ).map_err(|e| TtsError::Inference(e.to_string()))?;
+ let input_ids_arr =
+ Array2::from_shape_vec((1, current_input_ids.len()), current_input_ids)
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
let input_ids_tensor = Value::from_array(input_ids_arr)?;
let inputs_embeds = {
- let embed_outputs = self.embed_tokens.run(ort::inputs![input_ids_tensor])?;
+ let mut embed = self
+ .embed_tokens
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let embed_outputs = embed.run(ort::inputs![input_ids_tensor])?;
extract_f32_tensor(&embed_outputs[0])?
};
- // On first iteration, concatenate audio features with text embeddings
let inputs_embeds = if first_iteration {
- let audio_feat_3d = audio_features.view()
+ let audio_feat_3d = audio_features
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
- let text_emb_3d = inputs_embeds.view()
+ let text_emb_3d = inputs_embeds
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
ndarray::concatenate(ndarray::Axis(1), &[audio_feat_3d, text_emb_3d])
.map_err(|e| TtsError::Inference(e.to_string()))?
} else {
- inputs_embeds.view()
+ inputs_embeds
+ .view()
.into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?
.to_owned()
@@ -287,7 +250,6 @@ impl ChatterboxTTS {
let seq_len = inputs_embeds.shape()[1];
- // Set up attention mask and position ids
let (attention_mask, position_ids) = if first_iteration {
total_seq_len = seq_len;
let attention_mask: Array2<i64> = Array2::ones((1, seq_len));
@@ -296,14 +258,12 @@ impl ChatterboxTTS {
} else {
total_seq_len += 1;
let attention_mask: Array2<i64> = Array2::ones((1, total_seq_len));
- let position_ids = Array2::from_shape_vec(
- (1, 1),
- vec![(total_seq_len - 1) as i64]
- ).map_err(|e| TtsError::Inference(e.to_string()))?;
+ let position_ids =
+ Array2::from_shape_vec((1, 1), vec![(total_seq_len - 1) as i64])
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
(attention_mask, position_ids)
};
- // Run language model
let (logits, new_kv) = self.run_language_model(
inputs_embeds,
position_ids,
@@ -313,8 +273,9 @@ impl ChatterboxTTS {
past_key_values = new_kv;
- // Get last logits
- let logits_3d = logits.view().into_dimensionality::<ndarray::Ix3>()
+ let logits_3d = logits
+ .view()
+ .into_dimensionality::<ndarray::Ix3>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
let last_idx = logits_3d.shape()[1] - 1;
@@ -324,12 +285,9 @@ impl ChatterboxTTS {
.copied()
.collect();
- // Apply repetition penalty
apply_repetition_penalty(&mut current_logits, &generate_tokens, repetition_penalty);
- // Get next token
let next_token = argmax(&current_logits);
-
generate_tokens.push(next_token);
if next_token == STOP_SPEECH_TOKEN {
@@ -339,15 +297,14 @@ impl ChatterboxTTS {
first_iteration = false;
}
- // Return tokens without START and STOP tokens: [1:-1]
if generate_tokens.len() > 2 {
- Ok(generate_tokens[1..generate_tokens.len()-1].to_vec())
+ Ok(generate_tokens[1..generate_tokens.len() - 1].to_vec())
} else {
Ok(Vec::new())
}
}
- fn init_kv_cache(&self, seq_len: usize) -> Result<Vec<Array4<f32>>, TtsError> {
+ fn init_kv_cache(seq_len: usize) -> Vec<Array4<f32>> {
let mut cache = Vec::with_capacity(NUM_LAYERS * 2);
for _ in 0..NUM_LAYERS {
let key = Array4::<f32>::zeros((1, NUM_KV_HEADS, seq_len, HEAD_DIM));
@@ -355,11 +312,11 @@ impl ChatterboxTTS {
cache.push(key);
cache.push(value);
}
- Ok(cache)
+ cache
}
fn run_language_model(
- &mut self,
+ &self,
inputs_embeds: Array3<f32>,
position_ids: Array2<i64>,
attention_mask: Array2<i64>,
@@ -367,23 +324,37 @@ impl ChatterboxTTS {
) -> Result<(ArrayD<f32>, Vec<Array4<f32>>), TtsError> {
let mut inputs: Vec<(Cow<str>, DynValue)> = Vec::new();
- inputs.push((Cow::from("inputs_embeds"), Value::from_array(inputs_embeds)?.into_dyn()));
- inputs.push((Cow::from("position_ids"), Value::from_array(position_ids)?.into_dyn()));
- inputs.push((Cow::from("attention_mask"), Value::from_array(attention_mask)?.into_dyn()));
+ inputs.push((
+ Cow::from("inputs_embeds"),
+ Value::from_array(inputs_embeds)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("position_ids"),
+ Value::from_array(position_ids)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("attention_mask"),
+ Value::from_array(attention_mask)?.into_dyn(),
+ ));
- // Add KV cache inputs
for layer_idx in 0..NUM_LAYERS {
let key_name = format!("past_key_values.{}.key", layer_idx);
let value_name = format!("past_key_values.{}.value", layer_idx);
- let key_tensor = Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn();
- let value_tensor = Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn();
+ let key_tensor =
+ Value::from_array(past_key_values[layer_idx * 2].clone())?.into_dyn();
+ let value_tensor =
+ Value::from_array(past_key_values[layer_idx * 2 + 1].clone())?.into_dyn();
inputs.push((Cow::from(key_name), key_tensor));
inputs.push((Cow::from(value_name), value_tensor));
}
- let outputs = self.language_model.run(inputs)?;
+ let mut lm = self
+ .language_model
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = lm.run(inputs)?;
let logits = extract_f32_tensor(&outputs[0])?;
@@ -395,9 +366,11 @@ impl ChatterboxTTS {
let key_arr = extract_f32_tensor(&outputs[key_idx])?;
let value_arr = extract_f32_tensor(&outputs[value_idx])?;
- let key_4d = key_arr.into_dimensionality::<ndarray::Ix4>()
+ let key_4d = key_arr
+ .into_dimensionality::<ndarray::Ix4>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
- let value_4d = value_arr.into_dimensionality::<ndarray::Ix4>()
+ let value_4d = value_arr
+ .into_dimensionality::<ndarray::Ix4>()
.map_err(|e| TtsError::Inference(e.to_string()))?;
new_kv.push(key_4d.to_owned());
@@ -408,7 +381,7 @@ impl ChatterboxTTS {
}
fn decode_speech_tokens(
- &mut self,
+ &self,
speech_tokens: &[i64],
speaker_embeddings: &ArrayD<f32>,
speaker_features: &ArrayD<f32>,
@@ -417,15 +390,29 @@ impl ChatterboxTTS {
return Ok(Vec::new());
}
- let tokens_arr = Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec())
- .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let tokens_arr =
+ Array2::from_shape_vec((1, speech_tokens.len()), speech_tokens.to_vec())
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
let mut inputs: Vec<(Cow<str>, DynValue)> = Vec::new();
- inputs.push((Cow::from("speech_tokens"), Value::from_array(tokens_arr)?.into_dyn()));
- inputs.push((Cow::from("speaker_embeddings"), Value::from_array(speaker_embeddings.clone())?.into_dyn()));
- inputs.push((Cow::from("speaker_features"), Value::from_array(speaker_features.clone())?.into_dyn()));
-
- let outputs = self.conditional_decoder.run(inputs)?;
+ inputs.push((
+ Cow::from("speech_tokens"),
+ Value::from_array(tokens_arr)?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("speaker_embeddings"),
+ Value::from_array(speaker_embeddings.clone())?.into_dyn(),
+ ));
+ inputs.push((
+ Cow::from("speaker_features"),
+ Value::from_array(speaker_features.clone())?.into_dyn(),
+ ));
+
+ let mut decoder = self
+ .conditional_decoder
+ .lock()
+ .map_err(|e| TtsError::Inference(e.to_string()))?;
+ let outputs = decoder.run(inputs)?;
let waveform = extract_f32_tensor(&outputs[0])?;
@@ -433,6 +420,34 @@ impl ChatterboxTTS {
}
}
+#[async_trait::async_trait]
+impl TtsEngine for ChatterboxTTS {
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ let samples = match reference_audio {
+ Some(audio) => {
+ let sr = reference_sample_rate.unwrap_or(SAMPLE_RATE);
+ self.generate_tts_with_samples(text, audio, sr)?
+ }
+ None => return Err(TtsError::VoiceRequired),
+ };
+
+ Ok(vec![AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }])
+ }
+
+ fn is_ready(&self) -> bool {
+ true
+ }
+}
+
fn download_models(target_dir: &Path) -> Result<(), TtsError> {
fs::create_dir_all(target_dir)?;
@@ -453,7 +468,9 @@ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
for file in &model_files {
println!("Downloading {}...", file);
- let downloaded_path = repo.get(file).map_err(|e| TtsError::ModelLoad(e.to_string()))?;
+ let downloaded_path = repo
+ .get(file)
+ .map_err(|e| TtsError::ModelLoad(e.to_string()))?;
let filename = Path::new(file).file_name().unwrap();
let target_path = target_dir.join(filename);
@@ -466,115 +483,3 @@ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
println!("Models downloaded to {:?}", target_dir);
Ok(())
}
-
-fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec<f32> {
- if input_rate == SAMPLE_RATE {
- return samples.to_vec();
- }
- if samples.is_empty() {
- return Vec::new();
- }
-
- let ratio = input_rate as f64 / SAMPLE_RATE as f64;
- let output_len = ((samples.len() as f64) / ratio).ceil() as usize;
-
- let mut output = Vec::with_capacity(output_len);
- for i in 0..output_len {
- let src_idx = (i as f64 * ratio) as usize;
- let sample = samples.get(src_idx).copied().unwrap_or(0.0);
- output.push(sample);
- }
-
- output
-}
-
-fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) {
- for &token in generated {
- if (token as usize) < logits.len() {
- let score = logits[token as usize];
- // Note: opposite of standard - if score < 0, multiply; if > 0, divide
- logits[token as usize] = if score < 0.0 {
- score * penalty
- } else {
- score / penalty
- };
- }
- }
-}
-
-fn argmax(logits: &[f32]) -> i64 {
- logits
- .iter()
- .enumerate()
- .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
- .map(|(idx, _)| idx as i64)
- .unwrap_or(0)
-}
-
-pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> {
- let mut file = fs::File::create(path)?;
- write_wav(&mut file, samples, SAMPLE_RATE)?;
- Ok(())
-}
-
-fn write_wav<W: std::io::Write>(writer: &mut W, samples: &[f32], sample_rate: u32) -> Result<(), std::io::Error> {
- let num_samples = samples.len() as u32;
- let num_channels: u16 = 1;
- let bits_per_sample: u16 = 16;
- let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
- let block_align = num_channels * bits_per_sample / 8;
- let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
- let file_size = 36 + data_size;
-
- writer.write_all(b"RIFF")?;
- writer.write_all(&file_size.to_le_bytes())?;
- writer.write_all(b"WAVE")?;
-
- writer.write_all(b"fmt ")?;
- writer.write_all(&16u32.to_le_bytes())?;
- writer.write_all(&1u16.to_le_bytes())?;
- writer.write_all(&num_channels.to_le_bytes())?;
- writer.write_all(&sample_rate.to_le_bytes())?;
- writer.write_all(&byte_rate.to_le_bytes())?;
- writer.write_all(&block_align.to_le_bytes())?;
- writer.write_all(&bits_per_sample.to_le_bytes())?;
-
- writer.write_all(b"data")?;
- writer.write_all(&data_size.to_le_bytes())?;
-
- for &sample in samples {
- let clamped = sample.clamp(-1.0, 1.0);
- let int_sample = (clamped * 32767.0) as i16;
- writer.write_all(&int_sample.to_le_bytes())?;
- }
-
- Ok(())
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_argmax() {
- let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
- assert_eq!(argmax(&logits), 3);
- }
-
- #[test]
- fn test_resample_same_rate() {
- let samples = vec![0.1, 0.2, 0.3];
- let resampled = resample_to_24k(&samples, SAMPLE_RATE);
- assert_eq!(resampled, samples);
- }
-
- #[test]
- fn test_repetition_penalty() {
- let mut logits = vec![1.0, 2.0, 3.0, 4.0];
- let generated = vec![1, 3];
- apply_repetition_penalty(&mut logits, &generated, 1.2);
- // score > 0 -> divide
- assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6);
- assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6);
- }
-}
diff --git a/makima/src/tts/mod.rs b/makima/src/tts/mod.rs
new file mode 100644
index 0000000..2cd0412
--- /dev/null
+++ b/makima/src/tts/mod.rs
@@ -0,0 +1,281 @@
+//! TTS engine abstraction and implementations.
+//!
+//! Provides a trait-based TTS engine interface with two backends:
+//! - **Chatterbox**: ONNX-based TTS (legacy)
+//! - **Qwen3**: Pure Rust candle-based Qwen3-TTS-12Hz-0.6B
+
+use std::path::Path;
+
+pub mod chatterbox;
+pub mod qwen3;
+
+// Re-export primary types
+pub use chatterbox::ChatterboxTTS;
+pub use qwen3::Qwen3Tts;
+
+/// Audio output sample rate (both engines output 24kHz).
+pub const SAMPLE_RATE: u32 = 24_000;
+
+/// A chunk of generated audio for streaming output.
+#[derive(Debug, Clone)]
+pub struct AudioChunk {
+ /// PCM f32 samples in [-1.0, 1.0].
+ pub samples: Vec<f32>,
+ /// Sample rate (always 24000 for both engines).
+ pub sample_rate: u32,
+ /// Whether this is the final chunk in the stream.
+ pub is_final: bool,
+}
+
+impl AudioChunk {
+ /// Convert to 16-bit PCM bytes (little-endian) for WebSocket streaming.
+ pub fn to_pcm16_bytes(&self) -> Vec<u8> {
+ let mut buf = Vec::with_capacity(self.samples.len() * 2);
+ for &s in &self.samples {
+ let clamped = s.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ buf.extend_from_slice(&int_sample.to_le_bytes());
+ }
+ buf
+ }
+}
+
+/// Errors that can occur during TTS operations.
+#[derive(Debug)]
+pub enum TtsError {
+ ModelLoad(String),
+ Inference(String),
+ Tokenizer(String),
+ Audio(crate::audio::AudioError),
+ Io(std::io::Error),
+ VoiceRequired,
+ Config(String),
+ Candle(String),
+}
+
+impl std::fmt::Display for TtsError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TtsError::ModelLoad(msg) => write!(f, "model load error: {msg}"),
+ TtsError::Inference(msg) => write!(f, "inference error: {msg}"),
+ TtsError::Tokenizer(msg) => write!(f, "tokenizer error: {msg}"),
+ TtsError::Audio(err) => write!(f, "audio error: {err}"),
+ TtsError::Io(err) => write!(f, "io error: {err}"),
+ TtsError::VoiceRequired => {
+ write!(f, "voice reference audio is required")
+ }
+ TtsError::Config(msg) => write!(f, "config error: {msg}"),
+ TtsError::Candle(msg) => write!(f, "candle error: {msg}"),
+ }
+ }
+}
+
+impl std::error::Error for TtsError {}
+
+impl From<crate::audio::AudioError> for TtsError {
+ fn from(value: crate::audio::AudioError) -> Self {
+ TtsError::Audio(value)
+ }
+}
+
+impl From<std::io::Error> for TtsError {
+ fn from(value: std::io::Error) -> Self {
+ TtsError::Io(value)
+ }
+}
+
+impl From<ort::Error> for TtsError {
+ fn from(value: ort::Error) -> Self {
+ TtsError::ModelLoad(value.to_string())
+ }
+}
+
+impl From<candle_core::Error> for TtsError {
+ fn from(value: candle_core::Error) -> Self {
+ TtsError::Candle(value.to_string())
+ }
+}
+
+/// Which TTS backend to use.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum TtsBackend {
+ /// ONNX-based Chatterbox TTS (legacy).
+ Chatterbox,
+ /// Candle-based Qwen3-TTS (preferred).
+ Qwen3,
+}
+
+/// TTS engine trait — implemented by both Chatterbox and Qwen3.
+#[async_trait::async_trait]
+pub trait TtsEngine: Send + Sync {
+ /// Generate complete audio from text with a voice reference.
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError>;
+
+ /// Check if the engine is loaded and ready.
+ fn is_ready(&self) -> bool;
+
+ /// Get the engine's output sample rate.
+ fn sample_rate(&self) -> u32 {
+ SAMPLE_RATE
+ }
+}
+
+/// Factory for creating TTS engines.
+pub struct TtsEngineFactory;
+
+impl TtsEngineFactory {
+ /// Create a TTS engine of the specified backend type.
+ pub fn create(backend: TtsBackend, model_dir: Option<&str>) -> Result<Box<dyn TtsEngine>, TtsError> {
+ match backend {
+ TtsBackend::Chatterbox => {
+ let engine = ChatterboxTTS::from_pretrained(model_dir)?;
+ Ok(Box::new(engine))
+ }
+ TtsBackend::Qwen3 => {
+ let device = candle_core::Device::Cpu; // Default to CPU; GPU selection happens at higher level
+ let engine = Qwen3Tts::from_pretrained(model_dir, &device)?;
+ Ok(Box::new(engine))
+ }
+ }
+ }
+}
+
+/// Save audio samples to a WAV file.
+pub fn save_wav(samples: &[f32], path: &Path) -> Result<(), TtsError> {
+ let mut file = std::fs::File::create(path)?;
+ write_wav(&mut file, samples, SAMPLE_RATE)?;
+ Ok(())
+}
+
+fn write_wav<W: std::io::Write>(
+ writer: &mut W,
+ samples: &[f32],
+ sample_rate: u32,
+) -> Result<(), std::io::Error> {
+ let num_samples = samples.len() as u32;
+ let num_channels: u16 = 1;
+ let bits_per_sample: u16 = 16;
+ let byte_rate = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
+ let block_align = num_channels * bits_per_sample / 8;
+ let data_size = num_samples * num_channels as u32 * bits_per_sample as u32 / 8;
+ let file_size = 36 + data_size;
+
+ writer.write_all(b"RIFF")?;
+ writer.write_all(&file_size.to_le_bytes())?;
+ writer.write_all(b"WAVE")?;
+
+ writer.write_all(b"fmt ")?;
+ writer.write_all(&16u32.to_le_bytes())?;
+ writer.write_all(&1u16.to_le_bytes())?;
+ writer.write_all(&num_channels.to_le_bytes())?;
+ writer.write_all(&sample_rate.to_le_bytes())?;
+ writer.write_all(&byte_rate.to_le_bytes())?;
+ writer.write_all(&block_align.to_le_bytes())?;
+ writer.write_all(&bits_per_sample.to_le_bytes())?;
+
+ writer.write_all(b"data")?;
+ writer.write_all(&data_size.to_le_bytes())?;
+
+ for &sample in samples {
+ let clamped = sample.clamp(-1.0, 1.0);
+ let int_sample = (clamped * 32767.0) as i16;
+ writer.write_all(&int_sample.to_le_bytes())?;
+ }
+
+ Ok(())
+}
+
+/// Resample audio to 24kHz using simple linear interpolation.
+pub fn resample_to_24k(samples: &[f32], input_rate: u32) -> Vec<f32> {
+ if input_rate == SAMPLE_RATE {
+ return samples.to_vec();
+ }
+ if samples.is_empty() {
+ return Vec::new();
+ }
+
+ let ratio = input_rate as f64 / SAMPLE_RATE as f64;
+ let output_len = ((samples.len() as f64) / ratio).ceil() as usize;
+
+ let mut output = Vec::with_capacity(output_len);
+ for i in 0..output_len {
+ let src_idx = (i as f64 * ratio) as usize;
+ let sample = samples.get(src_idx).copied().unwrap_or(0.0);
+ output.push(sample);
+ }
+
+ output
+}
+
+/// Apply repetition penalty to logits based on previously generated tokens.
+pub fn apply_repetition_penalty(logits: &mut [f32], generated: &[i64], penalty: f32) {
+ for &token in generated {
+ if (token as usize) < logits.len() {
+ let score = logits[token as usize];
+ logits[token as usize] = if score < 0.0 {
+ score * penalty
+ } else {
+ score / penalty
+ };
+ }
+ }
+}
+
+/// Return the index of the maximum value in logits.
+pub fn argmax(logits: &[f32]) -> i64 {
+ logits
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
+ .map(|(idx, _)| idx as i64)
+ .unwrap_or(0)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_argmax() {
+ let logits = vec![0.1, 0.5, 0.3, 0.8, 0.2];
+ assert_eq!(argmax(&logits), 3);
+ }
+
+ #[test]
+ fn test_resample_same_rate() {
+ let samples = vec![0.1, 0.2, 0.3];
+ let resampled = resample_to_24k(&samples, SAMPLE_RATE);
+ assert_eq!(resampled, samples);
+ }
+
+ #[test]
+ fn test_repetition_penalty() {
+ let mut logits = vec![1.0, 2.0, 3.0, 4.0];
+ let generated = vec![1, 3];
+ apply_repetition_penalty(&mut logits, &generated, 1.2);
+ assert!((logits[1] - 2.0 / 1.2).abs() < 1e-6);
+ assert!((logits[3] - 4.0 / 1.2).abs() < 1e-6);
+ }
+
+ #[test]
+ fn test_audio_chunk_to_pcm16() {
+ let chunk = AudioChunk {
+ samples: vec![0.0, 1.0, -1.0],
+ sample_rate: 24_000,
+ is_final: true,
+ };
+ let bytes = chunk.to_pcm16_bytes();
+ assert_eq!(bytes.len(), 6);
+ // 0.0 -> 0i16
+ assert_eq!(i16::from_le_bytes([bytes[0], bytes[1]]), 0);
+ // 1.0 -> 32767i16
+ assert_eq!(i16::from_le_bytes([bytes[2], bytes[3]]), 32767);
+ // -1.0 -> -32767i16
+ assert_eq!(i16::from_le_bytes([bytes[4], bytes[5]]), -32767);
+ }
+}
diff --git a/makima/src/tts/qwen3/code_predictor.rs b/makima/src/tts/qwen3/code_predictor.rs
new file mode 100644
index 0000000..0ef8a1d
--- /dev/null
+++ b/makima/src/tts/qwen3/code_predictor.rs
@@ -0,0 +1,261 @@
+//! Multi-Token Prediction (MTP) code predictor.
+//!
+//! After the main LM predicts the zeroth codebook token, this module
+//! predicts the remaining 15 codebook layers in parallel from the
+//! LM's hidden states.
+//!
+//! Architecture:
+//! - 5 transformer layers (same structure as main LM layers)
+//! - 16 output heads, one per codebook (vocab 2048 each)
+//! - Input: last hidden state from main LM + zeroth codebook embedding
+//! - Output: 16 codebook token predictions
+
+use candle_core::{Device, Module, Result, Tensor, D};
+use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+
+use super::config::{CodePredictorConfig, Qwen3LmConfig};
+use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding};
+
+/// A single code predictor transformer layer.
+///
+/// Uses the same pre-norm residual structure as the main LM layers.
+pub struct CodePredictorLayer {
+ self_attn: Qwen3Attention,
+ mlp: Qwen3Mlp,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl CodePredictorLayer {
+ pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> {
+ // Construct a Qwen3LmConfig-like view for the attention/MLP constructors
+ let lm_config = Qwen3LmConfig {
+ hidden_size: config.hidden_size,
+ num_hidden_layers: config.num_layers,
+ num_attention_heads: config.num_attention_heads,
+ num_key_value_heads: config.num_attention_heads, // No GQA in predictor
+ intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024
+ head_dim: config.hidden_size / config.num_attention_heads,
+ rms_norm_eps: config.rms_norm_eps,
+ ..Qwen3LmConfig::default()
+ };
+
+ let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?;
+ let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?;
+ let input_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("input_layernorm"),
+ )?;
+ let post_attention_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let residual = hidden_states;
+ let hidden_states = self.input_layernorm.forward(hidden_states)?;
+ let hidden_states =
+ self.self_attn
+ .forward(&hidden_states, rope, kv_cache, attention_mask)?;
+ let hidden_states = (residual + hidden_states)?;
+
+ let residual = &hidden_states;
+ let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
+ let hidden_states = self.mlp.forward(&hidden_states)?;
+ let output = (residual + hidden_states)?;
+
+ Ok(output)
+ }
+}
+
+/// Multi-token prediction code predictor.
+///
+/// Takes the hidden states from the main LM and predicts all 16 codebook
+/// tokens. The zeroth codebook is predicted by the main LM head; this
+/// module predicts the remaining 15 residual codebooks.
+pub struct CodePredictor {
+ /// Embedding layer for codebook tokens (shared across groups).
+ code_embeddings: Vec<Embedding>,
+ /// Projection from LM hidden + code embedding to predictor hidden.
+ input_proj: Linear,
+ /// 5 transformer layers.
+ layers: Vec<CodePredictorLayer>,
+ /// Final normalization.
+ norm: RmsNorm,
+ /// Per-codebook output heads (16 heads, each projecting to codebook_vocab_size).
+ output_heads: Vec<Linear>,
+ /// RoPE for the predictor's attention layers.
+ rope: RotaryEmbedding,
+ config: CodePredictorConfig,
+}
+
+impl CodePredictor {
+ pub fn new(
+ config: &CodePredictorConfig,
+ lm_config: &Qwen3LmConfig,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let predictor_vb = vb.pp("code_predictor");
+
+ // Code embeddings for each codebook group
+ let mut code_embeddings = Vec::with_capacity(config.num_code_groups);
+ for i in 0..config.num_code_groups {
+ let emb = embedding(
+ config.codebook_vocab_size,
+ config.hidden_size,
+ predictor_vb.pp(format!("code_embeddings.{i}")),
+ )?;
+ code_embeddings.push(emb);
+ }
+
+ // Input projection: LM hidden (1024) + code embedding (1024) -> predictor hidden (1024)
+ let input_proj = linear_no_bias(
+ config.hidden_size * 2,
+ config.hidden_size,
+ predictor_vb.pp("input_proj"),
+ )?;
+
+ // Transformer layers
+ let mut layers = Vec::with_capacity(config.num_layers);
+ for i in 0..config.num_layers {
+ let layer =
+ CodePredictorLayer::new(config, predictor_vb.pp(format!("layers.{i}")))?;
+ layers.push(layer);
+ }
+
+ let norm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ predictor_vb.pp("norm"),
+ )?;
+
+ // Output heads for each codebook
+ let mut output_heads = Vec::with_capacity(config.num_code_groups);
+ for i in 0..config.num_code_groups {
+ let head = linear_no_bias(
+ config.hidden_size,
+ config.codebook_vocab_size,
+ predictor_vb.pp(format!("output_heads.{i}")),
+ )?;
+ output_heads.push(head);
+ }
+
+ // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim)
+ let predictor_head_dim = config.hidden_size / config.num_attention_heads;
+ let rope_config = Qwen3LmConfig {
+ head_dim: predictor_head_dim,
+ rope_theta: lm_config.rope_theta,
+ max_position_embeddings: lm_config.max_position_embeddings,
+ ..Qwen3LmConfig::default()
+ };
+ let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?;
+
+ Ok(Self {
+ code_embeddings,
+ input_proj,
+ layers,
+ norm,
+ output_heads,
+ rope,
+ config: config.clone(),
+ })
+ }
+
+ /// Predict all 16 codebook tokens from the LM hidden state.
+ ///
+ /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM
+ /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook)
+ ///
+ /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code.
+ pub fn predict(
+ &self,
+ lm_hidden: &Tensor,
+ zeroth_code: u32,
+ device: &Device,
+ ) -> Result<Vec<u32>> {
+ let mut all_codes = Vec::with_capacity(self.config.num_code_groups);
+ all_codes.push(zeroth_code);
+
+ // The code predictor iterates through codebook groups.
+ // For each group i (1..16), it:
+ // 1. Embeds the previous codebook token
+ // 2. Concatenates with LM hidden state
+ // 3. Projects through the predictor layers
+ // 4. Predicts the next codebook token via output_head[i]
+ let mut prev_code = zeroth_code;
+
+ for group_idx in 1..self.config.num_code_groups {
+ // Embed the previous codebook token
+ let code_tensor = Tensor::from_vec(
+ vec![prev_code],
+ (1, 1),
+ device,
+ )?;
+ let code_emb = self.code_embeddings[group_idx - 1].forward(&code_tensor)?;
+
+ // Concatenate LM hidden state with code embedding
+ let combined = Tensor::cat(&[lm_hidden, &code_emb], D::Minus1)?;
+
+ // Project to predictor hidden size
+ let mut hidden = self.input_proj.forward(&combined)?;
+
+ // Run through predictor transformer layers (no KV cache needed — single step)
+ let mut kv_caches: Vec<KvCache> =
+ (0..self.config.num_layers).map(|_| KvCache::new()).collect();
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?;
+ }
+
+ hidden = self.norm.forward(&hidden)?;
+
+ // Predict codebook token
+ let logits = self.output_heads[group_idx].forward(&hidden)?;
+
+ // Greedy decode: argmax
+ let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size]
+ let next_code = logits_flat
+ .argmax(0)?
+ .to_scalar::<u32>()?;
+
+ all_codes.push(next_code);
+ prev_code = next_code;
+ }
+
+ Ok(all_codes)
+ }
+
+ /// Number of codebook groups.
+ pub fn num_code_groups(&self) -> usize {
+ self.config.num_code_groups
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_code_predictor_config() {
+ let config = CodePredictorConfig::default();
+ assert_eq!(config.num_layers, 5);
+ assert_eq!(config.num_code_groups, 16);
+ assert_eq!(config.codebook_vocab_size, 2048);
+ assert_eq!(config.hidden_size, 1024);
+ }
+}
diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs
new file mode 100644
index 0000000..6fb55d7
--- /dev/null
+++ b/makima/src/tts/qwen3/config.rs
@@ -0,0 +1,271 @@
+//! Qwen3-TTS model configuration.
+//!
+//! Parses config.json from the HuggingFace model repository to configure
+//! the language model, code predictor, and speech tokenizer.
+
+use serde::Deserialize;
+
+use crate::tts::TtsError;
+
+/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base.
+#[derive(Debug, Clone, Deserialize)]
+pub struct Qwen3TtsConfig {
+ /// Language model (talker) configuration.
+ #[serde(default = "Qwen3LmConfig::default")]
+ pub lm: Qwen3LmConfig,
+
+ /// Code predictor (multi-token prediction) configuration.
+ #[serde(default = "CodePredictorConfig::default")]
+ pub code_predictor: CodePredictorConfig,
+
+ /// Speech tokenizer configuration.
+ #[serde(default = "SpeechTokenizerConfig::default")]
+ pub speech_tokenizer: SpeechTokenizerConfig,
+}
+
+impl Default for Qwen3TtsConfig {
+ fn default() -> Self {
+ Self {
+ lm: Qwen3LmConfig::default(),
+ code_predictor: CodePredictorConfig::default(),
+ speech_tokenizer: SpeechTokenizerConfig::default(),
+ }
+ }
+}
+
+impl Qwen3TtsConfig {
+ /// Load from a config.json file path.
+ pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> {
+ let content = std::fs::read_to_string(path)
+ .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?;
+ Self::from_json_str(&content)
+ }
+
+ /// Load from a JSON string.
+ pub fn from_json_str(json: &str) -> Result<Self, TtsError> {
+ // Try to parse the full HuggingFace config.json format first
+ if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) {
+ return Ok(Self::from_hf_config(&hf_config));
+ }
+ // Fall back to direct deserialization
+ serde_json::from_str(json)
+ .map_err(|e| TtsError::Config(format!("failed to parse config: {e}")))
+ }
+
+ /// Convert from HuggingFace's config.json format.
+ fn from_hf_config(hf: &HfConfig) -> Self {
+ Self {
+ lm: Qwen3LmConfig {
+ hidden_size: hf.hidden_size.unwrap_or(1024),
+ num_hidden_layers: hf.num_hidden_layers.unwrap_or(28),
+ num_attention_heads: hf.num_attention_heads.unwrap_or(16),
+ num_key_value_heads: hf.num_key_value_heads.unwrap_or(8),
+ intermediate_size: hf.intermediate_size.unwrap_or(3072),
+ head_dim: hf.head_dim.unwrap_or(128),
+ vocab_size: hf.vocab_size.unwrap_or(151_936),
+ max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768),
+ rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
+ rope_theta: hf.rope_theta.unwrap_or(1_000_000.0),
+ use_sliding_window: hf.use_sliding_window.unwrap_or(false),
+ sliding_window: hf.sliding_window,
+ hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()),
+ },
+ code_predictor: CodePredictorConfig {
+ hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024),
+ num_layers: hf.code_predictor_num_layers.unwrap_or(5),
+ num_attention_heads: hf
+ .code_predictor_num_attention_heads
+ .unwrap_or(16),
+ num_code_groups: hf.num_code_groups.unwrap_or(16),
+ codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048),
+ rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6),
+ },
+ speech_tokenizer: SpeechTokenizerConfig::default(),
+ }
+ }
+}
+
+/// Language model configuration (28-layer Qwen3 transformer).
+#[derive(Debug, Clone, Deserialize)]
+pub struct Qwen3LmConfig {
+ /// Hidden dimension of transformer layers.
+ pub hidden_size: usize,
+ /// Number of transformer layers.
+ pub num_hidden_layers: usize,
+ /// Number of attention heads.
+ pub num_attention_heads: usize,
+ /// Number of key-value heads (GQA).
+ pub num_key_value_heads: usize,
+ /// Feed-forward intermediate size.
+ pub intermediate_size: usize,
+ /// Dimension per attention head.
+ pub head_dim: usize,
+ /// Text vocabulary size.
+ pub vocab_size: usize,
+ /// Maximum sequence length for RoPE.
+ pub max_position_embeddings: usize,
+ /// RMS normalization epsilon.
+ pub rms_norm_eps: f64,
+ /// RoPE theta parameter.
+ pub rope_theta: f64,
+ /// Whether to use sliding window attention.
+ pub use_sliding_window: bool,
+ /// Sliding window size (if enabled).
+ pub sliding_window: Option<usize>,
+ /// Activation function name.
+ pub hidden_act: String,
+}
+
+impl Default for Qwen3LmConfig {
+ fn default() -> Self {
+ Self {
+ hidden_size: 1024,
+ num_hidden_layers: 28,
+ num_attention_heads: 16,
+ num_key_value_heads: 8,
+ intermediate_size: 3072,
+ head_dim: 128,
+ vocab_size: 151_936,
+ max_position_embeddings: 32_768,
+ rms_norm_eps: 1e-6,
+ rope_theta: 1_000_000.0,
+ use_sliding_window: false,
+ sliding_window: None,
+ hidden_act: "silu".to_string(),
+ }
+ }
+}
+
+impl Qwen3LmConfig {
+ /// Number of key-value head groups for GQA.
+ pub fn num_kv_groups(&self) -> usize {
+ self.num_attention_heads / self.num_key_value_heads
+ }
+}
+
+/// Code predictor (multi-token prediction) configuration.
+#[derive(Debug, Clone, Deserialize)]
+pub struct CodePredictorConfig {
+ /// Hidden size (matches LM hidden size).
+ pub hidden_size: usize,
+ /// Number of predictor transformer layers.
+ pub num_layers: usize,
+ /// Number of attention heads.
+ pub num_attention_heads: usize,
+ /// Number of codebook groups (residual codebooks).
+ pub num_code_groups: usize,
+ /// Vocabulary size per codebook.
+ pub codebook_vocab_size: usize,
+ /// RMS norm epsilon.
+ pub rms_norm_eps: f64,
+}
+
+impl Default for CodePredictorConfig {
+ fn default() -> Self {
+ Self {
+ hidden_size: 1024,
+ num_layers: 5,
+ num_attention_heads: 16,
+ num_code_groups: 16,
+ codebook_vocab_size: 2048,
+ rms_norm_eps: 1e-6,
+ }
+ }
+}
+
+/// Speech tokenizer (ConvNet codec) configuration.
+#[derive(Debug, Clone, Deserialize)]
+pub struct SpeechTokenizerConfig {
+ /// Number of RVQ codebooks.
+ pub num_codebooks: usize,
+ /// Codebook embedding dimension.
+ pub codebook_dim: usize,
+ /// Codebook vocabulary size per layer.
+ pub codebook_size: usize,
+ /// Encoder/decoder hidden channels.
+ pub hidden_channels: usize,
+ /// Output sample rate.
+ pub sample_rate: u32,
+ /// Token frame rate (Hz).
+ pub frame_rate: f32,
+ /// HuggingFace model ID for the speech tokenizer.
+ pub model_id: String,
+}
+
+impl Default for SpeechTokenizerConfig {
+ fn default() -> Self {
+ Self {
+ num_codebooks: 16,
+ codebook_dim: 256,
+ codebook_size: 2048,
+ hidden_channels: 512,
+ sample_rate: 24_000,
+ frame_rate: 12.5,
+ model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(),
+ }
+ }
+}
+
+/// HuggingFace config.json format (partial, fields we need).
+#[derive(Debug, Deserialize)]
+struct HfConfig {
+ hidden_size: Option<usize>,
+ num_hidden_layers: Option<usize>,
+ num_attention_heads: Option<usize>,
+ num_key_value_heads: Option<usize>,
+ intermediate_size: Option<usize>,
+ head_dim: Option<usize>,
+ vocab_size: Option<usize>,
+ max_position_embeddings: Option<usize>,
+ rms_norm_eps: Option<f64>,
+ rope_theta: Option<f64>,
+ use_sliding_window: Option<bool>,
+ sliding_window: Option<usize>,
+ hidden_act: Option<String>,
+ // Code predictor specific fields
+ code_predictor_hidden_size: Option<usize>,
+ code_predictor_num_layers: Option<usize>,
+ code_predictor_num_attention_heads: Option<usize>,
+ num_code_groups: Option<usize>,
+ codebook_vocab_size: Option<usize>,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_default_config() {
+ let config = Qwen3TtsConfig::default();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.lm.num_attention_heads, 16);
+ assert_eq!(config.lm.num_key_value_heads, 8);
+ assert_eq!(config.lm.head_dim, 128);
+ assert_eq!(config.lm.num_kv_groups(), 2);
+ assert_eq!(config.code_predictor.num_layers, 5);
+ assert_eq!(config.code_predictor.num_code_groups, 16);
+ assert_eq!(config.speech_tokenizer.num_codebooks, 16);
+ }
+
+ #[test]
+ fn test_config_from_json() {
+ let json = r#"{
+ "hidden_size": 1024,
+ "num_hidden_layers": 28,
+ "num_attention_heads": 16,
+ "num_key_value_heads": 8,
+ "intermediate_size": 3072,
+ "vocab_size": 151936,
+ "max_position_embeddings": 32768,
+ "rms_norm_eps": 1e-6,
+ "rope_theta": 1000000.0,
+ "hidden_act": "silu"
+ }"#;
+
+ let config = Qwen3TtsConfig::from_json_str(json).unwrap();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.lm.vocab_size, 151_936);
+ }
+}
diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs
new file mode 100644
index 0000000..02161e6
--- /dev/null
+++ b/makima/src/tts/qwen3/generate.rs
@@ -0,0 +1,426 @@
+//! Autoregressive generation loop for Qwen3-TTS.
+//!
+//! Orchestrates the full inference pipeline:
+//! 1. Encode reference audio → speaker embedding via speech tokenizer
+//! 2. Tokenize text → token IDs
+//! 3. Autoregressive LM generation → zeroth codebook tokens
+//! 4. Code predictor → remaining 15 codebook tokens per frame
+//! 5. Speech tokenizer decoder → waveform audio
+
+use candle_core::{DType, Device, IndexOp, Result, Tensor};
+use tokenizers::Tokenizer;
+
+use super::code_predictor::CodePredictor;
+use super::model::{KvCache, Qwen3Model};
+use super::speech_tokenizer::SpeechTokenizer;
+use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE};
+
+/// Special tokens for the Qwen3-TTS vocabulary.
+pub const BOS_TOKEN_ID: u32 = 151_643;
+pub const EOS_TOKEN_ID: u32 = 151_645;
+pub const PAD_TOKEN_ID: u32 = 151_643;
+
+/// Speech-specific control tokens.
+/// These are placeholders — actual values come from the tokenizer config.
+pub const START_OF_SPEECH: u32 = 151_668;
+pub const END_OF_SPEECH: u32 = 151_669;
+
+/// Generation configuration.
+#[derive(Debug, Clone)]
+pub struct GenerationConfig {
+ /// Maximum number of speech tokens to generate.
+ pub max_new_tokens: usize,
+ /// Temperature for sampling (1.0 = greedy if top_k=1).
+ pub temperature: f32,
+ /// Top-k sampling (0 = disabled, use greedy argmax).
+ pub top_k: usize,
+ /// Repetition penalty.
+ pub repetition_penalty: f32,
+ /// Whether to generate audio chunks incrementally (streaming).
+ pub streaming: bool,
+}
+
+impl Default for GenerationConfig {
+ fn default() -> Self {
+ Self {
+ max_new_tokens: 2048,
+ temperature: 1.0,
+ top_k: 0, // Greedy by default
+ repetition_penalty: 1.2,
+ streaming: false,
+ }
+ }
+}
+
+/// Manages the full generation pipeline.
+pub struct GenerationContext<'a> {
+ model: &'a Qwen3Model,
+ code_predictor: &'a CodePredictor,
+ speech_tokenizer: &'a SpeechTokenizer,
+ tokenizer: &'a Tokenizer,
+ device: &'a Device,
+ config: GenerationConfig,
+}
+
+impl<'a> GenerationContext<'a> {
+ pub fn new(
+ model: &'a Qwen3Model,
+ code_predictor: &'a CodePredictor,
+ speech_tokenizer: &'a SpeechTokenizer,
+ tokenizer: &'a Tokenizer,
+ device: &'a Device,
+ config: GenerationConfig,
+ ) -> Self {
+ Self {
+ model,
+ code_predictor,
+ speech_tokenizer,
+ tokenizer,
+ device,
+ config,
+ }
+ }
+
+ /// Generate audio from text, optionally with a voice reference.
+ ///
+ /// Returns a list of audio chunks. If `streaming` is false, returns
+ /// a single chunk with the complete audio.
+ pub fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ // 1. Encode reference audio if provided
+ let reference_codes = match reference_audio {
+ Some(audio) => Some(
+ self.speech_tokenizer
+ .encode(audio)
+ .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?,
+ ),
+ None => None,
+ };
+
+ // 2. Tokenize text
+ let encoding = self
+ .tokenizer
+ .encode(text, true)
+ .map_err(|e| TtsError::Tokenizer(e.to_string()))?;
+
+ let text_token_ids: Vec<u32> = encoding.get_ids().to_vec();
+
+ // 3. Prepare input sequence
+ // Format: [BOS] [text_tokens] [START_OF_SPEECH]
+ let mut input_ids = Vec::new();
+ input_ids.push(BOS_TOKEN_ID);
+ input_ids.extend_from_slice(&text_token_ids);
+ input_ids.push(START_OF_SPEECH);
+
+ // 4. Run autoregressive generation
+ let generated_frames = self
+ .autoregressive_generate(&input_ids, reference_codes.as_deref())
+ .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?;
+
+ if generated_frames.is_empty() {
+ return Ok(vec![AudioChunk {
+ samples: vec![],
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }]);
+ }
+
+ // 5. Decode all frames to audio
+ if self.config.streaming {
+ self.decode_streaming(&generated_frames)
+ } else {
+ self.decode_batch(&generated_frames)
+ }
+ }
+
+ /// Autoregressive generation loop.
+ ///
+ /// Generates zeroth codebook tokens one at a time, then uses the code
+ /// predictor to fill in the remaining 15 codebooks per frame.
+ ///
+ /// Returns: Vec of frames, each frame is [num_codebooks] tokens.
+ fn autoregressive_generate(
+ &self,
+ input_ids: &[u32],
+ _reference_codes: Option<&[Vec<u32>]>,
+ ) -> Result<Vec<Vec<u32>>> {
+ let _num_codebooks = self.code_predictor.num_code_groups();
+ let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers())
+ .map(|_| KvCache::new())
+ .collect();
+
+ let mut generated_frames: Vec<Vec<u32>> = Vec::new();
+ let mut past_zeroth_tokens: Vec<u32> = Vec::new();
+
+ // === First iteration: process the full input sequence ===
+ let input_tensor = Tensor::from_vec(
+ input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(),
+ (1, input_ids.len()),
+ self.device,
+ )?
+ .to_dtype(DType::I64)?;
+
+ let seq_len = input_ids.len();
+ let attention_mask =
+ Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?;
+
+ let logits =
+ self.model
+ .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?;
+
+ // Get the logits for the last position
+ let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size]
+ let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?;
+
+ if first_token == END_OF_SPEECH as u32 {
+ return Ok(generated_frames);
+ }
+
+ // Use code predictor for all codebooks
+ let lm_hidden = self
+ .model
+ .last_hidden_state()
+ .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
+ let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?;
+
+ let frame_codes = self
+ .code_predictor
+ .predict(&last_hidden, first_token, self.device)?;
+ generated_frames.push(frame_codes);
+ past_zeroth_tokens.push(first_token);
+
+ // === Subsequent iterations: one token at a time ===
+ for _step in 1..self.config.max_new_tokens {
+ let past_len = kv_caches[0].seq_len();
+
+ // Input: just the last generated zeroth codebook token
+ let last_token = *past_zeroth_tokens.last().unwrap();
+ let token_tensor = Tensor::from_vec(
+ vec![last_token as i64],
+ (1, 1),
+ self.device,
+ )?
+ .to_dtype(DType::I64)?;
+
+ // Single-token attention mask
+ let attention_mask =
+ Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?;
+
+ let logits =
+ self.model
+ .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?;
+
+ let next_logits = logits.i((0, 0, ..))?; // [vocab_size]
+ let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?;
+
+ if next_token == END_OF_SPEECH as u32 {
+ break;
+ }
+
+ // Predict all codebooks for this frame
+ let lm_hidden = self
+ .model
+ .last_hidden_state()
+ .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?;
+
+ let frame_codes = self
+ .code_predictor
+ .predict(&lm_hidden, next_token, self.device)?;
+ generated_frames.push(frame_codes);
+ past_zeroth_tokens.push(next_token);
+ }
+
+ Ok(generated_frames)
+ }
+
+ /// Sample a token from logits.
+ fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> {
+ let mut logits_vec: Vec<f32> = logits.to_vec1()?;
+
+ // Apply repetition penalty
+ if self.config.repetition_penalty != 1.0 {
+ for &token in past_tokens {
+ let idx = token as usize;
+ if idx < logits_vec.len() {
+ let score = logits_vec[idx];
+ logits_vec[idx] = if score < 0.0 {
+ score * self.config.repetition_penalty
+ } else {
+ score / self.config.repetition_penalty
+ };
+ }
+ }
+ }
+
+ if self.config.top_k == 0 || self.config.temperature == 0.0 {
+ // Greedy: argmax
+ let (max_idx, _) = logits_vec
+ .iter()
+ .enumerate()
+ .max_by(|(_, a), (_, b)| {
+ a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
+ })
+ .unwrap_or((0, &0.0));
+ Ok(max_idx as u32)
+ } else {
+ // Top-k sampling with temperature
+ let temperature = self.config.temperature;
+
+ // Apply temperature
+ for v in logits_vec.iter_mut() {
+ *v /= temperature;
+ }
+
+ // Sort indices by logit value (descending)
+ let mut indexed: Vec<(usize, f32)> =
+ logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect();
+ indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
+
+ // Keep only top-k
+ let k = self.config.top_k.min(indexed.len());
+ let top_k = &indexed[..k];
+
+ // Softmax over top-k
+ let max_val = top_k[0].1;
+ let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum();
+ let probs: Vec<(usize, f32)> = top_k
+ .iter()
+ .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum))
+ .collect();
+
+ // Sample from distribution (simple linear scan)
+ let r: f32 = random_float();
+ let mut cumulative = 0.0;
+ for (idx, prob) in &probs {
+ cumulative += prob;
+ if cumulative >= r {
+ return Ok(*idx as u32);
+ }
+ }
+
+ // Fallback to highest probability
+ Ok(probs[0].0 as u32)
+ }
+ }
+
+ /// Decode all frames in batch (non-streaming).
+ fn decode_batch(
+ &self,
+ frames: &[Vec<u32>],
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ let num_codebooks = self.speech_tokenizer.num_codebooks();
+
+ // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames]
+ let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
+ for frame in frames {
+ for (cb_idx, &code) in frame.iter().enumerate() {
+ if cb_idx < num_codebooks {
+ codes_by_codebook[cb_idx].push(code);
+ }
+ }
+ }
+
+ let samples = self
+ .speech_tokenizer
+ .decode(&codes_by_codebook)
+ .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?;
+
+ Ok(vec![AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: true,
+ }])
+ }
+
+ /// Decode frames incrementally (streaming).
+ fn decode_streaming(
+ &self,
+ frames: &[Vec<u32>],
+ ) -> std::result::Result<Vec<AudioChunk>, TtsError> {
+ let mut chunks = Vec::new();
+
+ // Decode in groups of frames for efficiency
+ let chunk_size = 10; // ~800ms per chunk at 12.5Hz
+ let num_codebooks = self.speech_tokenizer.num_codebooks();
+
+ for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() {
+ let is_last = (chunk_idx + 1) * chunk_size >= frames.len();
+
+ // Transpose chunk frames
+ let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks];
+ for frame in frame_chunk {
+ for (cb_idx, &code) in frame.iter().enumerate() {
+ if cb_idx < num_codebooks {
+ codes_by_codebook[cb_idx].push(code);
+ }
+ }
+ }
+
+ let samples = self
+ .speech_tokenizer
+ .decode(&codes_by_codebook)
+ .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?;
+
+ chunks.push(AudioChunk {
+ samples,
+ sample_rate: SAMPLE_RATE,
+ is_final: is_last,
+ });
+ }
+
+ Ok(chunks)
+ }
+}
+
+/// Simple pseudo-random float in [0, 1) using thread-local state.
+/// Uses a basic xorshift for reproducibility without external deps.
+fn random_float() -> f32 {
+ use std::cell::Cell;
+ thread_local! {
+ static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0);
+ }
+
+ STATE.with(|s| {
+ let mut x = s.get();
+ x ^= x << 13;
+ x ^= x >> 7;
+ x ^= x << 17;
+ s.set(x);
+ (x as f32) / (u64::MAX as f32)
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_generation_config_default() {
+ let config = GenerationConfig::default();
+ assert_eq!(config.max_new_tokens, 2048);
+ assert_eq!(config.top_k, 0);
+ assert_eq!(config.temperature, 1.0);
+ assert_eq!(config.repetition_penalty, 1.2);
+ assert!(!config.streaming);
+ }
+
+ #[test]
+ fn test_random_float_range() {
+ for _ in 0..100 {
+ let r = random_float();
+ assert!(r >= 0.0);
+ assert!(r < 1.0);
+ }
+ }
+
+ #[test]
+ fn test_special_tokens() {
+ assert_eq!(BOS_TOKEN_ID, 151_643);
+ assert_eq!(EOS_TOKEN_ID, 151_645);
+ assert_eq!(START_OF_SPEECH, 151_668);
+ assert_eq!(END_OF_SPEECH, 151_669);
+ }
+}
diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs
new file mode 100644
index 0000000..c55c118
--- /dev/null
+++ b/makima/src/tts/qwen3/mod.rs
@@ -0,0 +1,287 @@
+//! Qwen3-TTS — Pure Rust implementation using candle.
+//!
+//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis
+//! with voice cloning support. No Python, no ONNX — pure Rust inference
+//! via the candle ML framework.
+//!
+//! # Architecture
+//!
+//! The model has three components:
+//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens
+//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers
+//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes
+//!
+//! # Usage
+//!
+//! ```rust,no_run
+//! use makima::tts::qwen3::Qwen3Tts;
+//! use candle_core::Device;
+//!
+//! let device = Device::Cpu;
+//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap();
+//! // Use via TtsEngine trait or direct API
+//! ```
+
+pub mod code_predictor;
+pub mod config;
+pub mod generate;
+pub mod model;
+pub mod speech_tokenizer;
+
+use std::path::{Path, PathBuf};
+use std::sync::atomic::{AtomicBool, Ordering};
+
+use candle_core::{DType, Device};
+use candle_nn::VarBuilder;
+use hf_hub::api::sync::Api;
+use tokenizers::Tokenizer;
+
+use self::code_predictor::CodePredictor;
+use self::config::Qwen3TtsConfig;
+use self::generate::{GenerationConfig, GenerationContext};
+use self::model::Qwen3Model;
+use self::speech_tokenizer::SpeechTokenizer;
+use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE};
+
+/// HuggingFace model IDs.
+const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base";
+const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
+const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts";
+
+/// Qwen3-TTS engine — pure Rust candle-based inference.
+pub struct Qwen3Tts {
+ /// The 28-layer language model.
+ model: Qwen3Model,
+ /// Multi-token prediction code predictor.
+ code_predictor: CodePredictor,
+ /// Speech tokenizer (encoder + decoder + RVQ).
+ speech_tokenizer: SpeechTokenizer,
+ /// Text tokenizer.
+ tokenizer: Tokenizer,
+ /// Model configuration.
+ config: Qwen3TtsConfig,
+ /// Compute device (CPU/CUDA/Metal).
+ device: Device,
+ /// Whether the model is fully loaded and ready.
+ ready: AtomicBool,
+}
+
+// SAFETY: All fields are either Send+Sync or behind appropriate synchronization.
+// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync.
+unsafe impl Send for Qwen3Tts {}
+unsafe impl Sync for Qwen3Tts {}
+
+impl Qwen3Tts {
+ /// Load from a local directory or download from HuggingFace.
+ pub fn from_pretrained(
+ model_dir: Option<&str>,
+ device: &Device,
+ ) -> Result<Self, TtsError> {
+ let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR));
+
+ if !model_path.exists() {
+ Self::download_models(&model_path)?;
+ }
+
+ Self::load_from_path(&model_path, device)
+ }
+
+ /// Load all model components from a local directory.
+ pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> {
+ let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU
+
+ // Load configuration
+ let config_path = model_dir.join("config.json");
+ let config = if config_path.exists() {
+ Qwen3TtsConfig::from_json_path(&config_path)?
+ } else {
+ Qwen3TtsConfig::default()
+ };
+
+ // Load text tokenizer
+ let tokenizer_path = model_dir.join("tokenizer.json");
+ let tokenizer = Tokenizer::from_file(&tokenizer_path)
+ .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))?;
+
+ // Load LM weights from safetensors
+ let lm_weights_path = model_dir.join("model.safetensors");
+ let lm_data = std::fs::read(&lm_weights_path).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to read LM weights from {}: {e}",
+ lm_weights_path.display()
+ ))
+ })?;
+ let lm_vb = VarBuilder::from_buffered_safetensors(
+ lm_data,
+ dtype,
+ device,
+ ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?;
+
+ // Build language model
+ let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build LM model: {e}"))
+ })?;
+
+ // Build code predictor (weights are in the same safetensors file)
+ let code_predictor =
+ CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build code predictor: {e}"))
+ })?;
+
+ // Load speech tokenizer from separate safetensors
+ let st_weights_path = model_dir.join("speech_tokenizer.safetensors");
+ let st_data = std::fs::read(&st_weights_path).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to read speech tokenizer weights from {}: {e}",
+ st_weights_path.display()
+ ))
+ })?;
+ let st_vb = VarBuilder::from_buffered_safetensors(
+ st_data,
+ dtype,
+ device,
+ ).map_err(|e| {
+ TtsError::ModelLoad(format!(
+ "failed to create speech tokenizer VarBuilder: {e}"
+ ))
+ })?;
+
+ let speech_tokenizer =
+ SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| {
+ TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}"))
+ })?;
+
+ Ok(Self {
+ model,
+ code_predictor,
+ speech_tokenizer,
+ tokenizer,
+ config,
+ device: device.clone(),
+ ready: AtomicBool::new(true),
+ })
+ }
+
+ /// Generate audio from text with optional voice reference.
+ pub fn generate_speech(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ gen_config: Option<GenerationConfig>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ let config = gen_config.unwrap_or_default();
+
+ let ctx = GenerationContext::new(
+ &self.model,
+ &self.code_predictor,
+ &self.speech_tokenizer,
+ &self.tokenizer,
+ &self.device,
+ config,
+ );
+
+ ctx.generate(text, reference_audio)
+ }
+
+ /// Download model files from HuggingFace Hub.
+ fn download_models(target_dir: &Path) -> Result<(), TtsError> {
+ std::fs::create_dir_all(target_dir)?;
+
+ let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?;
+
+ // Download LM model files
+ println!("Downloading Qwen3-TTS language model...");
+ let lm_repo = api.model(LM_MODEL_ID.to_string());
+
+ let lm_files = [
+ "model.safetensors",
+ "config.json",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ ];
+
+ for file in &lm_files {
+ println!(" Downloading {file}...");
+ let downloaded = lm_repo
+ .get(file)
+ .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?;
+
+ let target = target_dir.join(file);
+ if !target.exists() {
+ std::fs::copy(&downloaded, &target)?;
+ }
+ }
+
+ // Download speech tokenizer
+ println!("Downloading Qwen3-TTS speech tokenizer...");
+ let st_repo = api.model(TOKENIZER_MODEL_ID.to_string());
+
+ let st_file = "model.safetensors";
+ let downloaded = st_repo
+ .get(st_file)
+ .map_err(|e| {
+ TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}"))
+ })?;
+
+ let target = target_dir.join("speech_tokenizer.safetensors");
+ if !target.exists() {
+ std::fs::copy(&downloaded, &target)?;
+ }
+
+ println!("All models downloaded to {}", target_dir.display());
+ Ok(())
+ }
+
+ /// Get the model configuration.
+ pub fn config(&self) -> &Qwen3TtsConfig {
+ &self.config
+ }
+
+ /// Get the compute device.
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+}
+
+#[async_trait::async_trait]
+impl TtsEngine for Qwen3Tts {
+ async fn generate(
+ &self,
+ text: &str,
+ reference_audio: Option<&[f32]>,
+ _reference_sample_rate: Option<u32>,
+ ) -> Result<Vec<AudioChunk>, TtsError> {
+ // Note: reference audio should already be resampled to 24kHz
+ // by the caller. If a different sample rate is provided,
+ // the caller should resample using `resample_to_24k()`.
+ self.generate_speech(text, reference_audio, None)
+ }
+
+ fn is_ready(&self) -> bool {
+ self.ready.load(Ordering::Relaxed)
+ }
+
+ fn sample_rate(&self) -> u32 {
+ SAMPLE_RATE
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_default_config() {
+ let config = Qwen3TtsConfig::default();
+ assert_eq!(config.lm.hidden_size, 1024);
+ assert_eq!(config.lm.num_hidden_layers, 28);
+ assert_eq!(config.code_predictor.num_code_groups, 16);
+ assert_eq!(config.speech_tokenizer.sample_rate, 24_000);
+ }
+
+ #[test]
+ fn test_model_ids() {
+ assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base");
+ assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz");
+ }
+}
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs
new file mode 100644
index 0000000..551893b
--- /dev/null
+++ b/makima/src/tts/qwen3/model.rs
@@ -0,0 +1,581 @@
+//! Qwen3 Language Model transformer backbone.
+//!
+//! Implements the 28-layer transformer with:
+//! - Rotary Position Embeddings (RoPE)
+//! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads
+//! - SiLU-gated MLP
+//! - RMS normalization
+//! - KV cache for autoregressive generation
+//!
+//! Based on the candle-transformers Qwen2 model architecture,
+//! extended for Qwen3-TTS.
+
+use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
+
+use super::config::Qwen3LmConfig;
+
+// ---------------------------------------------------------------------------
+// Rotary Position Embeddings
+// ---------------------------------------------------------------------------
+
+/// Precomputed RoPE sin/cos tables.
+#[derive(Debug, Clone)]
+pub struct RotaryEmbedding {
+ cos: Tensor,
+ sin: Tensor,
+}
+
+impl RotaryEmbedding {
+ pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result<Self> {
+ let head_dim = config.head_dim;
+ let max_seq = config.max_position_embeddings;
+ let theta = config.rope_theta;
+
+ let inv_freq: Vec<f32> = (0..head_dim)
+ .step_by(2)
+ .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32))
+ .collect();
+
+ let inv_freq_tensor =
+ Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?;
+
+ let positions: Vec<f32> = (0..max_seq).map(|p| p as f32).collect();
+ let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?;
+
+ // [max_seq, head_dim/2]
+ let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?;
+ // [max_seq, head_dim] by repeating
+ let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+
+ let cos = emb.cos()?.to_dtype(dtype)?;
+ let sin = emb.sin()?.to_dtype(dtype)?;
+
+ Ok(Self { cos, sin })
+ }
+
+ /// Apply RoPE to query and key tensors.
+ /// Input shape: [batch, heads, seq_len, head_dim]
+ pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
+ let seq_len = q.dim(2)?;
+ let cos = self.cos.narrow(0, offset, seq_len)?;
+ let sin = self.sin.narrow(0, offset, seq_len)?;
+
+ let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim]
+ let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
+
+ let q_rotated = Self::rotate_half(q, &cos, &sin)?;
+ let k_rotated = Self::rotate_half(k, &cos, &sin)?;
+
+ Ok((q_rotated, k_rotated))
+ }
+
+ fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
+ let half_dim = x.dim(D::Minus1)? / 2;
+ let x1 = x.narrow(D::Minus1, 0, half_dim)?;
+ let x2 = x.narrow(D::Minus1, half_dim, half_dim)?;
+
+ // [-x2, x1] concatenated
+ let neg_x2 = x2.neg()?;
+ let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?;
+
+ // x * cos + rotated * sin
+ let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?;
+ Ok(result)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// KV Cache
+// ---------------------------------------------------------------------------
+
+/// Per-layer key-value cache for autoregressive generation.
+#[derive(Debug, Clone)]
+pub struct KvCache {
+ key: Option<Tensor>,
+ value: Option<Tensor>,
+}
+
+impl KvCache {
+ pub fn new() -> Self {
+ Self {
+ key: None,
+ value: None,
+ }
+ }
+
+ /// Append new key/value tensors and return the full cached sequence.
+ /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim]
+ pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
+ let (full_key, full_value) = match (&self.key, &self.value) {
+ (Some(prev_k), Some(prev_v)) => {
+ let k = Tensor::cat(&[prev_k, key], 2)?;
+ let v = Tensor::cat(&[prev_v, value], 2)?;
+ (k, v)
+ }
+ _ => (key.clone(), value.clone()),
+ };
+
+ self.key = Some(full_key.clone());
+ self.value = Some(full_value.clone());
+
+ Ok((full_key, full_value))
+ }
+
+ /// Current cached sequence length.
+ pub fn seq_len(&self) -> usize {
+ self.key
+ .as_ref()
+ .map(|k| k.dim(2).unwrap_or(0))
+ .unwrap_or(0)
+ }
+
+ /// Reset the cache.
+ pub fn reset(&mut self) {
+ self.key = None;
+ self.value = None;
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Attention
+// ---------------------------------------------------------------------------
+
+/// Multi-head attention with GQA and RoPE.
+pub struct Qwen3Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ q_norm: RmsNorm,
+ k_norm: RmsNorm,
+ num_heads: usize,
+ num_kv_heads: usize,
+ head_dim: usize,
+ num_kv_groups: usize,
+}
+
+impl Qwen3Attention {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let hidden = config.hidden_size;
+ let num_heads = config.num_attention_heads;
+ let num_kv_heads = config.num_key_value_heads;
+ let head_dim = config.head_dim;
+
+ let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?;
+
+ let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?;
+ let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?;
+
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ q_norm,
+ k_norm,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ num_kv_groups: config.num_kv_groups(),
+ })
+ }
+
+ /// Forward pass with KV cache and RoPE.
+ /// Input: [batch, seq_len, hidden_size]
+ /// Returns: [batch, seq_len, hidden_size]
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (batch, seq_len, _) = hidden_states.dims3()?;
+ let offset = kv_cache.seq_len();
+
+ // Project Q, K, V
+ let q = self.q_proj.forward(hidden_states)?;
+ let k = self.k_proj.forward(hidden_states)?;
+ let v = self.v_proj.forward(hidden_states)?;
+
+ // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim]
+ let q = q
+ .reshape((batch, seq_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ // Apply QK normalization (Qwen3 specific)
+ let q = self.apply_head_norm(&q, &self.q_norm)?;
+ let k = self.apply_head_norm(&k, &self.k_norm)?;
+
+ // Apply RoPE
+ let (q, k) = rope.apply(&q, &k, offset)?;
+
+ // Update KV cache
+ let (k, v) = kv_cache.append(&k, &v)?;
+
+ // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim]
+ let k = self.repeat_kv(&k)?;
+ let v = self.repeat_kv(&v)?;
+
+ // Scaled dot-product attention
+ let scale = (self.head_dim as f64).sqrt();
+ let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?;
+
+ let attn_weights = match attention_mask {
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ None => attn_weights,
+ };
+
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+
+ // Attention output
+ let attn_output = attn_weights.matmul(&v)?;
+
+ // [batch, heads, seq, dim] -> [batch, seq, heads*dim]
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((batch, seq_len, self.num_heads * self.head_dim))?;
+
+ self.o_proj.forward(&attn_output)
+ }
+
+ /// Apply RMS norm per-head.
+ fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result<Tensor> {
+ let (b, h, s, d) = x.dims4()?;
+ // Reshape to [b*h*s, d] for norm, then back
+ let flat = x.reshape((b * h * s, d))?;
+ let normed = norm.forward(&flat)?;
+ normed.reshape((b, h, s, d))
+ }
+
+ /// Repeat KV heads for GQA.
+ fn repeat_kv(&self, x: &Tensor) -> Result<Tensor> {
+ if self.num_kv_groups == 1 {
+ return Ok(x.clone());
+ }
+ let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(2)?
+ .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))?
+ .reshape((batch, self.num_heads, seq_len, head_dim))?;
+ Ok(x)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// MLP
+// ---------------------------------------------------------------------------
+
+/// SiLU-gated feed-forward network.
+pub struct Qwen3Mlp {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+}
+
+impl Qwen3Mlp {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let hidden = config.hidden_size;
+ let intermediate = config.intermediate_size;
+
+ let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?;
+ let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?;
+ let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?;
+
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let gate = self.gate_proj.forward(x)?;
+ let gate = candle_nn::Activation::Silu.forward(&gate)?;
+ let up = self.up_proj.forward(x)?;
+ let hidden = (gate * up)?;
+ self.down_proj.forward(&hidden)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Transformer Layer
+// ---------------------------------------------------------------------------
+
+/// A single Qwen3 transformer decoder layer.
+pub struct Qwen3DecoderLayer {
+ self_attn: Qwen3Attention,
+ mlp: Qwen3Mlp,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl Qwen3DecoderLayer {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?;
+ let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?;
+ let input_layernorm =
+ rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = rms_norm(
+ config.hidden_size,
+ config.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ hidden_states: &Tensor,
+ rope: &RotaryEmbedding,
+ kv_cache: &mut KvCache,
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ // Pre-norm attention
+ let residual = hidden_states;
+ let hidden_states = self.input_layernorm.forward(hidden_states)?;
+ let hidden_states =
+ self.self_attn
+ .forward(&hidden_states, rope, kv_cache, attention_mask)?;
+ let hidden_states = (residual + hidden_states)?;
+
+ // Pre-norm MLP
+ let residual = &hidden_states;
+ let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
+ let hidden_states = self.mlp.forward(&hidden_states)?;
+ let output = (residual + hidden_states)?;
+
+ Ok(output)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Full Model
+// ---------------------------------------------------------------------------
+
+/// The complete Qwen3 language model for TTS.
+///
+/// Architecture:
+/// - Token embedding layer
+/// - 28 transformer decoder layers
+/// - Final RMS normalization
+/// - LM head (projects to vocab)
+pub struct Qwen3Model {
+ embed_tokens: Embedding,
+ layers: Vec<Qwen3DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ rope: RotaryEmbedding,
+ config: Qwen3LmConfig,
+ /// Last hidden states (before lm_head), used by code predictor.
+ last_hidden: std::cell::RefCell<Option<Tensor>>,
+}
+
+impl Qwen3Model {
+ pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
+ let model_vb = vb.pp("model");
+
+ let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("embed_tokens"))?;
+
+ let mut layers = Vec::with_capacity(config.num_hidden_layers);
+ for i in 0..config.num_hidden_layers {
+ let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?;
+ layers.push(layer);
+ }
+
+ let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?;
+
+ // LM head — may or may not share weights with embed_tokens
+ let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
+
+ let dtype = vb.dtype();
+ let device = vb.device().clone();
+ let rope = RotaryEmbedding::new(config, dtype, &device)?;
+
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ rope,
+ config: config.clone(),
+ last_hidden: std::cell::RefCell::new(None),
+ })
+ }
+
+ /// Forward pass through the full model.
+ ///
+ /// `input_ids`: [batch, seq_len] — token IDs
+ /// `kv_caches`: per-layer KV caches
+ /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len]
+ ///
+ /// Returns logits: [batch, seq_len, vocab_size]
+ pub fn forward(
+ &self,
+ input_ids: &Tensor,
+ kv_caches: &mut [KvCache],
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut hidden_states = self.embed_tokens.forward(input_ids)?;
+
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden_states =
+ layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
+ }
+
+ hidden_states = self.norm.forward(&hidden_states)?;
+
+ // Store last hidden state for code predictor
+ *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
+
+ let logits = self.lm_head.forward(&hidden_states)?;
+ Ok(logits)
+ }
+
+ /// Forward pass with pre-computed embeddings (for first iteration where
+ /// text embeddings are concatenated with audio features).
+ ///
+ /// `inputs_embeds`: [batch, seq_len, hidden_size]
+ pub fn forward_embeds(
+ &self,
+ inputs_embeds: &Tensor,
+ kv_caches: &mut [KvCache],
+ attention_mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut hidden_states = inputs_embeds.clone();
+
+ for (i, layer) in self.layers.iter().enumerate() {
+ hidden_states =
+ layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?;
+ }
+
+ hidden_states = self.norm.forward(&hidden_states)?;
+
+ *self.last_hidden.borrow_mut() = Some(hidden_states.clone());
+
+ let logits = self.lm_head.forward(&hidden_states)?;
+ Ok(logits)
+ }
+
+ /// Get the last hidden states (for the code predictor).
+ pub fn last_hidden_state(&self) -> Option<Tensor> {
+ self.last_hidden.borrow().clone()
+ }
+
+ /// Number of transformer layers.
+ pub fn num_layers(&self) -> usize {
+ self.config.num_hidden_layers
+ }
+
+ /// Hidden size.
+ pub fn hidden_size(&self) -> usize {
+ self.config.hidden_size
+ }
+
+ /// Get token embedding layer (for input preparation).
+ pub fn embed_tokens(&self) -> &Embedding {
+ &self.embed_tokens
+ }
+
+ /// Create a causal attention mask.
+ pub fn make_causal_mask(
+ seq_len: usize,
+ past_len: usize,
+ dtype: DType,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let total_len = past_len + seq_len;
+
+ if seq_len == 1 {
+ // Single token: no masking needed (can attend to everything)
+ return Tensor::zeros((1, 1, 1, total_len), dtype, device);
+ }
+
+ // Full causal mask: lower triangular
+ let mask: Vec<f32> = (0..seq_len)
+ .flat_map(|i| {
+ (0..total_len).map(move |j| {
+ if j <= past_len + i {
+ 0.0
+ } else {
+ f32::NEG_INFINITY
+ }
+ })
+ })
+ .collect();
+
+ Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_kv_cache() {
+ let device = Device::Cpu;
+ let mut cache = KvCache::new();
+ assert_eq!(cache.seq_len(), 0);
+
+ let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
+ let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap();
+ let (fk, _fv) = cache.append(&k, &v).unwrap();
+ assert_eq!(cache.seq_len(), 5);
+ assert_eq!(fk.dim(2).unwrap(), 5);
+
+ let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
+ let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap();
+ let (fk2, _fv2) = cache.append(&k2, &v2).unwrap();
+ assert_eq!(cache.seq_len(), 6);
+ assert_eq!(fk2.dim(2).unwrap(), 6);
+
+ cache.reset();
+ assert_eq!(cache.seq_len(), 0);
+ }
+
+ #[test]
+ fn test_causal_mask_single_token() {
+ let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap();
+ assert_eq!(mask.dims(), &[1, 1, 1, 11]);
+ // All zeros — single token can attend to everything
+ let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap();
+ assert_eq!(sum, 0.0);
+ }
+
+ #[test]
+ fn test_causal_mask_multi_token() {
+ let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap();
+ assert_eq!(mask.dims(), &[1, 1, 3, 3]);
+ // Upper triangle should be -inf
+ let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap();
+ // Row 0: [0, -inf, -inf]
+ assert_eq!(data[0], 0.0);
+ assert!(data[1].is_infinite() && data[1] < 0.0);
+ assert!(data[2].is_infinite() && data[2] < 0.0);
+ // Row 1: [0, 0, -inf]
+ assert_eq!(data[3], 0.0);
+ assert_eq!(data[4], 0.0);
+ assert!(data[5].is_infinite() && data[5] < 0.0);
+ // Row 2: [0, 0, 0]
+ assert_eq!(data[6], 0.0);
+ assert_eq!(data[7], 0.0);
+ assert_eq!(data[8], 0.0);
+ }
+}
diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs
new file mode 100644
index 0000000..752050a
--- /dev/null
+++ b/makima/src/tts/qwen3/speech_tokenizer.rs
@@ -0,0 +1,612 @@
+//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks.
+//!
+//! Two sub-components:
+//!
+//! **Encoder** (voice cloning): converts reference audio waveform to discrete
+//! multi-codebook tokens via a causal 1D ConvNet + RVQ.
+//!
+//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook
+//! indices via embedding lookup + causal 1D ConvNet.
+//!
+//! The speech tokenizer is a separate model (~682MB) loaded from
+//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`.
+
+use candle_core::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{
+ conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder,
+};
+
+use super::config::SpeechTokenizerConfig;
+
+// ---------------------------------------------------------------------------
+// Weight-Normalized Conv1d
+// ---------------------------------------------------------------------------
+
+/// A 1D convolution with optional weight normalization and activation.
+pub struct ConvBlock {
+ conv: Conv1d,
+ activation: ConvActivation,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum ConvActivation {
+ None,
+ Elu,
+ Tanh,
+}
+
+impl ConvBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ stride: usize,
+ padding: usize,
+ dilation: usize,
+ activation: ConvActivation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let config = Conv1dConfig {
+ stride,
+ padding,
+ dilation,
+ groups: 1,
+ };
+ let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?;
+
+ Ok(Self { conv, activation })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let out = self.conv.forward(x)?;
+ match self.activation {
+ ConvActivation::None => Ok(out),
+ ConvActivation::Elu => elu(&out, 1.0),
+ ConvActivation::Tanh => out.tanh(),
+ }
+ }
+}
+
+/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0
+fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> {
+ let zeros = x.zeros_like()?;
+ let positive = x.maximum(&zeros)?;
+ let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?;
+ let exp_x = x.exp()?;
+ let one = Tensor::ones_like(&exp_x)?;
+ let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?;
+ positive + negative
+}
+
+// ---------------------------------------------------------------------------
+// Residual Unit
+// ---------------------------------------------------------------------------
+
+/// Residual convolutional unit with dilated convolutions.
+pub struct ResidualUnit {
+ conv1: ConvBlock,
+ conv2: ConvBlock,
+}
+
+impl ResidualUnit {
+ pub fn new(
+ channels: usize,
+ dilation: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ // Dilated causal conv (kernel=7, dilation varies)
+ let padding = (7 - 1) * dilation / 2; // causal-ish padding
+ let conv1 = ConvBlock::new(
+ channels,
+ channels,
+ 7,
+ 1,
+ padding,
+ dilation,
+ ConvActivation::Elu,
+ vb.pp("block.0"),
+ )?;
+
+ // Pointwise conv (kernel=1)
+ let conv2 = ConvBlock::new(
+ channels,
+ channels,
+ 1,
+ 1,
+ 0,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("block.1"),
+ )?;
+
+ Ok(Self { conv1, conv2 })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let residual = x;
+ let out = self.conv1.forward(x)?;
+ let out = self.conv2.forward(&out)?;
+ // Match sequence lengths if needed (causal conv may change length)
+ let out_len = out.dim(D::Minus1)?;
+ let res_len = residual.dim(D::Minus1)?;
+ if out_len != res_len {
+ let start = res_len.saturating_sub(out_len);
+ let residual = residual.narrow(D::Minus1, start, out_len)?;
+ residual + out
+ } else {
+ residual + out
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Encoder Block
+// ---------------------------------------------------------------------------
+
+/// Encoder downsampling block: residual units + strided conv.
+pub struct EncoderBlock {
+ residual_units: Vec<ResidualUnit>,
+ downsample: ConvBlock,
+}
+
+impl EncoderBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ stride: usize,
+ num_residuals: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut residual_units = Vec::with_capacity(num_residuals);
+ for i in 0..num_residuals {
+ let dilation = 3usize.pow(i as u32); // 1, 3, 9
+ let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?;
+ residual_units.push(unit);
+ }
+
+ // Strided downsampling convolution
+ let kernel_size = stride * 2;
+ let padding = stride / 2;
+ let downsample = ConvBlock::new(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("downsample"),
+ )?;
+
+ Ok(Self {
+ residual_units,
+ downsample,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let mut out = x.clone();
+ for unit in &self.residual_units {
+ out = unit.forward(&out)?;
+ }
+ self.downsample.forward(&out)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Decoder Block
+// ---------------------------------------------------------------------------
+
+/// Decoder upsampling block: transposed conv + residual units.
+pub struct DecoderBlock {
+ upsample: ConvBlock,
+ residual_units: Vec<ResidualUnit>,
+}
+
+impl DecoderBlock {
+ pub fn new(
+ in_channels: usize,
+ out_channels: usize,
+ stride: usize,
+ num_residuals: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ // Strided upsampling (transpose conv simulated by regular conv + padding)
+ let kernel_size = stride * 2;
+ let padding = stride / 2;
+ let upsample = ConvBlock::new(
+ in_channels,
+ out_channels,
+ kernel_size,
+ 1, // stride=1 for output; upsample via repeat/interpolation
+ padding,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("upsample"),
+ )?;
+
+ let mut residual_units = Vec::with_capacity(num_residuals);
+ for i in 0..num_residuals {
+ let dilation = 3usize.pow(i as u32);
+ let unit =
+ ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?;
+ residual_units.push(unit);
+ }
+
+ Ok(Self {
+ upsample,
+ residual_units,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let mut out = self.upsample.forward(x)?;
+ for unit in &self.residual_units {
+ out = unit.forward(&out)?;
+ }
+ Ok(out)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// RVQ Codebook
+// ---------------------------------------------------------------------------
+
+/// Residual Vector Quantization codebook.
+///
+/// Contains `num_codebooks` embedding tables, each mapping
+/// `codebook_size` indices to `codebook_dim`-dimensional vectors.
+pub struct RvqCodebook {
+ codebooks: Vec<Embedding>,
+ num_codebooks: usize,
+ codebook_dim: usize,
+}
+
+impl RvqCodebook {
+ pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> {
+ let mut codebooks = Vec::with_capacity(config.num_codebooks);
+ for i in 0..config.num_codebooks {
+ let cb = embedding(
+ config.codebook_size,
+ config.codebook_dim,
+ vb.pp(format!("codebooks.{i}")),
+ )?;
+ codebooks.push(cb);
+ }
+
+ Ok(Self {
+ codebooks,
+ num_codebooks: config.num_codebooks,
+ codebook_dim: config.codebook_dim,
+ })
+ }
+
+ /// Look up codebook embeddings for all codebook layers.
+ ///
+ /// `codes`: [num_codebooks, seq_len] — codebook indices per layer
+ /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings
+ pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> {
+ assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks);
+
+ let seq_len = codes[0].len();
+ let mut sum: Option<Tensor> = None;
+
+ for (i, code_layer) in codes.iter().enumerate() {
+ assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch");
+
+ let indices = Tensor::from_vec(
+ code_layer.clone(),
+ (1, seq_len),
+ device,
+ )?;
+
+ // [1, seq_len, codebook_dim]
+ let emb = self.codebooks[i].forward(&indices)?;
+
+ sum = Some(match sum {
+ Some(prev) => (prev + emb)?,
+ None => emb,
+ });
+ }
+
+ // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
+ let result = sum.unwrap().transpose(1, 2)?;
+ Ok(result)
+ }
+
+ /// Number of codebooks.
+ pub fn num_codebooks(&self) -> usize {
+ self.num_codebooks
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Speech Tokenizer (Encoder + Decoder)
+// ---------------------------------------------------------------------------
+
+/// The complete speech tokenizer with encoder and decoder.
+pub struct SpeechTokenizer {
+ /// Encoder: waveform -> latent (for voice cloning).
+ encoder_input_conv: ConvBlock,
+ encoder_blocks: Vec<EncoderBlock>,
+ encoder_output_conv: ConvBlock,
+
+ /// RVQ codebooks for quantization.
+ codebook: RvqCodebook,
+
+ /// Decoder: codes -> waveform.
+ decoder_input_conv: ConvBlock,
+ decoder_blocks: Vec<DecoderBlock>,
+ decoder_output_conv: ConvBlock,
+
+ /// Projection from codebook dim to decoder hidden channels.
+ decoder_proj: Linear,
+
+ config: SpeechTokenizerConfig,
+ device: Device,
+}
+
+impl SpeechTokenizer {
+ /// Load the speech tokenizer from safetensors.
+ pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> {
+ let hidden = config.hidden_channels; // 512
+
+ // ===== Encoder =====
+ // Input: [batch, 1, samples] -> [batch, hidden/8, ...]
+ let encoder_input_conv = ConvBlock::new(
+ 1,
+ hidden / 8, // 64
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("encoder.input_conv"),
+ )?;
+
+ // Downsampling blocks with increasing channels
+ let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480
+ let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512
+ let mut encoder_blocks = Vec::with_capacity(strides.len());
+ for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() {
+ let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] };
+ let block = EncoderBlock::new(
+ in_ch,
+ out_ch,
+ stride,
+ 3, // 3 residual units per block
+ vb.pp(format!("encoder.blocks.{i}")),
+ )?;
+ encoder_blocks.push(block);
+ }
+
+ // Encoder output projection to codebook dim
+ let encoder_output_conv = ConvBlock::new(
+ hidden,
+ config.codebook_dim,
+ 3,
+ 1,
+ 1,
+ 1,
+ ConvActivation::None,
+ vb.pp("encoder.output_conv"),
+ )?;
+
+ // ===== RVQ Codebook =====
+ let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?;
+
+ // ===== Decoder =====
+ // Projection from codebook dim to decoder hidden
+ let decoder_proj = linear_no_bias(
+ config.codebook_dim,
+ hidden,
+ vb.pp("decoder.proj"),
+ )?;
+
+ // Input conv
+ let decoder_input_conv = ConvBlock::new(
+ hidden,
+ hidden,
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Elu,
+ vb.pp("decoder.input_conv"),
+ )?;
+
+ // Upsampling blocks (reverse order of encoder)
+ let dec_strides = [3, 4, 5, 8];
+ let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64
+ let mut decoder_blocks = Vec::with_capacity(dec_strides.len());
+ for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate()
+ {
+ let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] };
+ let block = DecoderBlock::new(
+ in_ch,
+ out_ch,
+ stride,
+ 3,
+ vb.pp(format!("decoder.blocks.{i}")),
+ )?;
+ decoder_blocks.push(block);
+ }
+
+ // Output conv: hidden/8 -> 1 channel (waveform)
+ let decoder_output_conv = ConvBlock::new(
+ hidden / 8,
+ 1,
+ 7,
+ 1,
+ 3,
+ 1,
+ ConvActivation::Tanh,
+ vb.pp("decoder.output_conv"),
+ )?;
+
+ Ok(Self {
+ encoder_input_conv,
+ encoder_blocks,
+ encoder_output_conv,
+ codebook,
+ decoder_input_conv,
+ decoder_blocks,
+ decoder_output_conv,
+ decoder_proj,
+ config: config.clone(),
+ device: device.clone(),
+ })
+ }
+
+ /// Encode reference audio waveform to discrete codebook tokens.
+ ///
+ /// `audio`: [num_samples] — mono 24kHz audio
+ /// Returns: Vec of `num_codebooks` vectors, each containing token indices.
+ pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> {
+ // [1, 1, num_samples]
+ let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?;
+
+ // Run encoder
+ let mut hidden = self.encoder_input_conv.forward(&x)?;
+ for block in &self.encoder_blocks {
+ hidden = block.forward(&hidden)?;
+ }
+ let latent = self.encoder_output_conv.forward(&hidden)?;
+
+ // latent: [1, codebook_dim, seq_len]
+ // Quantize via nearest-neighbor lookup in each codebook
+ let seq_len = latent.dim(D::Minus1)?;
+ let mut all_codes = Vec::with_capacity(self.config.num_codebooks);
+
+ // Residual quantization: subtract each codebook's contribution
+ let mut residual = latent.clone();
+
+ for cb_idx in 0..self.config.num_codebooks {
+ // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep
+ let codes = self.quantize_layer(&residual, cb_idx, seq_len)?;
+
+ // Look up the quantized vectors and subtract from residual
+ let code_indices =
+ Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?;
+ let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?;
+ // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len]
+ let quantized = quantized.transpose(1, 2)?;
+ residual = (residual - quantized)?;
+
+ all_codes.push(codes);
+ }
+
+ Ok(all_codes)
+ }
+
+ /// Quantize a single RVQ layer by finding the nearest codebook entry.
+ fn quantize_layer(
+ &self,
+ residual: &Tensor,
+ codebook_idx: usize,
+ _seq_len: usize,
+ ) -> Result<Vec<u32>> {
+ // residual: [1, codebook_dim, seq_len]
+ // codebook weights: [codebook_size, codebook_dim]
+ let cb_weight = self.codebook.codebooks[codebook_idx]
+ .embeddings()
+ .clone(); // [codebook_size, codebook_dim]
+
+ // Transpose residual: [1, seq_len, codebook_dim]
+ let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim]
+
+ // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2
+ let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len]
+ let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size]
+ let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size]
+
+ let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1]
+ let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size]
+
+ let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size]
+
+ // Argmin per timestep
+ let indices = distances.argmin(D::Minus1)?; // [seq_len]
+ let codes: Vec<u32> = indices.to_vec1()?;
+
+ Ok(codes)
+ }
+
+ /// Decode discrete codebook tokens to audio waveform.
+ ///
+ /// `codes`: Vec of `num_codebooks` vectors of token indices.
+ /// Returns: Vec<f32> — mono 24kHz audio samples.
+ pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> {
+ // Look up and sum all codebook embeddings
+ let embeddings = self.codebook.decode(codes, &self.device)?;
+ // embeddings: [1, codebook_dim, seq_len]
+
+ // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden]
+ let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim]
+ let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden]
+ let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len]
+
+ // Run decoder
+ hidden = self.decoder_input_conv.forward(&hidden)?;
+ for block in &self.decoder_blocks {
+ hidden = block.forward(&hidden)?;
+ }
+ let waveform = self.decoder_output_conv.forward(&hidden)?;
+
+ // [1, 1, num_samples] -> Vec<f32>
+ let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?;
+ Ok(samples)
+ }
+
+ /// Decode a single frame's codes to audio samples (for streaming).
+ ///
+ /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame
+ /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz)
+ pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> {
+ let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect();
+ self.decode(&codes)
+ }
+
+ /// Get the number of codebooks.
+ pub fn num_codebooks(&self) -> usize {
+ self.config.num_codebooks
+ }
+
+ /// Get the output sample rate.
+ pub fn sample_rate(&self) -> u32 {
+ self.config.sample_rate
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_elu_positive() {
+ let device = Device::Cpu;
+ let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap();
+ let result = elu(&x, 1.0).unwrap();
+ let values: Vec<f32> = result.to_vec1().unwrap();
+ assert!((values[0] - 1.0).abs() < 1e-5);
+ assert!((values[1] - 2.0).abs() < 1e-5);
+ }
+
+ #[test]
+ fn test_elu_negative() {
+ let device = Device::Cpu;
+ let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap();
+ let result = elu(&x, 1.0).unwrap();
+ let values: Vec<f32> = result.to_vec1().unwrap();
+ // ELU(-1) = exp(-1) - 1 ≈ -0.6321
+ assert!((values[0] - (-0.6321)).abs() < 0.01);
+ }
+
+ #[test]
+ fn test_speech_tokenizer_config() {
+ let config = SpeechTokenizerConfig::default();
+ assert_eq!(config.num_codebooks, 16);
+ assert_eq!(config.codebook_size, 2048);
+ assert_eq!(config.sample_rate, 24_000);
+ }
+}
diff --git a/voices/makima/manifest.json b/voices/makima/manifest.json
new file mode 100644
index 0000000..ec93fae
--- /dev/null
+++ b/voices/makima/manifest.json
@@ -0,0 +1,12 @@
+{
+ "name": "Makima",
+ "id": "makima",
+ "description": "Makima's Japanese-accented English voice for TTS synthesis.",
+ "language": "en",
+ "accent": "ja",
+ "sample_rate": 24000,
+ "format": "pcm_f32",
+ "model_backend": "qwen3",
+ "reference_audio": "reference.wav",
+ "notes": "Default voice for the Makima system. Reference audio should be a short (5-15s) clip of the target voice at 24kHz mono. Place the WAV file as reference.wav in this directory."
+}