diff options
| -rw-r--r-- | .github/workflows/release.yml | 58 | ||||
| -rw-r--r-- | Cargo.lock | 688 | ||||
| -rw-r--r-- | makima/Cargo.toml | 6 | ||||
| -rw-r--r-- | makima/src/server/handlers/speak.rs | 2 | ||||
| -rw-r--r-- | makima/src/server/state.rs | 10 | ||||
| -rw-r--r-- | makima/src/tts/mod.rs | 44 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/code_predictor.rs | 253 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/config.rs | 271 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/generate.rs | 456 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/mod.rs | 317 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/model.rs | 584 | ||||
| -rw-r--r-- | makima/src/tts/qwen3/speech_tokenizer.rs | 613 |
12 files changed, 77 insertions, 3225 deletions
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 84d340d..ca9aae0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -83,7 +83,7 @@ jobs: - name: List artifacts run: find artifacts -type f - - name: Create Release + - name: Create Release (soryu repo) uses: softprops/action-gh-release@v2 with: draft: false @@ -130,3 +130,59 @@ jobs: ``` files: | artifacts/**/*.tar.gz + + - name: Create Release (makima repo) + env: + GH_TOKEN: ${{ secrets.MAKIMA_RELEASE_TOKEN }} + run: | + # Create release notes file + cat > release_notes.md << 'EOF' + ## Makima CLI ${{ github.ref_name }} + + Release of the Makima CLI - a unified command-line interface for the Makima platform. + + ### Available Commands + + - **`makima server`** - Run the Makima server for audio processing and API endpoints + - **`makima daemon`** - Run the daemon that connects to the server and executes tasks + - **`makima supervisor`** - Supervisor commands for managing tasks and contracts + - **`makima contract`** - Contract-related commands for task tracking and reporting + + ### Installation + + Download the appropriate binary for your platform and add it to your PATH: + + ```bash + # Linux x86_64 + curl -LO https://github.com/soryu-co/makima/releases/download/${{ github.ref_name }}/makima-${{ github.ref_name }}-linux-x86_64.tar.gz + tar xzf makima-${{ github.ref_name }}-linux-x86_64.tar.gz + sudo mv makima /usr/local/bin/ + + # macOS Intel + curl -LO https://github.com/soryu-co/makima/releases/download/${{ github.ref_name }}/makima-${{ github.ref_name }}-macos-x86_64.tar.gz + tar xzf makima-${{ github.ref_name }}-macos-x86_64.tar.gz + sudo mv makima /usr/local/bin/ + + # macOS Apple Silicon + curl -LO https://github.com/soryu-co/makima/releases/download/${{ github.ref_name }}/makima-${{ github.ref_name }}-macos-arm64.tar.gz + tar xzf makima-${{ github.ref_name }}-macos-arm64.tar.gz + sudo mv makima /usr/local/bin/ + ``` + + ### Verification + + After installation, verify with: + ```bash + makima --help + ``` + EOF + + # Collect all artifact files + FILES=$(find artifacts -name "*.tar.gz" -type f | tr '\n' ' ') + + # Create release in soryu-co/makima repo + gh release create "${{ github.ref_name }}" \ + --repo soryu-co/makima \ + --title "Makima CLI ${{ github.ref_name }}" \ + --notes-file release_notes.md \ + $FILES @@ -250,21 +250,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - -[[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -299,20 +284,6 @@ name = "bytemuck" version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "byteorder" @@ -327,62 +298,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] -name = "candle-core" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1" -dependencies = [ - "byteorder", - "gemm 0.17.1", - "half", - "memmap2", - "num-traits", - "num_cpus", - "rand 0.9.2", - "rand_distr", - "rayon", - "safetensors", - "thiserror 1.0.69", - "ug", - "yoke 0.7.5", - "zip 1.1.4", -] - -[[package]] -name = "candle-nn" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be1160c3b63f47d40d91110a3e1e1e566ae38edddbbf492a60b40ffc3bc1ff38" -dependencies = [ - "candle-core", - "half", - "num-traits", - "rayon", - "safetensors", - "serde", - "thiserror 1.0.69", -] - -[[package]] -name = "candle-transformers" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020" -dependencies = [ - "byteorder", - "candle-core", - "candle-nn", - "fancy-regex", - "num-traits", - "rand 0.9.2", - "rayon", - "serde", - "serde_json", - "serde_plain", - "tracing", -] - -[[package]] name = "cassowary" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -970,32 +885,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] -name = "dyn-stack" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" -dependencies = [ - "bytemuck", - "reborrow", -] - -[[package]] -name = "dyn-stack" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8" -dependencies = [ - "bytemuck", - "dyn-stack-macros", -] - -[[package]] -name = "dyn-stack-macros" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9" - -[[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1020,18 +909,6 @@ dependencies = [ ] [[package]] -name = "enum-as-inner" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1107,17 +984,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] -name = "fancy-regex" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" -dependencies = [ - "bit-set", - "regex-automata", - "regex-syntax", -] - -[[package]] name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1320,243 +1186,6 @@ dependencies = [ ] [[package]] -name = "gemm" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-c32 0.17.1", - "gemm-c64 0.17.1", - "gemm-common 0.17.1", - "gemm-f16 0.17.1", - "gemm-f32 0.17.1", - "gemm-f64 0.17.1", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "seq-macro", -] - -[[package]] -name = "gemm" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-c32 0.18.2", - "gemm-c64 0.18.2", - "gemm-common 0.18.2", - "gemm-f16 0.18.2", - "gemm-f32 0.18.2", - "gemm-f64 0.18.2", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "seq-macro", -] - -[[package]] -name = "gemm-c32" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-common 0.17.1", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "seq-macro", -] - -[[package]] -name = "gemm-c32" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-common 0.18.2", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "seq-macro", -] - -[[package]] -name = "gemm-c64" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-common 0.17.1", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "seq-macro", -] - -[[package]] -name = "gemm-c64" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-common 0.18.2", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "seq-macro", -] - -[[package]] -name = "gemm-common" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" -dependencies = [ - "bytemuck", - "dyn-stack 0.10.0", - "half", - "num-complex", - "num-traits", - "once_cell", - "paste", - "pulp 0.18.22", - "raw-cpuid 10.7.0", - "rayon", - "seq-macro", - "sysctl 0.5.5", -] - -[[package]] -name = "gemm-common" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" -dependencies = [ - "bytemuck", - "dyn-stack 0.13.2", - "half", - "libm", - "num-complex", - "num-traits", - "once_cell", - "paste", - "pulp 0.21.5", - "raw-cpuid 11.6.0", - "rayon", - "seq-macro", - "sysctl 0.6.0", -] - -[[package]] -name = "gemm-f16" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-common 0.17.1", - "gemm-f32 0.17.1", - "half", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "rayon", - "seq-macro", -] - -[[package]] -name = "gemm-f16" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-common 0.18.2", - "gemm-f32 0.18.2", - "half", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "rayon", - "seq-macro", -] - -[[package]] -name = "gemm-f32" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-common 0.17.1", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "seq-macro", -] - -[[package]] -name = "gemm-f32" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-common 0.18.2", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "seq-macro", -] - -[[package]] -name = "gemm-f64" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" -dependencies = [ - "dyn-stack 0.10.0", - "gemm-common 0.17.1", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 10.7.0", - "seq-macro", -] - -[[package]] -name = "gemm-f64" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" -dependencies = [ - "dyn-stack 0.13.2", - "gemm-common 0.18.2", - "num-complex", - "num-traits", - "paste", - "raw-cpuid 11.6.0", - "seq-macro", -] - -[[package]] name = "generic-array" version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1617,21 +1246,6 @@ dependencies = [ ] [[package]] -name = "half" -version = "2.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" -dependencies = [ - "bytemuck", - "cfg-if", - "crunchy", - "num-traits", - "rand 0.9.2", - "rand_distr", - "zerocopy", -] - -[[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1935,7 +1549,7 @@ checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", "potential_utf", - "yoke 0.8.1", + "yoke", "zerofrom", "zerovec", ] @@ -2002,7 +1616,7 @@ dependencies = [ "displaydoc", "icu_locale_core", "writeable", - "yoke 0.8.1", + "yoke", "zerofrom", "zerotrie", "zerovec", @@ -2282,16 +1896,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] -name = "libloading" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" -dependencies = [ - "cfg-if", - "windows-link", -] - -[[package]] name = "libm" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2397,9 +2001,6 @@ dependencies = [ "backoff", "base64 0.22.1", "bytes", - "candle-core", - "candle-nn", - "candle-transformers", "chrono", "clap", "config", @@ -2429,7 +2030,6 @@ dependencies = [ "regex", "reqwest", "rusqlite", - "safetensors", "serde", "serde_json", "sha2", @@ -2492,16 +2092,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] -name = "memmap2" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" -dependencies = [ - "libc", - "stable_deref_trait", -] - -[[package]] name = "memoffset" version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2671,20 +2261,6 @@ dependencies = [ ] [[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] name = "num-bigint" version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2716,7 +2292,6 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ - "bytemuck", "num-traits", ] @@ -2747,17 +2322,6 @@ dependencies = [ ] [[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] name = "num-traits" version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2778,28 +2342,6 @@ dependencies = [ ] [[package]] -name = "num_enum" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" -dependencies = [ - "num_enum_derive", - "rustversion", -] - -[[package]] -name = "num_enum_derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" -dependencies = [ - "proc-macro-crate", - "proc-macro2", - "quote", - "syn", -] - -[[package]] name = "number_prefix" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3157,15 +2699,6 @@ dependencies = [ ] [[package]] -name = "proc-macro-crate" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" -dependencies = [ - "toml_edit 0.23.10+spec-1.0.0", -] - -[[package]] name = "proc-macro2" version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3175,32 +2708,6 @@ dependencies = [ ] [[package]] -name = "pulp" -version = "0.18.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" -dependencies = [ - "bytemuck", - "libm", - "num-complex", - "reborrow", -] - -[[package]] -name = "pulp" -version = "0.21.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907" -dependencies = [ - "bytemuck", - "cfg-if", - "libm", - "num-complex", - "reborrow", - "version_check", -] - -[[package]] name = "quote" version = "1.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3275,16 +2782,6 @@ dependencies = [ ] [[package]] -name = "rand_distr" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" -dependencies = [ - "num-traits", - "rand 0.9.2", -] - -[[package]] name = "ratatui" version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3306,24 +2803,6 @@ dependencies = [ ] [[package]] -name = "raw-cpuid" -version = "10.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" -dependencies = [ - "bitflags 1.3.2", -] - -[[package]] -name = "raw-cpuid" -version = "11.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" -dependencies = [ - "bitflags 2.10.0", -] - -[[package]] name = "rawpointer" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3372,12 +2851,6 @@ dependencies = [ ] [[package]] -name = "reborrow" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" - -[[package]] name = "redox_syscall" version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3681,16 +3154,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] -name = "safetensors" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3738,12 +3201,6 @@ dependencies = [ ] [[package]] -name = "seq-macro" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" - -[[package]] name = "serde" version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3798,15 +3255,6 @@ dependencies = [ ] [[package]] -name = "serde_plain" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" -dependencies = [ - "serde", -] - -[[package]] name = "serde_spanned" version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4523,34 +3971,6 @@ dependencies = [ ] [[package]] -name = "sysctl" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" -dependencies = [ - "bitflags 2.10.0", - "byteorder", - "enum-as-inner", - "libc", - "thiserror 1.0.69", - "walkdir", -] - -[[package]] -name = "sysctl" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" -dependencies = [ - "bitflags 2.10.0", - "byteorder", - "enum-as-inner", - "libc", - "thiserror 1.0.69", - "walkdir", -] - -[[package]] name = "system-configuration" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4890,8 +4310,8 @@ checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", "serde_spanned", - "toml_datetime 0.6.11", - "toml_edit 0.22.27", + "toml_datetime", + "toml_edit", ] [[package]] @@ -4904,15 +4324,6 @@ dependencies = [ ] [[package]] -name = "toml_datetime" -version = "0.7.5+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" -dependencies = [ - "serde_core", -] - -[[package]] name = "toml_edit" version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4921,33 +4332,12 @@ dependencies = [ "indexmap", "serde", "serde_spanned", - "toml_datetime 0.6.11", + "toml_datetime", "toml_write", "winnow", ] [[package]] -name = "toml_edit" -version = "0.23.10+spec-1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" -dependencies = [ - "indexmap", - "toml_datetime 0.7.5+spec-1.1.0", - "toml_parser", - "winnow", -] - -[[package]] -name = "toml_parser" -version = "1.0.6+spec-1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" -dependencies = [ - "winnow", -] - -[[package]] name = "toml_write" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5140,27 +4530,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" [[package]] -name = "ug" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03719c61a91b51541f076dfdba45caacf750b230cefaa4b32d6f5411c3f7f437" -dependencies = [ - "gemm 0.18.2", - "half", - "libloading", - "memmap2", - "num", - "num-traits", - "num_cpus", - "rayon", - "safetensors", - "serde", - "thiserror 1.0.69", - "tracing", - "yoke 0.7.5", -] - -[[package]] name = "unicase" version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5369,7 +4738,7 @@ dependencies = [ "serde_json", "url", "utoipa", - "zip 3.0.0", + "zip", ] [[package]] @@ -5955,41 +5324,17 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive 0.7.5", - "zerofrom", -] - -[[package]] -name = "yoke" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ "stable_deref_trait", - "yoke-derive 0.8.1", + "yoke-derive", "zerofrom", ] [[package]] name = "yoke-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "yoke-derive" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" @@ -6054,7 +5399,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" dependencies = [ "displaydoc", - "yoke 0.8.1", + "yoke", "zerofrom", ] @@ -6064,7 +5409,7 @@ version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ - "yoke 0.8.1", + "yoke", "zerofrom", "zerovec-derive", ] @@ -6082,21 +5427,6 @@ dependencies = [ [[package]] name = "zip" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164" -dependencies = [ - "arbitrary", - "crc32fast", - "crossbeam-utils", - "displaydoc", - "indexmap", - "num_enum", - "thiserror 1.0.69", -] - -[[package]] -name = "zip" version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12598812502ed0105f607f941c386f43d441e00148fce9dec3ca5ffb0bde9308" diff --git a/makima/Cargo.toml b/makima/Cargo.toml index b6b12dd..950c123 100644 --- a/makima/Cargo.toml +++ b/makima/Cargo.toml @@ -17,12 +17,6 @@ tokenizers = "0.21" hf-hub = "0.4" ndarray = "0.16" -# Candle ML framework (Qwen3-TTS native inference) -candle-core = "0.8" -candle-nn = "0.8" -candle-transformers = "0.8" -safetensors = "0.4" - # Web server axum = { version = "0.8", features = ["ws", "multipart"] } tokio = { version = "1.0", features = ["full", "signal", "process"] } diff --git a/makima/src/server/handlers/speak.rs b/makima/src/server/handlers/speak.rs index b235c65..0f94b40 100644 --- a/makima/src/server/handlers/speak.rs +++ b/makima/src/server/handlers/speak.rs @@ -48,7 +48,7 @@ enum ClientMessage { /// WebSocket upgrade handler for TTS streaming. /// /// This endpoint accepts WebSocket connections for text-to-speech synthesis. -/// The TTS model runs directly in-process using candle — no external service. +/// The TTS model runs directly in-process using ONNX — no external service. #[utoipa::path( get, path = "/api/v1/speak", diff --git a/makima/src/server/state.rs b/makima/src/server/state.rs index bd6864f..ba9f9cf 100644 --- a/makima/src/server/state.rs +++ b/makima/src/server/state.rs @@ -700,12 +700,10 @@ impl AppState { model_dir = ?tts_dir, "Lazy-loading TTS engine (Chatterbox) on first Speak connection..." ); - let engine = crate::tts::TtsEngineFactory::create( - crate::tts::TtsBackend::Chatterbox, - tts_dir, - ).map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { - Box::new(e) - })?; + let engine = crate::tts::TtsEngineFactory::create(tts_dir) + .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { + Box::new(e) + })?; tracing::info!("TTS engine loaded successfully"); Ok(engine) }).await.map(|b| b.as_ref()) diff --git a/makima/src/tts/mod.rs b/makima/src/tts/mod.rs index b66f4a5..31f4204 100644 --- a/makima/src/tts/mod.rs +++ b/makima/src/tts/mod.rs @@ -1,19 +1,15 @@ //! TTS engine abstraction and implementations. //! -//! Provides a trait-based TTS engine interface with two backends: -//! - **Chatterbox**: ONNX-based TTS (legacy) -//! - **Qwen3**: Pure Rust candle-based Qwen3-TTS-12Hz-0.6B +//! Provides a trait-based TTS engine interface using Chatterbox ONNX-based TTS. use std::path::Path; use std::sync::atomic::AtomicBool; use std::sync::Arc; pub mod chatterbox; -pub mod qwen3; // Re-export primary types pub use chatterbox::ChatterboxTTS; -pub use qwen3::Qwen3Tts; /// Audio output sample rate (both engines output 24kHz). pub const SAMPLE_RATE: u32 = 24_000; @@ -51,8 +47,6 @@ pub enum TtsError { Audio(crate::audio::AudioError), Io(std::io::Error), VoiceRequired, - Config(String), - Candle(String), } impl std::fmt::Display for TtsError { @@ -66,8 +60,6 @@ impl std::fmt::Display for TtsError { TtsError::VoiceRequired => { write!(f, "voice reference audio is required") } - TtsError::Config(msg) => write!(f, "config error: {msg}"), - TtsError::Candle(msg) => write!(f, "candle error: {msg}"), } } } @@ -92,22 +84,7 @@ impl From<ort::Error> for TtsError { } } -impl From<candle_core::Error> for TtsError { - fn from(value: candle_core::Error) -> Self { - TtsError::Candle(value.to_string()) - } -} - -/// Which TTS backend to use. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum TtsBackend { - /// ONNX-based Chatterbox TTS (legacy). - Chatterbox, - /// Candle-based Qwen3-TTS (preferred). - Qwen3, -} - -/// TTS engine trait — implemented by both Chatterbox and Qwen3. +/// TTS engine trait for text-to-speech synthesis. #[async_trait::async_trait] pub trait TtsEngine: Send + Sync { /// Generate complete audio from text with a voice reference. @@ -137,19 +114,10 @@ pub trait TtsEngine: Send + Sync { pub struct TtsEngineFactory; impl TtsEngineFactory { - /// Create a TTS engine of the specified backend type. - pub fn create(backend: TtsBackend, model_dir: Option<&str>) -> Result<Box<dyn TtsEngine>, TtsError> { - match backend { - TtsBackend::Chatterbox => { - let engine = ChatterboxTTS::from_pretrained(model_dir)?; - Ok(Box::new(engine)) - } - TtsBackend::Qwen3 => { - let device = candle_core::Device::Cpu; // Default to CPU; GPU selection happens at higher level - let engine = Qwen3Tts::from_pretrained(model_dir, &device)?; - Ok(Box::new(engine)) - } - } + /// Create a Chatterbox TTS engine. + pub fn create(model_dir: Option<&str>) -> Result<Box<dyn TtsEngine>, TtsError> { + let engine = ChatterboxTTS::from_pretrained(model_dir)?; + Ok(Box::new(engine)) } } diff --git a/makima/src/tts/qwen3/code_predictor.rs b/makima/src/tts/qwen3/code_predictor.rs deleted file mode 100644 index 363105f..0000000 --- a/makima/src/tts/qwen3/code_predictor.rs +++ /dev/null @@ -1,253 +0,0 @@ -//! Multi-Token Prediction (MTP) code predictor. -//! -//! After the main LM predicts the zeroth codebook token, this module -//! predicts the remaining 15 codebook layers in parallel from the -//! LM's hidden states. -//! -//! Architecture: -//! - 5 transformer layers (same structure as main LM layers) -//! - 16 output heads, one per codebook (vocab 2048 each) -//! - Input: last hidden state from main LM + zeroth codebook embedding -//! - Output: 16 codebook token predictions - -use candle_core::{Device, Module, Result, Tensor}; -use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; - -use super::config::{CodePredictorConfig, Qwen3LmConfig}; -use super::model::{KvCache, Qwen3Attention, Qwen3Mlp, RotaryEmbedding}; - -/// A single code predictor transformer layer. -/// -/// Uses the same pre-norm residual structure as the main LM layers. -pub struct CodePredictorLayer { - self_attn: Qwen3Attention, - mlp: Qwen3Mlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, -} - -impl CodePredictorLayer { - pub fn new(config: &CodePredictorConfig, vb: VarBuilder) -> Result<Self> { - // Construct a Qwen3LmConfig-like view for the attention/MLP constructors - let lm_config = Qwen3LmConfig { - hidden_size: config.hidden_size, - num_hidden_layers: config.num_layers, - num_attention_heads: config.num_attention_heads, - num_key_value_heads: config.num_attention_heads, // No GQA in predictor - intermediate_size: config.hidden_size * 3, // 3072 for hidden=1024 - head_dim: config.hidden_size / config.num_attention_heads, - rms_norm_eps: config.rms_norm_eps, - ..Qwen3LmConfig::default() - }; - - let self_attn = Qwen3Attention::new(&lm_config, vb.pp("self_attn"))?; - let mlp = Qwen3Mlp::new(&lm_config, vb.pp("mlp"))?; - let input_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("input_layernorm"), - )?; - let post_attention_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - - Ok(Self { - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - }) - } - - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let residual = hidden_states; - let hidden_states = self.input_layernorm.forward(hidden_states)?; - let hidden_states = - self.self_attn - .forward(&hidden_states, rope, kv_cache, attention_mask)?; - let hidden_states = (residual + hidden_states)?; - - let residual = &hidden_states; - let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let output = (residual + hidden_states)?; - - Ok(output) - } -} - -/// Multi-token prediction code predictor. -/// -/// Takes the hidden states from the main LM and predicts all 16 codebook -/// tokens. The zeroth codebook is predicted by the main LM head; this -/// module predicts the remaining 15 residual codebooks. -pub struct CodePredictor { - /// Embedding layer for codebook tokens (one per residual codebook group, 0-14). - code_embeddings: Vec<Embedding>, - /// 5 transformer layers. - layers: Vec<CodePredictorLayer>, - /// Final normalization. - norm: RmsNorm, - /// Per-codebook output heads (15 heads for residual codebooks). - output_heads: Vec<Linear>, - /// RoPE for the predictor's attention layers. - rope: RotaryEmbedding, - config: CodePredictorConfig, -} - -impl CodePredictor { - pub fn new( - config: &CodePredictorConfig, - lm_config: &Qwen3LmConfig, - vb: VarBuilder, - ) -> Result<Self> { - // HuggingFace Qwen3-TTS uses "talker.code_predictor.*" prefix - let predictor_vb = vb.pp("talker").pp("code_predictor"); - let model_vb = predictor_vb.pp("model"); - - // Code embeddings for residual codebook groups (15 groups, indices 0-14) - // HF names them "codec_embedding" not "code_embeddings" - let num_residual_groups = config.num_code_groups - 1; // 15, not 16 - let mut code_embeddings = Vec::with_capacity(num_residual_groups); - for i in 0..num_residual_groups { - let emb = embedding( - config.codebook_vocab_size, - config.hidden_size, - model_vb.pp(format!("codec_embedding.{i}")), - )?; - code_embeddings.push(emb); - } - - // Transformer layers - let mut layers = Vec::with_capacity(config.num_layers); - for i in 0..config.num_layers { - let layer = - CodePredictorLayer::new(config, model_vb.pp(format!("layers.{i}")))?; - layers.push(layer); - } - - let norm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - model_vb.pp("norm"), - )?; - - // Output heads for residual codebooks (15 heads, indices 0-14) - // HF names them "lm_head" not "output_heads" - let mut output_heads = Vec::with_capacity(num_residual_groups); - for i in 0..num_residual_groups { - let head = linear_no_bias( - config.hidden_size, - config.codebook_vocab_size, - predictor_vb.pp(format!("lm_head.{i}")), - )?; - output_heads.push(head); - } - - // RoPE for predictor attention (uses same theta/dim as main LM but with predictor head_dim) - let predictor_head_dim = config.hidden_size / config.num_attention_heads; - let rope_config = Qwen3LmConfig { - head_dim: predictor_head_dim, - rope_theta: lm_config.rope_theta, - max_position_embeddings: lm_config.max_position_embeddings, - ..Qwen3LmConfig::default() - }; - let rope = RotaryEmbedding::new(&rope_config, vb.dtype(), vb.device())?; - - Ok(Self { - code_embeddings, - layers, - norm, - output_heads, - rope, - config: config.clone(), - }) - } - - /// Predict all 16 codebook tokens from the LM hidden state. - /// - /// `lm_hidden`: [batch, 1, hidden_size] — last hidden state from main LM - /// `zeroth_code`: the token predicted by the main LM head (zeroth codebook) - /// - /// Returns: Vec of 16 token indices (one per codebook), starting with zeroth_code. - pub fn predict( - &self, - lm_hidden: &Tensor, - zeroth_code: u32, - device: &Device, - ) -> Result<Vec<u32>> { - let mut all_codes = Vec::with_capacity(self.config.num_code_groups); - all_codes.push(zeroth_code); - - // The code predictor iterates through the 15 residual codebook groups. - // For each group i (0..15), it: - // 1. Embeds the previous codebook token - // 2. Adds to LM hidden state - // 3. Runs through predictor layers - // 4. Predicts the next codebook token via lm_head[i] - let mut prev_code = zeroth_code; - - for group_idx in 0..self.code_embeddings.len() { - // Embed the previous codebook token - let code_tensor = Tensor::from_vec( - vec![prev_code], - (1, 1), - device, - )?; - let code_emb = self.code_embeddings[group_idx].forward(&code_tensor)?; - - // Add code embedding to LM hidden state (no concatenation, no projection) - let mut hidden = (lm_hidden + &code_emb)?; - - // Run through predictor transformer layers (no KV cache needed — single step) - let mut kv_caches: Vec<KvCache> = - (0..self.config.num_layers).map(|_| KvCache::new()).collect(); - for (i, layer) in self.layers.iter().enumerate() { - hidden = layer.forward(&hidden, &self.rope, &mut kv_caches[i], None)?; - } - - hidden = self.norm.forward(&hidden)?; - - // Predict codebook token - let logits = self.output_heads[group_idx].forward(&hidden)?; - - // Greedy decode: argmax - let logits_flat = logits.squeeze(0)?.squeeze(0)?; // [codebook_vocab_size] - let next_code = logits_flat - .argmax(0)? - .to_scalar::<u32>()?; - - all_codes.push(next_code); - prev_code = next_code; - } - - Ok(all_codes) - } - - /// Number of codebook groups. - pub fn num_code_groups(&self) -> usize { - self.config.num_code_groups - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_code_predictor_config() { - let config = CodePredictorConfig::default(); - assert_eq!(config.num_layers, 5); - assert_eq!(config.num_code_groups, 16); - assert_eq!(config.codebook_vocab_size, 2048); - assert_eq!(config.hidden_size, 1024); - } -} diff --git a/makima/src/tts/qwen3/config.rs b/makima/src/tts/qwen3/config.rs deleted file mode 100644 index 6fb55d7..0000000 --- a/makima/src/tts/qwen3/config.rs +++ /dev/null @@ -1,271 +0,0 @@ -//! Qwen3-TTS model configuration. -//! -//! Parses config.json from the HuggingFace model repository to configure -//! the language model, code predictor, and speech tokenizer. - -use serde::Deserialize; - -use crate::tts::TtsError; - -/// Top-level configuration for Qwen3-TTS-12Hz-0.6B-Base. -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3TtsConfig { - /// Language model (talker) configuration. - #[serde(default = "Qwen3LmConfig::default")] - pub lm: Qwen3LmConfig, - - /// Code predictor (multi-token prediction) configuration. - #[serde(default = "CodePredictorConfig::default")] - pub code_predictor: CodePredictorConfig, - - /// Speech tokenizer configuration. - #[serde(default = "SpeechTokenizerConfig::default")] - pub speech_tokenizer: SpeechTokenizerConfig, -} - -impl Default for Qwen3TtsConfig { - fn default() -> Self { - Self { - lm: Qwen3LmConfig::default(), - code_predictor: CodePredictorConfig::default(), - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -impl Qwen3TtsConfig { - /// Load from a config.json file path. - pub fn from_json_path(path: &std::path::Path) -> Result<Self, TtsError> { - let content = std::fs::read_to_string(path) - .map_err(|e| TtsError::Config(format!("failed to read config: {e}")))?; - Self::from_json_str(&content) - } - - /// Load from a JSON string. - pub fn from_json_str(json: &str) -> Result<Self, TtsError> { - // Try to parse the full HuggingFace config.json format first - if let Ok(hf_config) = serde_json::from_str::<HfConfig>(json) { - return Ok(Self::from_hf_config(&hf_config)); - } - // Fall back to direct deserialization - serde_json::from_str(json) - .map_err(|e| TtsError::Config(format!("failed to parse config: {e}"))) - } - - /// Convert from HuggingFace's config.json format. - fn from_hf_config(hf: &HfConfig) -> Self { - Self { - lm: Qwen3LmConfig { - hidden_size: hf.hidden_size.unwrap_or(1024), - num_hidden_layers: hf.num_hidden_layers.unwrap_or(28), - num_attention_heads: hf.num_attention_heads.unwrap_or(16), - num_key_value_heads: hf.num_key_value_heads.unwrap_or(8), - intermediate_size: hf.intermediate_size.unwrap_or(3072), - head_dim: hf.head_dim.unwrap_or(128), - vocab_size: hf.vocab_size.unwrap_or(151_936), - max_position_embeddings: hf.max_position_embeddings.unwrap_or(32_768), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - rope_theta: hf.rope_theta.unwrap_or(1_000_000.0), - use_sliding_window: hf.use_sliding_window.unwrap_or(false), - sliding_window: hf.sliding_window, - hidden_act: hf.hidden_act.clone().unwrap_or_else(|| "silu".to_string()), - }, - code_predictor: CodePredictorConfig { - hidden_size: hf.code_predictor_hidden_size.unwrap_or(1024), - num_layers: hf.code_predictor_num_layers.unwrap_or(5), - num_attention_heads: hf - .code_predictor_num_attention_heads - .unwrap_or(16), - num_code_groups: hf.num_code_groups.unwrap_or(16), - codebook_vocab_size: hf.codebook_vocab_size.unwrap_or(2048), - rms_norm_eps: hf.rms_norm_eps.unwrap_or(1e-6), - }, - speech_tokenizer: SpeechTokenizerConfig::default(), - } - } -} - -/// Language model configuration (28-layer Qwen3 transformer). -#[derive(Debug, Clone, Deserialize)] -pub struct Qwen3LmConfig { - /// Hidden dimension of transformer layers. - pub hidden_size: usize, - /// Number of transformer layers. - pub num_hidden_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of key-value heads (GQA). - pub num_key_value_heads: usize, - /// Feed-forward intermediate size. - pub intermediate_size: usize, - /// Dimension per attention head. - pub head_dim: usize, - /// Text vocabulary size. - pub vocab_size: usize, - /// Maximum sequence length for RoPE. - pub max_position_embeddings: usize, - /// RMS normalization epsilon. - pub rms_norm_eps: f64, - /// RoPE theta parameter. - pub rope_theta: f64, - /// Whether to use sliding window attention. - pub use_sliding_window: bool, - /// Sliding window size (if enabled). - pub sliding_window: Option<usize>, - /// Activation function name. - pub hidden_act: String, -} - -impl Default for Qwen3LmConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_hidden_layers: 28, - num_attention_heads: 16, - num_key_value_heads: 8, - intermediate_size: 3072, - head_dim: 128, - vocab_size: 151_936, - max_position_embeddings: 32_768, - rms_norm_eps: 1e-6, - rope_theta: 1_000_000.0, - use_sliding_window: false, - sliding_window: None, - hidden_act: "silu".to_string(), - } - } -} - -impl Qwen3LmConfig { - /// Number of key-value head groups for GQA. - pub fn num_kv_groups(&self) -> usize { - self.num_attention_heads / self.num_key_value_heads - } -} - -/// Code predictor (multi-token prediction) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct CodePredictorConfig { - /// Hidden size (matches LM hidden size). - pub hidden_size: usize, - /// Number of predictor transformer layers. - pub num_layers: usize, - /// Number of attention heads. - pub num_attention_heads: usize, - /// Number of codebook groups (residual codebooks). - pub num_code_groups: usize, - /// Vocabulary size per codebook. - pub codebook_vocab_size: usize, - /// RMS norm epsilon. - pub rms_norm_eps: f64, -} - -impl Default for CodePredictorConfig { - fn default() -> Self { - Self { - hidden_size: 1024, - num_layers: 5, - num_attention_heads: 16, - num_code_groups: 16, - codebook_vocab_size: 2048, - rms_norm_eps: 1e-6, - } - } -} - -/// Speech tokenizer (ConvNet codec) configuration. -#[derive(Debug, Clone, Deserialize)] -pub struct SpeechTokenizerConfig { - /// Number of RVQ codebooks. - pub num_codebooks: usize, - /// Codebook embedding dimension. - pub codebook_dim: usize, - /// Codebook vocabulary size per layer. - pub codebook_size: usize, - /// Encoder/decoder hidden channels. - pub hidden_channels: usize, - /// Output sample rate. - pub sample_rate: u32, - /// Token frame rate (Hz). - pub frame_rate: f32, - /// HuggingFace model ID for the speech tokenizer. - pub model_id: String, -} - -impl Default for SpeechTokenizerConfig { - fn default() -> Self { - Self { - num_codebooks: 16, - codebook_dim: 256, - codebook_size: 2048, - hidden_channels: 512, - sample_rate: 24_000, - frame_rate: 12.5, - model_id: "Qwen/Qwen3-TTS-Tokenizer-12Hz".to_string(), - } - } -} - -/// HuggingFace config.json format (partial, fields we need). -#[derive(Debug, Deserialize)] -struct HfConfig { - hidden_size: Option<usize>, - num_hidden_layers: Option<usize>, - num_attention_heads: Option<usize>, - num_key_value_heads: Option<usize>, - intermediate_size: Option<usize>, - head_dim: Option<usize>, - vocab_size: Option<usize>, - max_position_embeddings: Option<usize>, - rms_norm_eps: Option<f64>, - rope_theta: Option<f64>, - use_sliding_window: Option<bool>, - sliding_window: Option<usize>, - hidden_act: Option<String>, - // Code predictor specific fields - code_predictor_hidden_size: Option<usize>, - code_predictor_num_layers: Option<usize>, - code_predictor_num_attention_heads: Option<usize>, - num_code_groups: Option<usize>, - codebook_vocab_size: Option<usize>, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.num_attention_heads, 16); - assert_eq!(config.lm.num_key_value_heads, 8); - assert_eq!(config.lm.head_dim, 128); - assert_eq!(config.lm.num_kv_groups(), 2); - assert_eq!(config.code_predictor.num_layers, 5); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.num_codebooks, 16); - } - - #[test] - fn test_config_from_json() { - let json = r#"{ - "hidden_size": 1024, - "num_hidden_layers": 28, - "num_attention_heads": 16, - "num_key_value_heads": 8, - "intermediate_size": 3072, - "vocab_size": 151936, - "max_position_embeddings": 32768, - "rms_norm_eps": 1e-6, - "rope_theta": 1000000.0, - "hidden_act": "silu" - }"#; - - let config = Qwen3TtsConfig::from_json_str(json).unwrap(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.lm.vocab_size, 151_936); - } -} diff --git a/makima/src/tts/qwen3/generate.rs b/makima/src/tts/qwen3/generate.rs deleted file mode 100644 index 30d165b..0000000 --- a/makima/src/tts/qwen3/generate.rs +++ /dev/null @@ -1,456 +0,0 @@ -//! Autoregressive generation loop for Qwen3-TTS. -//! -//! Orchestrates the full inference pipeline: -//! 1. Encode reference audio → speaker embedding via speech tokenizer -//! 2. Tokenize text → token IDs -//! 3. Autoregressive LM generation → zeroth codebook tokens -//! 4. Code predictor → remaining 15 codebook tokens per frame -//! 5. Speech tokenizer decoder → waveform audio - -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device, IndexOp, Result, Tensor}; -use tokenizers::Tokenizer; - -use super::code_predictor::CodePredictor; -use super::model::{KvCache, Qwen3Model}; -use super::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsError, SAMPLE_RATE}; - -/// Special tokens for the Qwen3-TTS vocabulary. -pub const BOS_TOKEN_ID: u32 = 151_643; -pub const EOS_TOKEN_ID: u32 = 151_645; -pub const PAD_TOKEN_ID: u32 = 151_643; - -/// Speech-specific control tokens. -/// These are placeholders — actual values come from the tokenizer config. -pub const START_OF_SPEECH: u32 = 151_668; -pub const END_OF_SPEECH: u32 = 151_669; - -/// Generation configuration. -#[derive(Debug, Clone)] -pub struct GenerationConfig { - /// Maximum number of speech tokens to generate. - pub max_new_tokens: usize, - /// Temperature for sampling (1.0 = greedy if top_k=1). - pub temperature: f32, - /// Top-k sampling (0 = disabled, use greedy argmax). - pub top_k: usize, - /// Repetition penalty. - pub repetition_penalty: f32, - /// Whether to generate audio chunks incrementally (streaming). - pub streaming: bool, -} - -impl Default for GenerationConfig { - fn default() -> Self { - Self { - max_new_tokens: 2048, - temperature: 1.0, - top_k: 0, // Greedy by default - repetition_penalty: 1.2, - streaming: false, - } - } -} - -/// Manages the full generation pipeline. -pub struct GenerationContext<'a> { - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - /// Optional cancellation flag. When set to `true`, the generation loop - /// will break early and return whatever audio has been produced so far. - cancel_flag: Option<Arc<AtomicBool>>, -} - -impl<'a> GenerationContext<'a> { - pub fn new( - model: &'a Qwen3Model, - code_predictor: &'a CodePredictor, - speech_tokenizer: &'a SpeechTokenizer, - tokenizer: &'a Tokenizer, - device: &'a Device, - config: GenerationConfig, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Self { - Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - device, - config, - cancel_flag, - } - } - - /// Check whether cancellation has been requested. - fn is_cancelled(&self) -> bool { - self.cancel_flag - .as_ref() - .map_or(false, |f| f.load(Ordering::Relaxed)) - } - - /// Generate audio from text, optionally with a voice reference. - /// - /// Returns a list of audio chunks. If `streaming` is false, returns - /// a single chunk with the complete audio. - pub fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - // 1. Encode reference audio if provided - let reference_codes = match reference_audio { - Some(audio) => Some( - self.speech_tokenizer - .encode(audio) - .map_err(|e| TtsError::Inference(format!("speech encoder failed: {e}")))?, - ), - None => None, - }; - - // 2. Tokenize text - let encoding = self - .tokenizer - .encode(text, true) - .map_err(|e| TtsError::Tokenizer(e.to_string()))?; - - let text_token_ids: Vec<u32> = encoding.get_ids().to_vec(); - - // 3. Prepare input sequence - // Format: [BOS] [text_tokens] [START_OF_SPEECH] - let mut input_ids = Vec::new(); - input_ids.push(BOS_TOKEN_ID); - input_ids.extend_from_slice(&text_token_ids); - input_ids.push(START_OF_SPEECH); - - // 4. Run autoregressive generation - let generated_frames = self - .autoregressive_generate(&input_ids, reference_codes.as_deref()) - .map_err(|e| TtsError::Inference(format!("generation failed: {e}")))?; - - if generated_frames.is_empty() { - return Ok(vec![AudioChunk { - samples: vec![], - sample_rate: SAMPLE_RATE, - is_final: true, - }]); - } - - // 5. Decode all frames to audio - if self.config.streaming { - self.decode_streaming(&generated_frames) - } else { - self.decode_batch(&generated_frames) - } - } - - /// Autoregressive generation loop. - /// - /// Generates zeroth codebook tokens one at a time, then uses the code - /// predictor to fill in the remaining 15 codebooks per frame. - /// - /// Returns: Vec of frames, each frame is [num_codebooks] tokens. - fn autoregressive_generate( - &self, - input_ids: &[u32], - _reference_codes: Option<&[Vec<u32>]>, - ) -> Result<Vec<Vec<u32>>> { - let _num_codebooks = self.code_predictor.num_code_groups(); - let mut kv_caches: Vec<KvCache> = (0..self.model.num_layers()) - .map(|_| KvCache::new()) - .collect(); - - let mut generated_frames: Vec<Vec<u32>> = Vec::new(); - let mut past_zeroth_tokens: Vec<u32> = Vec::new(); - - // === First iteration: process the full input sequence === - let input_tensor = Tensor::from_vec( - input_ids.iter().map(|&x| x as i64).collect::<Vec<_>>(), - (1, input_ids.len()), - self.device, - )? - .to_dtype(DType::I64)?; - - let seq_len = input_ids.len(); - let attention_mask = - Qwen3Model::make_causal_mask(seq_len, 0, DType::F32, self.device)?; - - let logits = - self.model - .forward(&input_tensor, &mut kv_caches, Some(&attention_mask))?; - - // Get the logits for the last position - let last_logits = logits.i((0, seq_len - 1, ..))?; // [vocab_size] - let first_token = self.sample_token(&last_logits, &past_zeroth_tokens)?; - - if first_token == END_OF_SPEECH as u32 { - return Ok(generated_frames); - } - - // Use code predictor for all codebooks - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - let last_hidden = lm_hidden.i((0..1, (seq_len - 1)..seq_len, ..))?; - - let frame_codes = self - .code_predictor - .predict(&last_hidden, first_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(first_token); - - // === Subsequent iterations: one token at a time === - for _step in 1..self.config.max_new_tokens { - // Check for cancellation each iteration - if self.is_cancelled() { - tracing::info!("TTS generation cancelled after {} frames", generated_frames.len()); - break; - } - - let past_len = kv_caches[0].seq_len(); - - // Input: just the last generated zeroth codebook token - let last_token = *past_zeroth_tokens.last().unwrap(); - let token_tensor = Tensor::from_vec( - vec![last_token as i64], - (1, 1), - self.device, - )? - .to_dtype(DType::I64)?; - - // Single-token attention mask - let attention_mask = - Qwen3Model::make_causal_mask(1, past_len, DType::F32, self.device)?; - - let logits = - self.model - .forward(&token_tensor, &mut kv_caches, Some(&attention_mask))?; - - let next_logits = logits.i((0, 0, ..))?; // [vocab_size] - let next_token = self.sample_token(&next_logits, &past_zeroth_tokens)?; - - if next_token == END_OF_SPEECH as u32 { - break; - } - - // Predict all codebooks for this frame - let lm_hidden = self - .model - .last_hidden_state() - .ok_or_else(|| candle_core::Error::Msg("no hidden state".to_string()))?; - - let frame_codes = self - .code_predictor - .predict(&lm_hidden, next_token, self.device)?; - generated_frames.push(frame_codes); - past_zeroth_tokens.push(next_token); - } - - Ok(generated_frames) - } - - /// Sample a token from logits. - fn sample_token(&self, logits: &Tensor, past_tokens: &[u32]) -> Result<u32> { - let mut logits_vec: Vec<f32> = logits.to_vec1()?; - - // Apply repetition penalty - if self.config.repetition_penalty != 1.0 { - for &token in past_tokens { - let idx = token as usize; - if idx < logits_vec.len() { - let score = logits_vec[idx]; - logits_vec[idx] = if score < 0.0 { - score * self.config.repetition_penalty - } else { - score / self.config.repetition_penalty - }; - } - } - } - - if self.config.top_k == 0 || self.config.temperature == 0.0 { - // Greedy: argmax - let (max_idx, _) = logits_vec - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| { - a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal) - }) - .unwrap_or((0, &0.0)); - Ok(max_idx as u32) - } else { - // Top-k sampling with temperature - let temperature = self.config.temperature; - - // Apply temperature - for v in logits_vec.iter_mut() { - *v /= temperature; - } - - // Sort indices by logit value (descending) - let mut indexed: Vec<(usize, f32)> = - logits_vec.iter().enumerate().map(|(i, &v)| (i, v)).collect(); - indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - // Keep only top-k - let k = self.config.top_k.min(indexed.len()); - let top_k = &indexed[..k]; - - // Softmax over top-k - let max_val = top_k[0].1; - let exp_sum: f32 = top_k.iter().map(|(_, v)| (*v - max_val).exp()).collect::<Vec<_>>().iter().sum(); - let probs: Vec<(usize, f32)> = top_k - .iter() - .map(|(i, v)| (*i, (*v - max_val).exp() / exp_sum)) - .collect(); - - // Sample from distribution (simple linear scan) - let r: f32 = random_float(); - let mut cumulative = 0.0; - for (idx, prob) in &probs { - cumulative += prob; - if cumulative >= r { - return Ok(*idx as u32); - } - } - - // Fallback to highest probability - Ok(probs[0].0 as u32) - } - } - - /// Decode all frames in batch (non-streaming). - fn decode_batch( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - // Transpose frames: [num_frames, num_codebooks] -> [num_codebooks, num_frames] - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frames { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("speech decoder failed: {e}")))?; - - Ok(vec![AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: true, - }]) - } - - /// Decode frames incrementally (streaming). - fn decode_streaming( - &self, - frames: &[Vec<u32>], - ) -> std::result::Result<Vec<AudioChunk>, TtsError> { - let mut chunks: Vec<AudioChunk> = Vec::new(); - - // Decode in groups of frames for efficiency - let chunk_size = 10; // ~800ms per chunk at 12.5Hz - let num_codebooks = self.speech_tokenizer.num_codebooks(); - - for (chunk_idx, frame_chunk) in frames.chunks(chunk_size).enumerate() { - // Check for cancellation between streaming chunks - if self.is_cancelled() { - tracing::info!("TTS streaming decode cancelled after {} chunks", chunks.len()); - if let Some(last) = chunks.last_mut() { - last.is_final = true; - } - return Ok(chunks); - } - - let is_last = (chunk_idx + 1) * chunk_size >= frames.len(); - - // Transpose chunk frames - let mut codes_by_codebook: Vec<Vec<u32>> = vec![Vec::new(); num_codebooks]; - for frame in frame_chunk { - for (cb_idx, &code) in frame.iter().enumerate() { - if cb_idx < num_codebooks { - codes_by_codebook[cb_idx].push(code); - } - } - } - - let samples = self - .speech_tokenizer - .decode(&codes_by_codebook) - .map_err(|e| TtsError::Inference(format!("streaming decode failed: {e}")))?; - - chunks.push(AudioChunk { - samples, - sample_rate: SAMPLE_RATE, - is_final: is_last, - }); - } - - Ok(chunks) - } -} - -/// Simple pseudo-random float in [0, 1) using thread-local state. -/// Uses a basic xorshift for reproducibility without external deps. -fn random_float() -> f32 { - use std::cell::Cell; - thread_local! { - static STATE: Cell<u64> = Cell::new(0x12345678_9ABCDEF0); - } - - STATE.with(|s| { - let mut x = s.get(); - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - s.set(x); - (x as f32) / (u64::MAX as f32) - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generation_config_default() { - let config = GenerationConfig::default(); - assert_eq!(config.max_new_tokens, 2048); - assert_eq!(config.top_k, 0); - assert_eq!(config.temperature, 1.0); - assert_eq!(config.repetition_penalty, 1.2); - assert!(!config.streaming); - } - - #[test] - fn test_random_float_range() { - for _ in 0..100 { - let r = random_float(); - assert!(r >= 0.0); - assert!(r < 1.0); - } - } - - #[test] - fn test_special_tokens() { - assert_eq!(BOS_TOKEN_ID, 151_643); - assert_eq!(EOS_TOKEN_ID, 151_645); - assert_eq!(START_OF_SPEECH, 151_668); - assert_eq!(END_OF_SPEECH, 151_669); - } -} diff --git a/makima/src/tts/qwen3/mod.rs b/makima/src/tts/qwen3/mod.rs deleted file mode 100644 index fc6c472..0000000 --- a/makima/src/tts/qwen3/mod.rs +++ /dev/null @@ -1,317 +0,0 @@ -//! Qwen3-TTS — Pure Rust implementation using candle. -//! -//! Implements Qwen3-TTS-12Hz-0.6B-Base for text-to-speech synthesis -//! with voice cloning support. No Python, no ONNX — pure Rust inference -//! via the candle ML framework. -//! -//! # Architecture -//! -//! The model has three components: -//! - **Language Model** (28-layer transformer): generates zeroth codebook tokens -//! - **Code Predictor** (5-layer MTP): predicts remaining 15 codebook layers -//! - **Speech Tokenizer** (ConvNet codec): encodes/decodes audio ↔ codes -//! -//! # Usage -//! -//! ```rust,no_run -//! use makima::tts::qwen3::Qwen3Tts; -//! use candle_core::Device; -//! -//! let device = Device::Cpu; -//! let tts = Qwen3Tts::from_pretrained(None, &device).unwrap(); -//! // Use via TtsEngine trait or direct API -//! ``` - -pub mod code_predictor; -pub mod config; -pub mod generate; -pub mod model; -pub mod speech_tokenizer; - -use std::path::{Path, PathBuf}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use candle_core::{DType, Device}; -use candle_nn::VarBuilder; -use hf_hub::api::sync::Api; -use tokenizers::Tokenizer; - -use self::code_predictor::CodePredictor; -use self::config::Qwen3TtsConfig; -use self::generate::{GenerationConfig, GenerationContext}; -use self::model::Qwen3Model; -use self::speech_tokenizer::SpeechTokenizer; -use crate::tts::{AudioChunk, TtsEngine, TtsError, SAMPLE_RATE}; - -/// HuggingFace model IDs. -const LM_MODEL_ID: &str = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"; -const TOKENIZER_MODEL_ID: &str = "Qwen/Qwen3-TTS-Tokenizer-12Hz"; -const DEFAULT_MODEL_DIR: &str = "models/qwen3-tts"; - -/// Qwen3-TTS engine — pure Rust candle-based inference. -pub struct Qwen3Tts { - /// The 28-layer language model. - model: Qwen3Model, - /// Multi-token prediction code predictor. - code_predictor: CodePredictor, - /// Speech tokenizer (encoder + decoder + RVQ). - speech_tokenizer: SpeechTokenizer, - /// Text tokenizer. - tokenizer: Tokenizer, - /// Model configuration. - config: Qwen3TtsConfig, - /// Compute device (CPU/CUDA/Metal). - device: Device, - /// Whether the model is fully loaded and ready. - ready: AtomicBool, -} - -// SAFETY: All fields are either Send+Sync or behind appropriate synchronization. -// candle tensors are Send+Sync, Tokenizer is Send+Sync, AtomicBool is Send+Sync. -unsafe impl Send for Qwen3Tts {} -unsafe impl Sync for Qwen3Tts {} - -impl Qwen3Tts { - /// Load from a local directory or download from HuggingFace. - pub fn from_pretrained( - model_dir: Option<&str>, - device: &Device, - ) -> Result<Self, TtsError> { - let model_path = PathBuf::from(model_dir.unwrap_or(DEFAULT_MODEL_DIR)); - - if !model_path.exists() { - Self::download_models(&model_path)?; - } - - Self::load_from_path(&model_path, device) - } - - /// Load all model components from a local directory. - pub fn load_from_path(model_dir: &Path, device: &Device) -> Result<Self, TtsError> { - let dtype = DType::F32; // Use F32 for CPU; BF16/F16 for GPU - - // Load configuration - let config_path = model_dir.join("config.json"); - let config = if config_path.exists() { - Qwen3TtsConfig::from_json_path(&config_path)? - } else { - Qwen3TtsConfig::default() - }; - - // Load text tokenizer (supports both tokenizer.json and vocab.json+merges.txt formats) - let tokenizer_json_path = model_dir.join("tokenizer.json"); - let tokenizer = if tokenizer_json_path.exists() { - Tokenizer::from_file(&tokenizer_json_path) - .map_err(|e| TtsError::Tokenizer(format!("failed to load tokenizer.json: {e}")))? - } else { - // Fall back to vocab.json + merges.txt (HuggingFace Qwen3-TTS format) - let vocab_path = model_dir.join("vocab.json"); - let merges_path = model_dir.join("merges.txt"); - - if !vocab_path.exists() || !merges_path.exists() { - return Err(TtsError::Tokenizer(format!( - "tokenizer files not found: need either tokenizer.json or vocab.json+merges.txt in {}", - model_dir.display() - ))); - } - - tokenizers::Tokenizer::from_file(&vocab_path) - .or_else(|_| { - // Build BPE tokenizer from vocab and merges - use tokenizers::models::bpe::BPE; - let bpe = BPE::from_file(&vocab_path.to_string_lossy(), &merges_path.to_string_lossy()) - .build() - .map_err(|e| TtsError::Tokenizer(format!("failed to build BPE tokenizer: {e}")))?; - Ok(Tokenizer::new(bpe)) - }) - .map_err(|e: TtsError| TtsError::Tokenizer(format!("failed to load tokenizer: {e}")))? - }; - - // Load LM weights from safetensors - let lm_weights_path = model_dir.join("model.safetensors"); - let lm_data = std::fs::read(&lm_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read LM weights from {}: {e}", - lm_weights_path.display() - )) - })?; - let lm_vb = VarBuilder::from_buffered_safetensors( - lm_data, - dtype, - device, - ).map_err(|e| TtsError::ModelLoad(format!("failed to create LM VarBuilder: {e}")))?; - - // Build language model - let model = Qwen3Model::new(&config.lm, lm_vb.clone()).map_err(|e| { - TtsError::ModelLoad(format!("failed to build LM model: {e}")) - })?; - - // Build code predictor (weights are in the same safetensors file) - let code_predictor = - CodePredictor::new(&config.code_predictor, &config.lm, lm_vb).map_err(|e| { - TtsError::ModelLoad(format!("failed to build code predictor: {e}")) - })?; - - // Load speech tokenizer from separate safetensors - let st_weights_path = model_dir.join("speech_tokenizer.safetensors"); - let st_data = std::fs::read(&st_weights_path).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to read speech tokenizer weights from {}: {e}", - st_weights_path.display() - )) - })?; - let st_vb = VarBuilder::from_buffered_safetensors( - st_data, - dtype, - device, - ).map_err(|e| { - TtsError::ModelLoad(format!( - "failed to create speech tokenizer VarBuilder: {e}" - )) - })?; - - let speech_tokenizer = - SpeechTokenizer::new(&config.speech_tokenizer, st_vb, device).map_err(|e| { - TtsError::ModelLoad(format!("failed to build speech tokenizer: {e}")) - })?; - - Ok(Self { - model, - code_predictor, - speech_tokenizer, - tokenizer, - config, - device: device.clone(), - ready: AtomicBool::new(true), - }) - } - - /// Generate audio from text with optional voice reference. - pub fn generate_speech( - &self, - text: &str, - reference_audio: Option<&[f32]>, - gen_config: Option<GenerationConfig>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - let config = gen_config.unwrap_or_default(); - - let ctx = GenerationContext::new( - &self.model, - &self.code_predictor, - &self.speech_tokenizer, - &self.tokenizer, - &self.device, - config, - cancel_flag, - ); - - ctx.generate(text, reference_audio) - } - - /// Download model files from HuggingFace Hub. - fn download_models(target_dir: &Path) -> Result<(), TtsError> { - std::fs::create_dir_all(target_dir)?; - - let api = Api::new().map_err(|e| TtsError::ModelLoad(e.to_string()))?; - - // Download LM model files - println!("Downloading Qwen3-TTS language model..."); - let lm_repo = api.model(LM_MODEL_ID.to_string()); - - // Note: HuggingFace repo has vocab.json + merges.txt instead of tokenizer.json - let lm_files = [ - "model.safetensors", - "config.json", - "vocab.json", - "merges.txt", - "tokenizer_config.json", - ]; - - for file in &lm_files { - println!(" Downloading {file}..."); - let downloaded = lm_repo - .get(file) - .map_err(|e| TtsError::ModelLoad(format!("failed to download {file}: {e}")))?; - - let target = target_dir.join(file); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - } - - // Download speech tokenizer - println!("Downloading Qwen3-TTS speech tokenizer..."); - let st_repo = api.model(TOKENIZER_MODEL_ID.to_string()); - - let st_file = "model.safetensors"; - let downloaded = st_repo - .get(st_file) - .map_err(|e| { - TtsError::ModelLoad(format!("failed to download speech tokenizer: {e}")) - })?; - - let target = target_dir.join("speech_tokenizer.safetensors"); - if !target.exists() { - std::fs::copy(&downloaded, &target)?; - } - - println!("All models downloaded to {}", target_dir.display()); - Ok(()) - } - - /// Get the model configuration. - pub fn config(&self) -> &Qwen3TtsConfig { - &self.config - } - - /// Get the compute device. - pub fn device(&self) -> &Device { - &self.device - } -} - -#[async_trait::async_trait] -impl TtsEngine for Qwen3Tts { - async fn generate( - &self, - text: &str, - reference_audio: Option<&[f32]>, - _reference_sample_rate: Option<u32>, - cancel_flag: Option<Arc<AtomicBool>>, - ) -> Result<Vec<AudioChunk>, TtsError> { - // Note: reference audio should already be resampled to 24kHz - // by the caller. If a different sample rate is provided, - // the caller should resample using `resample_to_24k()`. - self.generate_speech(text, reference_audio, None, cancel_flag) - } - - fn is_ready(&self) -> bool { - self.ready.load(Ordering::Relaxed) - } - - fn sample_rate(&self) -> u32 { - SAMPLE_RATE - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let config = Qwen3TtsConfig::default(); - assert_eq!(config.lm.hidden_size, 1024); - assert_eq!(config.lm.num_hidden_layers, 28); - assert_eq!(config.code_predictor.num_code_groups, 16); - assert_eq!(config.speech_tokenizer.sample_rate, 24_000); - } - - #[test] - fn test_model_ids() { - assert_eq!(LM_MODEL_ID, "Qwen/Qwen3-TTS-12Hz-0.6B-Base"); - assert_eq!(TOKENIZER_MODEL_ID, "Qwen/Qwen3-TTS-Tokenizer-12Hz"); - } -} diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs deleted file mode 100644 index e19e5f9..0000000 --- a/makima/src/tts/qwen3/model.rs +++ /dev/null @@ -1,584 +0,0 @@ -//! Qwen3 Language Model transformer backbone. -//! -//! Implements the 28-layer transformer with: -//! - Rotary Position Embeddings (RoPE) -//! - Grouped Query Attention (GQA) — 16 heads, 8 KV heads -//! - SiLU-gated MLP -//! - RMS normalization -//! - KV cache for autoregressive generation -//! -//! Based on the candle-transformers Qwen2 model architecture, -//! extended for Qwen3-TTS. - -use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_no_bias, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; - -use super::config::Qwen3LmConfig; - -// --------------------------------------------------------------------------- -// Rotary Position Embeddings -// --------------------------------------------------------------------------- - -/// Precomputed RoPE sin/cos tables. -#[derive(Debug, Clone)] -pub struct RotaryEmbedding { - cos: Tensor, - sin: Tensor, -} - -impl RotaryEmbedding { - pub fn new(config: &Qwen3LmConfig, dtype: DType, device: &Device) -> Result<Self> { - let head_dim = config.head_dim; - let max_seq = config.max_position_embeddings; - let theta = config.rope_theta; - - let inv_freq: Vec<f32> = (0..head_dim) - .step_by(2) - .map(|i| 1.0 / (theta as f32).powf(i as f32 / head_dim as f32)) - .collect(); - - let inv_freq_tensor = - Tensor::from_vec(inv_freq, (head_dim / 2,), device)?.to_dtype(DType::F32)?; - - let positions: Vec<f32> = (0..max_seq).map(|p| p as f32).collect(); - let positions_tensor = Tensor::from_vec(positions, (max_seq, 1), device)?; - - // [max_seq, head_dim/2] - let freqs = positions_tensor.matmul(&inv_freq_tensor.unsqueeze(0)?)?; - // [max_seq, head_dim] by repeating - let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; - - let cos = emb.cos()?.to_dtype(dtype)?; - let sin = emb.sin()?.to_dtype(dtype)?; - - Ok(Self { cos, sin }) - } - - /// Apply RoPE to query and key tensors. - /// Input shape: [batch, heads, seq_len, head_dim] - pub fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { - let seq_len = q.dim(2)?; - let cos = self.cos.narrow(0, offset, seq_len)?; - let sin = self.sin.narrow(0, offset, seq_len)?; - - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // [1, 1, seq, dim] - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; - - let q_rotated = Self::rotate_half(q, &cos, &sin)?; - let k_rotated = Self::rotate_half(k, &cos, &sin)?; - - Ok((q_rotated, k_rotated)) - } - - fn rotate_half(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { - let half_dim = x.dim(D::Minus1)? / 2; - let x1 = x.narrow(D::Minus1, 0, half_dim)?; - let x2 = x.narrow(D::Minus1, half_dim, half_dim)?; - - // [-x2, x1] concatenated - let neg_x2 = x2.neg()?; - let rotated = Tensor::cat(&[&neg_x2, &x1], D::Minus1)?; - - // x * cos + rotated * sin - let result = x.broadcast_mul(cos)?.broadcast_add(&rotated.broadcast_mul(sin)?)?; - Ok(result) - } -} - -// --------------------------------------------------------------------------- -// KV Cache -// --------------------------------------------------------------------------- - -/// Per-layer key-value cache for autoregressive generation. -#[derive(Debug, Clone)] -pub struct KvCache { - key: Option<Tensor>, - value: Option<Tensor>, -} - -impl KvCache { - pub fn new() -> Self { - Self { - key: None, - value: None, - } - } - - /// Append new key/value tensors and return the full cached sequence. - /// Input shapes: [batch, num_kv_heads, new_seq_len, head_dim] - pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> { - let (full_key, full_value) = match (&self.key, &self.value) { - (Some(prev_k), Some(prev_v)) => { - let k = Tensor::cat(&[prev_k, key], 2)?; - let v = Tensor::cat(&[prev_v, value], 2)?; - (k, v) - } - _ => (key.clone(), value.clone()), - }; - - self.key = Some(full_key.clone()); - self.value = Some(full_value.clone()); - - Ok((full_key, full_value)) - } - - /// Current cached sequence length. - pub fn seq_len(&self) -> usize { - self.key - .as_ref() - .map(|k| k.dim(2).unwrap_or(0)) - .unwrap_or(0) - } - - /// Reset the cache. - pub fn reset(&mut self) { - self.key = None; - self.value = None; - } -} - -// --------------------------------------------------------------------------- -// Attention -// --------------------------------------------------------------------------- - -/// Multi-head attention with GQA and RoPE. -pub struct Qwen3Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - q_norm: RmsNorm, - k_norm: RmsNorm, - num_heads: usize, - num_kv_heads: usize, - head_dim: usize, - num_kv_groups: usize, -} - -impl Qwen3Attention { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let hidden = config.hidden_size; - let num_heads = config.num_attention_heads; - let num_kv_heads = config.num_key_value_heads; - let head_dim = config.head_dim; - - let q_proj = linear_no_bias(hidden, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden, vb.pp("o_proj"))?; - - let q_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = rms_norm(head_dim, config.rms_norm_eps, vb.pp("k_norm"))?; - - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - q_norm, - k_norm, - num_heads, - num_kv_heads, - head_dim, - num_kv_groups: config.num_kv_groups(), - }) - } - - /// Forward pass with KV cache and RoPE. - /// Input: [batch, seq_len, hidden_size] - /// Returns: [batch, seq_len, hidden_size] - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let (batch, seq_len, _) = hidden_states.dims3()?; - let offset = kv_cache.seq_len(); - - // Project Q, K, V - let q = self.q_proj.forward(hidden_states)?; - let k = self.k_proj.forward(hidden_states)?; - let v = self.v_proj.forward(hidden_states)?; - - // Reshape: [batch, seq, heads*dim] -> [batch, heads, seq, dim] - let q = q - .reshape((batch, seq_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((batch, seq_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - - // Apply QK normalization (Qwen3 specific) - let q = self.apply_head_norm(&q, &self.q_norm)?; - let k = self.apply_head_norm(&k, &self.k_norm)?; - - // Apply RoPE - let (q, k) = rope.apply(&q, &k, offset)?; - - // Update KV cache - let (k, v) = kv_cache.append(&k, &v)?; - - // Expand KV heads for GQA: [batch, kv_heads, seq, dim] -> [batch, heads, seq, dim] - let k = self.repeat_kv(&k)?; - let v = self.repeat_kv(&v)?; - - // Scaled dot-product attention - let scale = (self.head_dim as f64).sqrt(); - let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? / scale)?; - - let attn_weights = match attention_mask { - Some(mask) => attn_weights.broadcast_add(mask)?, - None => attn_weights, - }; - - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - - // Attention output - let attn_output = attn_weights.matmul(&v)?; - - // [batch, heads, seq, dim] -> [batch, seq, heads*dim] - let attn_output = attn_output - .transpose(1, 2)? - .reshape((batch, seq_len, self.num_heads * self.head_dim))?; - - self.o_proj.forward(&attn_output) - } - - /// Apply RMS norm per-head. - fn apply_head_norm(&self, x: &Tensor, norm: &RmsNorm) -> Result<Tensor> { - let (b, h, s, d) = x.dims4()?; - // Reshape to [b*h*s, d] for norm, then back - let flat = x.reshape((b * h * s, d))?; - let normed = norm.forward(&flat)?; - normed.reshape((b, h, s, d)) - } - - /// Repeat KV heads for GQA. - fn repeat_kv(&self, x: &Tensor) -> Result<Tensor> { - if self.num_kv_groups == 1 { - return Ok(x.clone()); - } - let (batch, num_kv_heads, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((batch, num_kv_heads, self.num_kv_groups, seq_len, head_dim))? - .reshape((batch, self.num_heads, seq_len, head_dim))?; - Ok(x) - } -} - -// --------------------------------------------------------------------------- -// MLP -// --------------------------------------------------------------------------- - -/// SiLU-gated feed-forward network. -pub struct Qwen3Mlp { - gate_proj: Linear, - up_proj: Linear, - down_proj: Linear, -} - -impl Qwen3Mlp { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let hidden = config.hidden_size; - let intermediate = config.intermediate_size; - - let gate_proj = linear_no_bias(hidden, intermediate, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden, intermediate, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate, hidden, vb.pp("down_proj"))?; - - Ok(Self { - gate_proj, - up_proj, - down_proj, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let gate = self.gate_proj.forward(x)?; - let gate = candle_nn::Activation::Silu.forward(&gate)?; - let up = self.up_proj.forward(x)?; - let hidden = (gate * up)?; - self.down_proj.forward(&hidden) - } -} - -// --------------------------------------------------------------------------- -// Transformer Layer -// --------------------------------------------------------------------------- - -/// A single Qwen3 transformer decoder layer. -pub struct Qwen3DecoderLayer { - self_attn: Qwen3Attention, - mlp: Qwen3Mlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, -} - -impl Qwen3DecoderLayer { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let self_attn = Qwen3Attention::new(config, vb.pp("self_attn"))?; - let mlp = Qwen3Mlp::new(config, vb.pp("mlp"))?; - let input_layernorm = - rms_norm(config.hidden_size, config.rms_norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = rms_norm( - config.hidden_size, - config.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - - Ok(Self { - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - }) - } - - pub fn forward( - &self, - hidden_states: &Tensor, - rope: &RotaryEmbedding, - kv_cache: &mut KvCache, - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - // Pre-norm attention - let residual = hidden_states; - let hidden_states = self.input_layernorm.forward(hidden_states)?; - let hidden_states = - self.self_attn - .forward(&hidden_states, rope, kv_cache, attention_mask)?; - let hidden_states = (residual + hidden_states)?; - - // Pre-norm MLP - let residual = &hidden_states; - let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?; - let hidden_states = self.mlp.forward(&hidden_states)?; - let output = (residual + hidden_states)?; - - Ok(output) - } -} - -// --------------------------------------------------------------------------- -// Full Model -// --------------------------------------------------------------------------- - -/// The complete Qwen3 language model for TTS. -/// -/// Architecture: -/// - Token embedding layer -/// - 28 transformer decoder layers -/// - Final RMS normalization -/// - LM head (projects to vocab) -pub struct Qwen3Model { - embed_tokens: Embedding, - layers: Vec<Qwen3DecoderLayer>, - norm: RmsNorm, - lm_head: Linear, - rope: RotaryEmbedding, - config: Qwen3LmConfig, - /// Last hidden states (before lm_head), used by code predictor. - last_hidden: std::cell::RefCell<Option<Tensor>>, -} - -impl Qwen3Model { - pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - // HuggingFace Qwen3-TTS uses "talker.model.*" prefix - let talker_vb = vb.pp("talker"); - let model_vb = talker_vb.pp("model"); - - // Text embedding (called "text_embedding" in HF, not "embed_tokens") - let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?; - - let mut layers = Vec::with_capacity(config.num_hidden_layers); - for i in 0..config.num_hidden_layers { - let layer = Qwen3DecoderLayer::new(config, model_vb.pp(format!("layers.{i}")))?; - layers.push(layer); - } - - let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?; - - // Codec head (called "codec_head" in HF, not "lm_head") - let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?; - - let dtype = vb.dtype(); - let device = vb.device().clone(); - let rope = RotaryEmbedding::new(config, dtype, &device)?; - - Ok(Self { - embed_tokens, - layers, - norm, - lm_head, - rope, - config: config.clone(), - last_hidden: std::cell::RefCell::new(None), - }) - } - - /// Forward pass through the full model. - /// - /// `input_ids`: [batch, seq_len] — token IDs - /// `kv_caches`: per-layer KV caches - /// `attention_mask`: optional causal mask [batch, 1, seq_len, total_seq_len] - /// - /// Returns logits: [batch, seq_len, vocab_size] - pub fn forward( - &self, - input_ids: &Tensor, - kv_caches: &mut [KvCache], - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let mut hidden_states = self.embed_tokens.forward(input_ids)?; - - for (i, layer) in self.layers.iter().enumerate() { - hidden_states = - layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; - } - - hidden_states = self.norm.forward(&hidden_states)?; - - // Store last hidden state for code predictor - *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); - - let logits = self.lm_head.forward(&hidden_states)?; - Ok(logits) - } - - /// Forward pass with pre-computed embeddings (for first iteration where - /// text embeddings are concatenated with audio features). - /// - /// `inputs_embeds`: [batch, seq_len, hidden_size] - pub fn forward_embeds( - &self, - inputs_embeds: &Tensor, - kv_caches: &mut [KvCache], - attention_mask: Option<&Tensor>, - ) -> Result<Tensor> { - let mut hidden_states = inputs_embeds.clone(); - - for (i, layer) in self.layers.iter().enumerate() { - hidden_states = - layer.forward(&hidden_states, &self.rope, &mut kv_caches[i], attention_mask)?; - } - - hidden_states = self.norm.forward(&hidden_states)?; - - *self.last_hidden.borrow_mut() = Some(hidden_states.clone()); - - let logits = self.lm_head.forward(&hidden_states)?; - Ok(logits) - } - - /// Get the last hidden states (for the code predictor). - pub fn last_hidden_state(&self) -> Option<Tensor> { - self.last_hidden.borrow().clone() - } - - /// Number of transformer layers. - pub fn num_layers(&self) -> usize { - self.config.num_hidden_layers - } - - /// Hidden size. - pub fn hidden_size(&self) -> usize { - self.config.hidden_size - } - - /// Get token embedding layer (for input preparation). - pub fn embed_tokens(&self) -> &Embedding { - &self.embed_tokens - } - - /// Create a causal attention mask. - pub fn make_causal_mask( - seq_len: usize, - past_len: usize, - dtype: DType, - device: &Device, - ) -> Result<Tensor> { - let total_len = past_len + seq_len; - - if seq_len == 1 { - // Single token: no masking needed (can attend to everything) - return Tensor::zeros((1, 1, 1, total_len), dtype, device); - } - - // Full causal mask: lower triangular - let mask: Vec<f32> = (0..seq_len) - .flat_map(|i| { - (0..total_len).map(move |j| { - if j <= past_len + i { - 0.0 - } else { - f32::NEG_INFINITY - } - }) - }) - .collect(); - - Tensor::from_vec(mask, (1, 1, seq_len, total_len), device)?.to_dtype(dtype) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_kv_cache() { - let device = Device::Cpu; - let mut cache = KvCache::new(); - assert_eq!(cache.seq_len(), 0); - - let k = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); - let v = Tensor::zeros((1, 8, 5, 128), DType::F32, &device).unwrap(); - let (fk, _fv) = cache.append(&k, &v).unwrap(); - assert_eq!(cache.seq_len(), 5); - assert_eq!(fk.dim(2).unwrap(), 5); - - let k2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); - let v2 = Tensor::zeros((1, 8, 1, 128), DType::F32, &device).unwrap(); - let (fk2, _fv2) = cache.append(&k2, &v2).unwrap(); - assert_eq!(cache.seq_len(), 6); - assert_eq!(fk2.dim(2).unwrap(), 6); - - cache.reset(); - assert_eq!(cache.seq_len(), 0); - } - - #[test] - fn test_causal_mask_single_token() { - let mask = Qwen3Model::make_causal_mask(1, 10, DType::F32, &Device::Cpu).unwrap(); - assert_eq!(mask.dims(), &[1, 1, 1, 11]); - // All zeros — single token can attend to everything - let sum: f32 = mask.sum_all().unwrap().to_scalar().unwrap(); - assert_eq!(sum, 0.0); - } - - #[test] - fn test_causal_mask_multi_token() { - let mask = Qwen3Model::make_causal_mask(3, 0, DType::F32, &Device::Cpu).unwrap(); - assert_eq!(mask.dims(), &[1, 1, 3, 3]); - // Upper triangle should be -inf - let data: Vec<f32> = mask.flatten_all().unwrap().to_vec1().unwrap(); - // Row 0: [0, -inf, -inf] - assert_eq!(data[0], 0.0); - assert!(data[1].is_infinite() && data[1] < 0.0); - assert!(data[2].is_infinite() && data[2] < 0.0); - // Row 1: [0, 0, -inf] - assert_eq!(data[3], 0.0); - assert_eq!(data[4], 0.0); - assert!(data[5].is_infinite() && data[5] < 0.0); - // Row 2: [0, 0, 0] - assert_eq!(data[6], 0.0); - assert_eq!(data[7], 0.0); - assert_eq!(data[8], 0.0); - } -} diff --git a/makima/src/tts/qwen3/speech_tokenizer.rs b/makima/src/tts/qwen3/speech_tokenizer.rs deleted file mode 100644 index 86e00f2..0000000 --- a/makima/src/tts/qwen3/speech_tokenizer.rs +++ /dev/null @@ -1,613 +0,0 @@ -//! Speech Tokenizer — ConvNet encoder/decoder with RVQ codebooks. -//! -//! Two sub-components: -//! -//! **Encoder** (voice cloning): converts reference audio waveform to discrete -//! multi-codebook tokens via a causal 1D ConvNet + RVQ. -//! -//! **Decoder** (audio synthesis): reconstructs waveform from discrete codebook -//! indices via embedding lookup + causal 1D ConvNet. -//! -//! The speech tokenizer is a separate model (~682MB) loaded from -//! `Qwen/Qwen3-TTS-Tokenizer-12Hz`. - -use candle_core::{Device, Module, Result, Tensor, D}; -use candle_nn::{ - conv1d, embedding, linear_no_bias, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder, -}; - -use super::config::SpeechTokenizerConfig; - -// --------------------------------------------------------------------------- -// Weight-Normalized Conv1d -// --------------------------------------------------------------------------- - -/// A 1D convolution with optional weight normalization and activation. -pub struct ConvBlock { - conv: Conv1d, - activation: ConvActivation, -} - -#[derive(Debug, Clone, Copy)] -pub enum ConvActivation { - None, - Elu, - Tanh, -} - -impl ConvBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - stride: usize, - padding: usize, - dilation: usize, - activation: ConvActivation, - vb: VarBuilder, - ) -> Result<Self> { - let config = Conv1dConfig { - stride, - padding, - dilation, - groups: 1, - }; - let conv = conv1d(in_channels, out_channels, kernel_size, config, vb.pp("conv"))?; - - Ok(Self { conv, activation }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let out = self.conv.forward(x)?; - match self.activation { - ConvActivation::None => Ok(out), - ConvActivation::Elu => elu(&out, 1.0), - ConvActivation::Tanh => out.tanh(), - } - } -} - -/// ELU activation: x if x >= 0, alpha * (exp(x) - 1) if x < 0 -fn elu(x: &Tensor, alpha: f64) -> Result<Tensor> { - let zeros = x.zeros_like()?; - let positive = x.maximum(&zeros)?; - let negative_mask = x.lt(&zeros)?.to_dtype(x.dtype())?; - let exp_x = x.exp()?; - let one = Tensor::ones_like(&exp_x)?; - let negative = ((exp_x - one)? * alpha)?.broadcast_mul(&negative_mask)?; - positive + negative -} - -// --------------------------------------------------------------------------- -// Residual Unit -// --------------------------------------------------------------------------- - -/// Residual convolutional unit with dilated convolutions. -pub struct ResidualUnit { - conv1: ConvBlock, - conv2: ConvBlock, -} - -impl ResidualUnit { - pub fn new( - channels: usize, - dilation: usize, - vb: VarBuilder, - ) -> Result<Self> { - // Dilated causal conv (kernel=7, dilation varies) - let padding = (7 - 1) * dilation / 2; // causal-ish padding - let conv1 = ConvBlock::new( - channels, - channels, - 7, - 1, - padding, - dilation, - ConvActivation::Elu, - vb.pp("block.0"), - )?; - - // Pointwise conv (kernel=1) - let conv2 = ConvBlock::new( - channels, - channels, - 1, - 1, - 0, - 1, - ConvActivation::Elu, - vb.pp("block.1"), - )?; - - Ok(Self { conv1, conv2 }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let residual = x; - let out = self.conv1.forward(x)?; - let out = self.conv2.forward(&out)?; - // Match sequence lengths if needed (causal conv may change length) - let out_len = out.dim(D::Minus1)?; - let res_len = residual.dim(D::Minus1)?; - if out_len != res_len { - let start = res_len.saturating_sub(out_len); - let residual = residual.narrow(D::Minus1, start, out_len)?; - residual + out - } else { - residual + out - } - } -} - -// --------------------------------------------------------------------------- -// Encoder Block -// --------------------------------------------------------------------------- - -/// Encoder downsampling block: residual units + strided conv. -pub struct EncoderBlock { - residual_units: Vec<ResidualUnit>, - downsample: ConvBlock, -} - -impl EncoderBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - stride: usize, - num_residuals: usize, - vb: VarBuilder, - ) -> Result<Self> { - let mut residual_units = Vec::with_capacity(num_residuals); - for i in 0..num_residuals { - let dilation = 3usize.pow(i as u32); // 1, 3, 9 - let unit = ResidualUnit::new(in_channels, dilation, vb.pp(format!("residuals.{i}")))?; - residual_units.push(unit); - } - - // Strided downsampling convolution - let kernel_size = stride * 2; - let padding = stride / 2; - let downsample = ConvBlock::new( - in_channels, - out_channels, - kernel_size, - stride, - padding, - 1, - ConvActivation::Elu, - vb.pp("downsample"), - )?; - - Ok(Self { - residual_units, - downsample, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let mut out = x.clone(); - for unit in &self.residual_units { - out = unit.forward(&out)?; - } - self.downsample.forward(&out) - } -} - -// --------------------------------------------------------------------------- -// Decoder Block -// --------------------------------------------------------------------------- - -/// Decoder upsampling block: transposed conv + residual units. -pub struct DecoderBlock { - upsample: ConvBlock, - residual_units: Vec<ResidualUnit>, -} - -impl DecoderBlock { - pub fn new( - in_channels: usize, - out_channels: usize, - stride: usize, - num_residuals: usize, - vb: VarBuilder, - ) -> Result<Self> { - // Strided upsampling (transpose conv simulated by regular conv + padding) - let kernel_size = stride * 2; - let padding = stride / 2; - let upsample = ConvBlock::new( - in_channels, - out_channels, - kernel_size, - 1, // stride=1 for output; upsample via repeat/interpolation - padding, - 1, - ConvActivation::Elu, - vb.pp("upsample"), - )?; - - let mut residual_units = Vec::with_capacity(num_residuals); - for i in 0..num_residuals { - let dilation = 3usize.pow(i as u32); - let unit = - ResidualUnit::new(out_channels, dilation, vb.pp(format!("residuals.{i}")))?; - residual_units.push(unit); - } - - Ok(Self { - upsample, - residual_units, - }) - } - - pub fn forward(&self, x: &Tensor) -> Result<Tensor> { - let mut out = self.upsample.forward(x)?; - for unit in &self.residual_units { - out = unit.forward(&out)?; - } - Ok(out) - } -} - -// --------------------------------------------------------------------------- -// RVQ Codebook -// --------------------------------------------------------------------------- - -/// Residual Vector Quantization codebook. -/// -/// Contains `num_codebooks` embedding tables, each mapping -/// `codebook_size` indices to `codebook_dim`-dimensional vectors. -pub struct RvqCodebook { - codebooks: Vec<Embedding>, - num_codebooks: usize, - #[allow(dead_code)] - codebook_dim: usize, -} - -impl RvqCodebook { - pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder) -> Result<Self> { - let mut codebooks = Vec::with_capacity(config.num_codebooks); - for i in 0..config.num_codebooks { - let cb = embedding( - config.codebook_size, - config.codebook_dim, - vb.pp(format!("codebooks.{i}")), - )?; - codebooks.push(cb); - } - - Ok(Self { - codebooks, - num_codebooks: config.num_codebooks, - codebook_dim: config.codebook_dim, - }) - } - - /// Look up codebook embeddings for all codebook layers. - /// - /// `codes`: [num_codebooks, seq_len] — codebook indices per layer - /// Returns: [1, codebook_dim, seq_len] — sum of all codebook embeddings - pub fn decode(&self, codes: &[Vec<u32>], device: &Device) -> Result<Tensor> { - assert_eq!(codes.len(), self.num_codebooks, "Expected {} codebook layers", self.num_codebooks); - - let seq_len = codes[0].len(); - let mut sum: Option<Tensor> = None; - - for (i, code_layer) in codes.iter().enumerate() { - assert_eq!(code_layer.len(), seq_len, "Codebook layer {i} length mismatch"); - - let indices = Tensor::from_vec( - code_layer.clone(), - (1, seq_len), - device, - )?; - - // [1, seq_len, codebook_dim] - let emb = self.codebooks[i].forward(&indices)?; - - sum = Some(match sum { - Some(prev) => (prev + emb)?, - None => emb, - }); - } - - // [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len] - let result = sum.unwrap().transpose(1, 2)?; - Ok(result) - } - - /// Number of codebooks. - pub fn num_codebooks(&self) -> usize { - self.num_codebooks - } -} - -// --------------------------------------------------------------------------- -// Speech Tokenizer (Encoder + Decoder) -// --------------------------------------------------------------------------- - -/// The complete speech tokenizer with encoder and decoder. -pub struct SpeechTokenizer { - /// Encoder: waveform -> latent (for voice cloning). - encoder_input_conv: ConvBlock, - encoder_blocks: Vec<EncoderBlock>, - encoder_output_conv: ConvBlock, - - /// RVQ codebooks for quantization. - codebook: RvqCodebook, - - /// Decoder: codes -> waveform. - decoder_input_conv: ConvBlock, - decoder_blocks: Vec<DecoderBlock>, - decoder_output_conv: ConvBlock, - - /// Projection from codebook dim to decoder hidden channels. - decoder_proj: Linear, - - config: SpeechTokenizerConfig, - device: Device, -} - -impl SpeechTokenizer { - /// Load the speech tokenizer from safetensors. - pub fn new(config: &SpeechTokenizerConfig, vb: VarBuilder, device: &Device) -> Result<Self> { - let hidden = config.hidden_channels; // 512 - - // ===== Encoder ===== - // Input: [batch, 1, samples] -> [batch, hidden/8, ...] - let encoder_input_conv = ConvBlock::new( - 1, - hidden / 8, // 64 - 7, - 1, - 3, - 1, - ConvActivation::Elu, - vb.pp("encoder.input_conv"), - )?; - - // Downsampling blocks with increasing channels - let strides = [8, 5, 4, 3]; // Total downsampling: 8*5*4*3 = 480 - let channels = [hidden / 8, hidden / 4, hidden / 2, hidden]; // 64, 128, 256, 512 - let mut encoder_blocks = Vec::with_capacity(strides.len()); - for (i, (&stride, &out_ch)) in strides.iter().zip(channels.iter().skip(0)).enumerate() { - let in_ch = if i == 0 { hidden / 8 } else { channels[i - 1] }; - let block = EncoderBlock::new( - in_ch, - out_ch, - stride, - 3, // 3 residual units per block - vb.pp(format!("encoder.blocks.{i}")), - )?; - encoder_blocks.push(block); - } - - // Encoder output projection to codebook dim - let encoder_output_conv = ConvBlock::new( - hidden, - config.codebook_dim, - 3, - 1, - 1, - 1, - ConvActivation::None, - vb.pp("encoder.output_conv"), - )?; - - // ===== RVQ Codebook ===== - let codebook = RvqCodebook::new(config, vb.pp("quantizer"))?; - - // ===== Decoder ===== - // Projection from codebook dim to decoder hidden - let decoder_proj = linear_no_bias( - config.codebook_dim, - hidden, - vb.pp("decoder.proj"), - )?; - - // Input conv - let decoder_input_conv = ConvBlock::new( - hidden, - hidden, - 7, - 1, - 3, - 1, - ConvActivation::Elu, - vb.pp("decoder.input_conv"), - )?; - - // Upsampling blocks (reverse order of encoder) - let dec_strides = [3, 4, 5, 8]; - let dec_channels = [hidden, hidden / 2, hidden / 4, hidden / 8]; // 512, 256, 128, 64 - let mut decoder_blocks = Vec::with_capacity(dec_strides.len()); - for (i, (&stride, &out_ch)) in dec_strides.iter().zip(dec_channels.iter().skip(0)).enumerate() - { - let in_ch = if i == 0 { hidden } else { dec_channels[i - 1] }; - let block = DecoderBlock::new( - in_ch, - out_ch, - stride, - 3, - vb.pp(format!("decoder.blocks.{i}")), - )?; - decoder_blocks.push(block); - } - - // Output conv: hidden/8 -> 1 channel (waveform) - let decoder_output_conv = ConvBlock::new( - hidden / 8, - 1, - 7, - 1, - 3, - 1, - ConvActivation::Tanh, - vb.pp("decoder.output_conv"), - )?; - - Ok(Self { - encoder_input_conv, - encoder_blocks, - encoder_output_conv, - codebook, - decoder_input_conv, - decoder_blocks, - decoder_output_conv, - decoder_proj, - config: config.clone(), - device: device.clone(), - }) - } - - /// Encode reference audio waveform to discrete codebook tokens. - /// - /// `audio`: [num_samples] — mono 24kHz audio - /// Returns: Vec of `num_codebooks` vectors, each containing token indices. - pub fn encode(&self, audio: &[f32]) -> Result<Vec<Vec<u32>>> { - // [1, 1, num_samples] - let x = Tensor::from_vec(audio.to_vec(), (1, 1, audio.len()), &self.device)?; - - // Run encoder - let mut hidden = self.encoder_input_conv.forward(&x)?; - for block in &self.encoder_blocks { - hidden = block.forward(&hidden)?; - } - let latent = self.encoder_output_conv.forward(&hidden)?; - - // latent: [1, codebook_dim, seq_len] - // Quantize via nearest-neighbor lookup in each codebook - let seq_len = latent.dim(D::Minus1)?; - let mut all_codes = Vec::with_capacity(self.config.num_codebooks); - - // Residual quantization: subtract each codebook's contribution - let mut residual = latent.clone(); - - for cb_idx in 0..self.config.num_codebooks { - // residual: [1, codebook_dim, seq_len] -> find nearest codebook entry per timestep - let codes = self.quantize_layer(&residual, cb_idx, seq_len)?; - - // Look up the quantized vectors and subtract from residual - let code_indices = - Tensor::from_vec(codes.clone(), (1, seq_len), &self.device)?; - let quantized = self.codebook.codebooks[cb_idx].forward(&code_indices)?; - // quantized: [1, seq_len, codebook_dim] -> [1, codebook_dim, seq_len] - let quantized = quantized.transpose(1, 2)?; - residual = (residual - quantized)?; - - all_codes.push(codes); - } - - Ok(all_codes) - } - - /// Quantize a single RVQ layer by finding the nearest codebook entry. - fn quantize_layer( - &self, - residual: &Tensor, - codebook_idx: usize, - _seq_len: usize, - ) -> Result<Vec<u32>> { - // residual: [1, codebook_dim, seq_len] - // codebook weights: [codebook_size, codebook_dim] - let cb_weight = self.codebook.codebooks[codebook_idx] - .embeddings() - .clone(); // [codebook_size, codebook_dim] - - // Transpose residual: [1, seq_len, codebook_dim] - let residual_t = residual.transpose(1, 2)?.squeeze(0)?; // [seq_len, codebook_dim] - - // Compute L2 distances: ||r - c||^2 = ||r||^2 - 2*r*c^T + ||c||^2 - let r_sq = residual_t.sqr()?.sum(D::Minus1)?; // [seq_len] - let c_sq = cb_weight.sqr()?.sum(D::Minus1)?; // [codebook_size] - let rc = residual_t.matmul(&cb_weight.t()?)?; // [seq_len, codebook_size] - - let r_sq = r_sq.unsqueeze(1)?; // [seq_len, 1] - let c_sq = c_sq.unsqueeze(0)?; // [1, codebook_size] - - let distances = (r_sq.broadcast_add(&c_sq)? - (rc * 2.0)?)?; // [seq_len, codebook_size] - - // Argmin per timestep - let indices = distances.argmin(D::Minus1)?; // [seq_len] - let codes: Vec<u32> = indices.to_vec1()?; - - Ok(codes) - } - - /// Decode discrete codebook tokens to audio waveform. - /// - /// `codes`: Vec of `num_codebooks` vectors of token indices. - /// Returns: Vec<f32> — mono 24kHz audio samples. - pub fn decode(&self, codes: &[Vec<u32>]) -> Result<Vec<f32>> { - // Look up and sum all codebook embeddings - let embeddings = self.codebook.decode(codes, &self.device)?; - // embeddings: [1, codebook_dim, seq_len] - - // Project to decoder hidden size: [1, seq_len, codebook_dim] -> [1, seq_len, hidden] - let emb_t = embeddings.transpose(1, 2)?; // [1, seq_len, codebook_dim] - let projected = self.decoder_proj.forward(&emb_t)?; // [1, seq_len, hidden] - let mut hidden = projected.transpose(1, 2)?; // [1, hidden, seq_len] - - // Run decoder - hidden = self.decoder_input_conv.forward(&hidden)?; - for block in &self.decoder_blocks { - hidden = block.forward(&hidden)?; - } - let waveform = self.decoder_output_conv.forward(&hidden)?; - - // [1, 1, num_samples] -> Vec<f32> - let samples: Vec<f32> = waveform.flatten_all()?.to_vec1()?; - Ok(samples) - } - - /// Decode a single frame's codes to audio samples (for streaming). - /// - /// `frame_codes`: [num_codebooks] — one token per codebook for a single frame - /// Returns: audio samples for this frame (~1920 samples at 24kHz / 12.5Hz) - pub fn decode_frame(&self, frame_codes: &[u32]) -> Result<Vec<f32>> { - let codes: Vec<Vec<u32>> = frame_codes.iter().map(|&c| vec![c]).collect(); - self.decode(&codes) - } - - /// Get the number of codebooks. - pub fn num_codebooks(&self) -> usize { - self.config.num_codebooks - } - - /// Get the output sample rate. - pub fn sample_rate(&self) -> u32 { - self.config.sample_rate - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_elu_positive() { - let device = Device::Cpu; - let x = Tensor::from_vec(vec![1.0f32, 2.0, 3.0], (3,), &device).unwrap(); - let result = elu(&x, 1.0).unwrap(); - let values: Vec<f32> = result.to_vec1().unwrap(); - assert!((values[0] - 1.0).abs() < 1e-5); - assert!((values[1] - 2.0).abs() < 1e-5); - } - - #[test] - fn test_elu_negative() { - let device = Device::Cpu; - let x = Tensor::from_vec(vec![-1.0f32], (1,), &device).unwrap(); - let result = elu(&x, 1.0).unwrap(); - let values: Vec<f32> = result.to_vec1().unwrap(); - // ELU(-1) = exp(-1) - 1 ≈ -0.6321 - assert!((values[0] - (-0.6321)).abs() < 0.01); - } - - #[test] - fn test_speech_tokenizer_config() { - let config = SpeechTokenizerConfig::default(); - assert_eq!(config.num_codebooks, 16); - assert_eq!(config.codebook_size, 2048); - assert_eq!(config.sample_rate, 24_000); - } -} |
