summaryrefslogblamecommitdiff
path: root/docs/research/rust-native-tts-research.md
blob: 5bc75f73b118637edf14135b7a4cea7673973f0c (plain) (tree)








































































































































































































































































































                                                                                                                                                                                                                                                                                                                                                                                                                       
# Rust-Native Qwen3-TTS Integration Research

## Executive Summary

This document researches integrating **Qwen3-TTS-12Hz-0.6B-Base** directly into the makima Rust codebase, replacing the current Chatterbox TTS implementation. The goal is a **pure Rust** solution — no Python, no separate microservice.

**Bottom line:** A Rust-native integration is feasible but requires significant implementation work. The most viable path is using **candle** (HuggingFace's Rust ML framework) to implement the model architecture natively, loading safetensors weights directly. The existing ONNX-based approach used for Chatterbox TTS is **not viable** for Qwen3-TTS due to architecture incompatibilities with ONNX exporters.

---

## 1. Current TTS Implementation Analysis

The existing Chatterbox TTS in `makima/src/tts.rs` uses:

- **ONNX Runtime** via the `ort` crate (v2.0.0-rc.10)
- **Four ONNX model files**: `speech_encoder.onnx`, `embed_tokens.onnx`, `language_model.onnx`, `conditional_decoder.onnx`
- **tokenizers** crate for text tokenization
- **ndarray** for tensor manipulation
- **hf-hub** for model downloading
- **Pipeline**: encode voice → tokenize text → autoregressive token generation with KV cache → decode tokens to waveform
- **Architecture constants**: 24 layers, 16 KV heads, 64 head dim, 24kHz sample rate

The pattern is well-established: download ONNX models from HuggingFace, load sessions, run inference with manual KV cache management.

### STT Pattern (listen.rs)

The Listen/STT handler in `makima/src/server/handlers/listen.rs` demonstrates the broader ML pattern:
- WebSocket-based streaming
- Lazy model loading via `SharedState::get_ml_models()`
- Models held behind `tokio::sync::Mutex` for async access
- `parakeet-rs` local crate for STT, `sortformer` for diarization
- All models are Rust-native with ONNX backends

---

## 2. Qwen3-TTS-12Hz-0.6B-Base Architecture

### Model Overview

| Property | Value |
|----------|-------|
| **Parameters** | 0.6B |
| **Architecture** | `Qwen3TTSForConditionalGeneration` |
| **Output Sample Rate** | 24,000 Hz |
| **Token Frame Rate** | 12.5 Hz (~80ms per token) |
| **Model Format** | SafeTensors (1.83 GB main + 682 MB tokenizer) |
| **Total Size** | ~2.52 GB |
| **Precision** | bfloat16/float16 |

### Components

The model has **three distinct components**:

#### A. Main Language Model (Talker) — 1.83 GB safetensors
- Hidden size: 1024
- Layers: 28
- Attention heads: 16 (8 KV heads)
- Intermediate size: 3072
- Head dimension: 128
- Text vocab size: 151,936
- Max position embeddings: 32,768
- Autoregressive transformer predicting speech token sequences from text

#### B. Code Predictor (Multi-Token Prediction) — embedded in main model
- Hidden size: 1024
- Layers: 5
- Attention heads: 16
- Number of code groups: 16
- Codebook vocab size: 2048
- Predicts residual codebooks (16 layers) after the main LM predicts the zeroth codebook

#### C. Speech Tokenizer (Qwen3-TTS-Tokenizer-12Hz) — 682 MB safetensors
- Separate model in `speech_tokenizer/` directory
- GAN-based codec: encoder + decoder
- 16-layer multi-codebook RVQ (Residual Vector Quantization)
- First codebook: semantic (WavLM-guided)
- Remaining 15: acoustic details
- **Decoder**: lightweight causal ConvNet (no DiT/diffusion needed)
- Encodes reference audio → discrete codes, decodes codes → waveform

### Inference Pipeline

```
Text Input + Reference Audio
         ↓
[Speech Tokenizer Encoder] → reference audio codes + speaker embedding
         ↓
[Text Tokenizer] → text token IDs
         ↓
[Language Model] → autoregressive generation of zeroth codebook tokens
         ↓
[Code Predictor / MTP] → predict remaining 15 codebook layers
         ↓
[Speech Tokenizer Decoder / Causal ConvNet] → waveform output (24kHz)
```

---

## 3. ONNX Export Feasibility — NOT VIABLE

### Status: No ONNX support exists

- **No official ONNX export** from Qwen team
- **No community ONNX conversion** for Qwen3-TTS
- The Qwen3 architecture is **not supported** by HuggingFace Optimum's ONNX exporter
- Users attempting export get: `ValueError: Trying to export a qwen3 model, that is a custom or unsupported architecture, but no custom onnx configuration was passed`
- Even for base Qwen3 LLMs (non-TTS), ONNX export has significant issues with MoE routing, hybrid attention, and novel architecture components

### Why ONNX Won't Work for Qwen3-TTS

1. **Custom architecture** — `Qwen3TTSForConditionalGeneration` is not a standard transformer; it combines LM + code predictor + speech tokenizer
2. **Multi-codebook MTP module** — the code predictor generates 16 codebook layers, a non-standard operation
3. **Causal ConvNet decoder** — the speech tokenizer's decoder is a custom GAN-trained ConvNet, not a standard vocoder
4. **Dynamic control flow** — dual-track streaming architecture with conditional branching
5. **No Optimum support** — would require writing a custom ONNX config from scratch for each sub-component

**Verdict: The ONNX path (matching our Chatterbox approach) is a dead end for Qwen3-TTS.**

---

## 4. Rust-Native Inference Options

### Option A: Candle (HuggingFace) — RECOMMENDED

[candle](https://github.com/huggingface/candle) is HuggingFace's minimalist Rust ML framework.

**Why candle is the best fit:**

| Factor | Assessment |
|--------|------------|
| **Qwen model support** | ✅ Has `qwen2` module in candle-transformers; Qwen3 variants supported |
| **SafeTensors loading** | ✅ Native first-class support (safetensors is a Rust crate) |
| **GPU support** | ✅ CUDA backend, Metal (macOS), CPU with MKL |
| **Tokenizer support** | ✅ Uses the same `tokenizers` crate makima already depends on |
| **Audio models** | ✅ Supports EnCodec, Whisper, MetaVoice, Parler-TTS |
| **KV cache** | ✅ Well-established patterns in existing model implementations |
| **Community** | ✅ Active; Crane project already lists Qwen3-TTS as "highest priority" |
| **Binary size** | ✅ Compiles to single binary, no Python dependency |

**What needs to be implemented:**

1. **Qwen3-TTS transformer layers** — extend existing `qwen2` model code for the 28-layer LM with TTS-specific modifications (speaker encoder concatenation, code predictor output heads)
2. **Code Predictor (MTP)** — 5-layer module that generates 16 codebook predictions from the LM hidden states
3. **Speech Tokenizer Encoder** — ConvNet encoder that converts reference audio to discrete multi-codebook tokens + speaker embeddings
4. **Speech Tokenizer Decoder** — causal ConvNet that reconstructs waveforms from discrete codes
5. **Multi-codebook handling** — manage 16 parallel codebook sequences

**Estimated effort:** Medium-High. The LM backbone can reuse existing Qwen2/3 code. The speech tokenizer (encoder + decoder) is the most novel component.

**Key crate dependencies to add:**
```toml
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
# Keep existing: tokenizers, hf-hub, ndarray (for compatibility)
```

### Option B: Crane (Candle-based TTS Engine)

[Crane](https://github.com/lucasjinreal/Crane) is a pure Rust LLM inference engine built on candle, specifically designed for multi-modal models including TTS.

**Key facts:**
- Already supports Spark-TTS (codec-based TTS with similar architecture)
- **Qwen3-TTS is listed as "Highest Priority" on their roadmap**
- Handles multi-module architectures (codec + LLM pipelines)
- Supports Qwen2.5, Moonshine ASR
- Claims 50x faster than PyTorch on Apple Silicon

**Strategy:** Monitor Crane's Qwen3-TTS implementation. If they ship it, we could either:
- Use Crane as a dependency directly
- Port their implementation into makima's codebase
- Contribute to Crane and depend on it

**Risk:** Crane is a relatively new project; depending on it adds supply chain risk.

### Option C: qwen3-rs (Educational Reference)

[qwen3-rs](https://github.com/reinterpretcat/qwen3-rs) is an educational project implementing Qwen3 inference from scratch in Rust.

**Useful for:** Reference implementation of Qwen3 transformer layers, tokenization, KV cache, and safetensors loading — all without heavy ML framework dependencies. However, it only implements the base LLM, not the TTS-specific components.

### Option D: Direct ort (ONNX Runtime) with Custom Export — FALLBACK

If we could manually export each sub-component to ONNX:

1. Export the 28-layer LM backbone (similar complexity to Chatterbox)
2. Export the code predictor separately
3. Export the speech tokenizer encoder/decoder

This would match our existing Chatterbox pattern but requires Python scripting for the one-time export, and the Qwen3 architecture is explicitly unsupported by standard exporters. **Not recommended unless ONNX support materializes upstream.**

### Option E: PyTorch C++ (libtorch) via FFI — NOT RECOMMENDED

Using libtorch via Rust FFI bindings (`tch-rs` crate). This would:
- Add a ~2GB libtorch dependency
- Require complex build setup
- Introduce C++ dependency management
- Defeat the purpose of a pure Rust solution

---

## 5. Recommended Approach

### Phase 1: Candle-Based Implementation

**Architecture:**

```
makima/src/tts/
├── mod.rs              // TTS trait + factory (select Chatterbox vs Qwen3)
├── chatterbox.rs       // Existing ONNX-based Chatterbox (moved from tts.rs)
├── qwen3/
│   ├── mod.rs          // Qwen3TTS public API
│   ├── model.rs        // Qwen3 LM transformer (28 layers)
│   ├── code_predictor.rs   // MTP module (5 layers, 16 codebooks)
│   ├── speech_tokenizer.rs // Encoder + Decoder (causal ConvNet)
│   ├── config.rs       // Model config from config.json
│   └── generate.rs     // Autoregressive generation loop with KV cache
```

**Key implementation details:**

1. **Load safetensors directly** — candle's `safetensors` support reads the 1.83GB main model and 682MB speech tokenizer
2. **Reuse Qwen2 attention** — candle-transformers already has `qwen2::Model` with RoPE, GQA, and KV cache
3. **Implement ConvNet codec** — the speech tokenizer's encoder/decoder is a causal 1D ConvNet; candle has `Conv1d` layers
4. **Multi-codebook RVQ** — implement the 16-codebook residual vector quantization lookup
5. **Speaker embedding** — extract from reference audio via the speech tokenizer encoder
6. **Streaming support** — the 12Hz model's causal architecture enables token-by-token waveform generation

### Phase 2: Voice Assets

The model supports voice cloning with reference audio. For the default Makima voice:
- Need 5-15 second Japanese-accented English audio clip
- Reference audio + transcript fed to speech tokenizer encoder
- Speaker embedding cached for reuse

### Phase 3: Integration with Listen Page

Following the pattern in `listen.rs`:
- TTS model loaded lazily via `SharedState`
- Protected behind `tokio::sync::Mutex` (or `RwLock` for concurrent reads)
- WebSocket endpoint for streaming TTS (emit audio chunks as tokens are generated)
- Bidirectional: STT (listen) → process → TTS (speak) loop

---

## 6. Comparison Matrix

| Criteria | ONNX (current pattern) | Candle | Crane | libtorch |
|----------|----------------------|--------|-------|----------|
| Pure Rust | ✅ (ort crate) | ✅ | ✅ | ❌ (C++ FFI) |
| Qwen3-TTS support | ❌ No export | ⚠️ Needs impl | ⚠️ Planned | ✅ (full PyTorch) |
| Single binary | ✅ | ✅ | ✅ | ❌ |
| GPU acceleration | ✅ | ✅ | ✅ | ✅ |
| SafeTensors loading | ❌ (needs ONNX) | ✅ | ✅ | ✅ |
| Streaming TTS | ✅ | ✅ | ✅ | ✅ |
| Maintenance burden | Low | Medium | Low (if adopted) | High |
| Implementation effort | N/A (blocked) | Medium-High | Low (if available) | Medium |
| Dependency size | ~50MB | ~5MB | ~5MB | ~2GB |

---

## 7. Risk Assessment

| Risk | Likelihood | Impact | Mitigation |
|------|-----------|--------|------------|
| Candle implementation takes longer than expected | Medium | Medium | Reference Crane's Spark-TTS implementation; use qwen3-rs as LM reference |
| Speech tokenizer ConvNet is complex to port | Medium | High | Study the PyTorch source in qwen-tts package; ConvNet layers are simpler than transformers |
| Model quality differs from reference PyTorch | Low | High | Validate with reference audio samples; ensure bfloat16 precision |
| Crane ships Qwen3-TTS before we finish | Medium | Positive | Adopt their implementation |
| GPU memory issues on target hardware | Low | Medium | 0.6B model is small (~2.5GB); fits in 4GB VRAM with float16 |

---

## 8. Next Steps

1. **Immediate:** Add `candle-core`, `candle-nn`, `candle-transformers` to Cargo.toml
2. **Week 1:** Implement Qwen3 LM backbone in candle (extend existing qwen2 model)
3. **Week 2:** Implement speech tokenizer encoder/decoder (ConvNet + RVQ)
4. **Week 2:** Implement code predictor (MTP module)
5. **Week 3:** Integration testing with reference audio; validate output quality
6. **Week 3:** Wire into makima server as TTS endpoint
7. **Ongoing:** Monitor Crane project for Qwen3-TTS implementation

---

## Sources

- [Qwen3-TTS-12Hz-0.6B-Base on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-0.6B-Base)
- [Qwen3-TTS Technical Report (arXiv)](https://arxiv.org/html/2601.15621v1)
- [Qwen3-TTS GitHub Repository](https://github.com/QwenLM/Qwen3-TTS)
- [Candle — HuggingFace Rust ML Framework](https://github.com/huggingface/candle)
- [Crane — Rust LLM Inference Engine](https://github.com/lucasjinreal/Crane)
- [qwen3-rs — Educational Qwen3 Rust Implementation](https://github.com/reinterpretcat/qwen3-rs)
- [candle-transformers Qwen2 model](https://docs.rs/candle-transformers/latest/candle_transformers/models/qwen2/index.html)
- [Qwen3-TTS-Tokenizer-12Hz on HuggingFace](https://huggingface.co/Qwen/Qwen3-TTS-Tokenizer-12Hz)
- [ONNX export issues for Qwen3](https://huggingface.co/onnx-community/Qwen3-1.7B-ONNX/discussions/1)