summaryrefslogtreecommitdiff
path: root/docs/research/rust-native-tts-research.md
blob: 5bc75f73b118637edf14135b7a4cea7673973f0c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Rust-Native Qwen3-TTS Integration Research

## Executive Summary

This document researches integrating **Qwen3-TTS-12Hz-0.6B-Base** directly into the makima Rust codebase, replacing the current Chatterbox TTS implementation. The goal is a **pure Rust** solution — no Python, no separate microservice.

**Bottom line:** A Rust-native integration is feasible but requires significant implementation work. The most viable path is using **candle** (HuggingFace's Rust ML framework) to implement the model architecture natively, loading safetensors weights directly. The existing ONNX-based approach used for Chatterbox TTS is **not viable** for Qwen3-TTS due to architecture incompatibilities with ONNX exporters.

---

## 1. Current TTS Implementation Analysis

The existing Chatterbox TTS in `makima/src/tts.rs` uses:

- **ONNX Runtime** via the `ort` crate (v2.0.0-rc.10)
- **Four ONNX model files**: `speech_encoder.onnx`, `embed_tokens.onnx`, `language_model.onnx`, `conditional_decoder.onnx`
- **tokenizers** crate for text tokenization
- **ndarray** for tensor manipulation
- **hf-hub** for model downloading
- **Pipeline**: encode voice → tokenize text → autoregressive token generation with KV cache → decode tokens to waveform
- **Architecture constants**: 24 layers, 16 KV heads, 64 head dim, 24kHz sample rate

The pattern is well-established: download ONNX models from HuggingFace, load sessions, run inference with manual KV cache management.

### STT Pattern (listen.rs)

The Listen/STT handler in `makima/src/server/handlers/listen.rs` demonstrates the broader ML pattern:
- WebSocket-based streaming
- Lazy model loading via `SharedState::get_ml_models()`
- Models held behind `tokio::sync::Mutex` for async access
- `parakeet-rs` local crate for STT, `sortformer` for diarization
- All models are Rust-native with ONNX backends

---

## 2. Qwen3-TTS-12Hz-0.6B-Base Architecture

### Model Overview

| Property | Value |
|----------|-------|
| **Parameters** | 0.6B |
| **Architecture** | `Qwen3TTSForConditionalGeneration` |
| **Output Sample Rate** | 24,000 Hz |
| **Token Frame Rate** | 12.5 Hz (~80ms per token) |
| **Model Format** | SafeTensors (1.83 GB main + 682 MB tokenizer) |
| **Total Size** | ~2.52 GB |
| **Precision** | bfloat16/float16 |

### Components

The model has **three distinct components**:

#### A. Main Language Model (Talker) — 1.83 GB safetensors
- Hidden size: 1024
- Layers: 28
- Attention heads: 16 (8 KV heads)
- Intermediate size: 3072
- Head dimension: 128
- Text vocab size: 151,936
- Max position embeddings: 32,768
- Autoregressive transformer predicting speech token sequences from text

#### B. Code Predictor (Multi-Token Prediction) — embedded in main model
- Hidden size: 1024
- Layers: 5
- Attention heads: 16
- Number of code groups: 16
- Codebook vocab size: 2048
- Predicts residual codebooks (16 layers) after the main LM predicts the zeroth codebook

#### C. Speech Tokenizer (Qwen3-TTS-Tokenizer-12Hz) — 682 MB safetensors
- Separate model in `speech_tokenizer/` directory
- GAN-based codec: encoder + decoder
- 16-layer multi-codebook RVQ (Residual Vector Quantization)
- First codebook: semantic (WavLM-guided)
- Remaining 15: acoustic details
- **Decoder**: lightweight causal ConvNet (no DiT/diffusion needed)
- Encodes reference audio → discrete codes, decodes codes → waveform

### Inference Pipeline

```
Text Input + Reference Audio
         ↓
[Speech Tokenizer Encoder] → reference audio codes + speaker embedding
         ↓
[Text Tokenizer] → text token IDs
         ↓
[Language Model] → autoregressive generation of zeroth codebook tokens
         ↓
[Code Predictor / MTP] → predict remaining 15 codebook layers
         ↓
[Speech Tokenizer Decoder / Causal ConvNet] → waveform output (24kHz)
```

---

## 3. ONNX Export Feasibility — NOT VIABLE

### Status: No ONNX support exists

- **No official ONNX export** from Qwen team
- **No community ONNX conversion** for Qwen3-TTS
- The Qwen3 architecture is **not supported** by HuggingFace Optimum's ONNX exporter
- Users attempting export get: `ValueError: Trying to export a qwen3 model, that is a custom or unsupported architecture, but no custom onnx configuration was passed`
- Even for base Qwen3 LLMs (non-TTS), ONNX export has significant issues with MoE routing, hybrid attention, and novel architecture components

### Why ONNX Won't Work for Qwen3-TTS

1. **Custom architecture** — `Qwen3TTSForConditionalGeneration` is not a standard transformer; it combines LM + code predictor + speech tokenizer
2. **Multi-codebook MTP module** — the code predictor generates 16 codebook layers, a non-standard operation
3. **Causal ConvNet decoder** — the speech tokenizer's decoder is a custom GAN-trained ConvNet, not a standard vocoder
4. **Dynamic control flow** — dual-track streaming architecture with conditional branching
5. **No Optimum support** — would require writing a custom ONNX config from scratch for each sub-component

**Verdict: The ONNX path (matching our Chatterbox approach) is a dead end for Qwen3-TTS.**

---

## 4. Rust-Native Inference Options

### Option A: Candle (HuggingFace) — RECOMMENDED

[candle](https://github.com/huggingface/candle) is HuggingFace's minimalist Rust ML framework.

**Why candle is the best fit:**

| Factor | Assessment |
|--------|------------|
| **Qwen model support** | ✅ Has `qwen2` module in candle-transformers; Qwen3 variants supported |
| **SafeTensors loading** | ✅ Native first-class support (safetensors is a Rust crate) |
| **GPU support** | ✅ CUDA backend, Metal (macOS), CPU with MKL |
| **Tokenizer support** | ✅ Uses the same `tokenizers` crate makima already depends on |
| **Audio models** | ✅ Supports EnCodec, Whisper, MetaVoice, Parler-TTS |
| **KV cache** | ✅ Well-established patterns in existing model implementations |
| **Community** | ✅ Active; Crane project already lists Qwen3-TTS as "highest priority" |
| **Binary size** | ✅ Compiles to single binary, no Python dependency |

**What needs to be implemented:**

1. **Qwen3-TTS transformer layers** — extend existing `qwen2` model code for the 28-layer LM with TTS-specific modifications (speaker encoder concatenation, code predictor output heads)
2. **Code Predictor (MTP)** — 5-layer module that generates 16 codebook predictions from the LM hidden states
3. **Speech Tokenizer Encoder** — ConvNet encoder that converts reference audio to discrete multi-codebook tokens + speaker embeddings
4. **Speech Tokenizer Decoder** — causal ConvNet that reconstructs waveforms from discrete codes
5. **Multi-codebook handling** — manage 16 parallel codebook sequences

**Estimated effort:** Medium-High. The LM backbone can reuse existing Qwen2/3 code. The speech tokenizer (encoder + decoder) is the most novel component.

**Key crate dependencies to add:**
```toml
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
# Keep existing: tokenizers, hf-hub, ndarray (for compatibility)
```

### Option B: Crane (Candle-based TTS Engine)

[Crane](https://github.com/lucasjinreal/Crane) is a pure Rust LLM inference engine built on candle, specifically designed for multi-modal models including TTS.

**Key facts:**
- Already supports Spark-TTS (codec-based TTS with similar architecture)
- **Qwen3-TTS is listed as "Highest Priority" on their roadmap**
- Handles multi-module architectures (codec + LLM pipelines)
- Supports Qwen2.5, Moonshine ASR
- Claims 50x faster than PyTorch on Apple Silicon

**Strategy:** Monitor Crane's Qwen3-TTS implementation. If they ship it, we could either:
- Use Crane as a dependency directly
- Port their implementation into makima's codebase
- Contribute to Crane and depend on it

**Risk:** Crane is a relatively new project; depending on it adds supply chain risk.

### Option C: qwen3-rs (Educational Reference)

[qwen3-rs](https://github.com/reinterpretcat/qwen3-rs) is an educational project implementing Qwen3 inference from scratch in Rust.

**Useful for:** Reference implementation of Qwen3 transformer layers, tokenization, KV cache, and safetensors loading — all without heavy ML framework dependencies. However, it only implements the base LLM, not the TTS-specific components.

### Option D: Direct ort (ONNX Runtime) with Custom Export — FALLBACK

If we could manually export each sub-component to ONNX:

1. Export the 28-layer LM backbone (similar complexity to Chatterbox)
2. Export the code predictor separately
3. Export the speech tokenizer encoder/decoder

This would match our existing Chatterbox pattern but requires Python scripting for the one-time export, and the Qwen3 architecture is explicitly unsupported by standard exporters. **Not recommended unless ONNX support materializes upstream.**

### Option E: PyTorch C++ (libtorch) via FFI — NOT RECOMMENDED

Using libtorch via Rust FFI bindings (`tch-rs` crate). This would:
- Add a ~2GB libtorch dependency
- Require complex build setup
- Introduce C++ dependency management
- Defeat the purpose of a pure Rust solution

---

## 5. Recommended Approach

### Phase 1: Candle-Based Implementation

**Architecture:**

```
makima/src/tts/
├── mod.rs              // TTS trait + factory (select Chatterbox vs Qwen3)
├── chatterbox.rs       // Existing ONNX-based Chatterbox (moved from tts.rs)
├── qwen3/
│   ├── mod.rs          // Qwen3TTS public API
│   ├── model.rs        // Qwen3 LM transformer (28 layers)
│   ├── code_predictor.rs   // MTP module (5 layers, 16 codebooks)
│   ├── speech_tokenizer.rs // Encoder + Decoder (causal ConvNet)
│   ├── config.rs       // Model config from config.json
│   └── generate.rs     // Autoregressive generation loop with KV cache
```

**Key implementation details:**

1. **Load safetensors directly** — candle's `safetensors` support reads the 1.83GB main model and 682MB speech tokenizer
2. **Reuse Qwen2 attention** — candle-transformers already has `qwen2::Model` with RoPE, GQA, and KV cache
3. **Implement ConvNet codec** — the speech tokenizer's encoder/decoder is a causal 1D ConvNet; candle has `Conv1d` layers
4. **Multi-codebook RVQ** — implement the 16-codebook residual vector quantization lookup
5. **Speaker embedding** — extract from reference audio via the speech tokenizer encoder
6. **Streaming support** — the 12Hz model's causal architecture enables token-by-token waveform generation

### Phase 2: Voice Assets

The model supports voice cloning with reference audio. For the default Makima voice:
- Need 5-15 second Japanese-accented English audio clip
- Reference audio + transcript fed to speech tokenizer encoder
- Speaker embedding cached for reuse

### Phase 3: Integration with Listen Page

Following the pattern in `listen.rs`:
- TTS model loaded lazily via `SharedState`
- Protected behind `tokio::sync::Mutex` (or `RwLock` for concurrent reads)
- WebSocket endpoint for streaming TTS (emit audio chunks as tokens are generated)
- Bidirectional: STT (listen) → process → TTS (speak) loop

---

## 6. Comparison Matrix

| Criteria | ONNX (current pattern) | Candle | Crane | libtorch |
|----------|----------------------|--------|-------|----------|
| Pure Rust | ✅ (ort crate) | ✅ | ✅ | ❌ (C++ FFI) |
| Qwen3-TTS support | ❌ No export | ⚠️ Needs impl | ⚠️ Planned | ✅ (full PyTorch) |
| Single binary | ✅ | ✅ | ✅ | ❌ |
| GPU acceleration | ✅ | ✅ | ✅ | ✅ |
| SafeTensors loading | ❌ (needs ONNX) | ✅ | ✅ | ✅ |
| Streaming TTS | ✅ | ✅ | ✅ | ✅ |
| Maintenance burden | Low | Medium | Low (if adopted) | High |
| Implementation effort | N/A (blocked) | Medium-High | Low (if available) | Medium |
| Dependency size | ~50MB | ~5MB | ~5MB | ~2GB |

---

## 7. Risk Assessment

| Risk | Likelihood | Impact | Mitigation |
|------|-----------|--------|------------|
| Candle implementation takes longer than expected | Medium | Medium | Reference Crane's Spark-TTS implementation; use qwen3-rs as LM reference |
| Speech tokenizer ConvNet is complex to port | Medium | High | Study the PyTorch source in qwen-tts package; ConvNet layers are simpler than transformers |
| Model quality differs from reference PyTorch | Low | High | Validate with reference audio samples; ensure bfloat16 precision |
| Crane ships Qwen3-TTS before we finish | Medium | Positive | Adopt their implementation |
| GPU memory issues on target hardware | Low | Medium | 0.6B model is small (~2.5GB); fits in 4GB VRAM with float16 |

---

## 8. Next Steps

1. **Immediate:** Add `candle-core`, `candle-nn`, `candle-transformers` to Cargo.toml
2. **Week 1:** Implement Qwen3 LM backbone in candle (extend existing qwen2 model)
3. **Week 2:** Implement speech tokenizer encoder/decoder (ConvNet + RVQ)
4. **Week 2:** Implement code predictor (MTP module)
5. **Week 3:** Integration testing with reference audio; validate output quality
6. **Week 3:** Wire into makima server as TTS endpoint
7. **Ongoing:** Monitor Crane project for Qwen3-TTS implementation

---

## Sources

- [Qwen3-TTS-12Hz-0.6B-Base on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
- [Qwen3-TTS Technical Report (arXiv)](https://arxiv.org/html/2601.15621v1)
- [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
- [Candle — HuggingFace Rust ML Framework](https://github.com/huggingface/candle)
- [Crane — Rust LLM Inference Engine](https://github.com/lucasjinreal/Crane)
- [qwen3-rs — Educational Qwen3 Rust Implementation](https://github.com/reinterpretcat/qwen3-rs)
- [candle-transformers Qwen2 model](https://docs.rs/candle-transformers/latest/candle_transformers/models/qwen2/index.html)
- [Qwen3-TTS-Tokenizer-12Hz on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-Tokenizer-12Hz)
- [ONNX export issues for Qwen3](https://huggingface.co/onnx-community/Qwen3-1.7B-ONNX/discussions/1)