diff options
| author | soryu <soryu@soryu.co> | 2025-12-21 01:27:02 +0000 |
|---|---|---|
| committer | soryu <soryu@soryu.co> | 2025-12-23 14:47:18 +0000 |
| commit | 3c696cfc9005e73be5ed46f8941dfc8f0aca7102 (patch) | |
| tree | 497bffd67001501a003739cfe0bb790502ffd50a /vendor | |
| parent | 55cacf6e1a087c0fa6950a1ddeb09060f787e541 (diff) | |
| download | soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.tar.gz soryu-3c696cfc9005e73be5ed46f8941dfc8f0aca7102.zip | |
Create container image and move parakeet fork to vendor dir
Diffstat (limited to 'vendor')
28 files changed, 5892 insertions, 0 deletions
diff --git a/vendor/parakeet-rs/.cargo-ok b/vendor/parakeet-rs/.cargo-ok new file mode 100644 index 0000000..5f8b795 --- /dev/null +++ b/vendor/parakeet-rs/.cargo-ok @@ -0,0 +1 @@ +{"v":1}
\ No newline at end of file diff --git a/vendor/parakeet-rs/.github/workflows/rust.yml b/vendor/parakeet-rs/.github/workflows/rust.yml new file mode 100644 index 0000000..c7f9726 --- /dev/null +++ b/vendor/parakeet-rs/.github/workflows/rust.yml @@ -0,0 +1,43 @@ +name: Rust + +on: + push: + branches: [ "master" ] + tags: + - 'v*' + pull_request: + branches: [ "master" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose + + release: + needs: test + if: startsWith(github.ref, 'refs/tags/v') + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Create Release + uses: softprops/action-gh-release@v2 + with: + generate_release_notes: true + draft: false + prerelease: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/vendor/parakeet-rs/.gitignore b/vendor/parakeet-rs/.gitignore new file mode 100644 index 0000000..fd045f6 --- /dev/null +++ b/vendor/parakeet-rs/.gitignore @@ -0,0 +1,12 @@ +/target +DS_Store +*.DS_Store +.idea +.vscode +*.log +*.onnx +*.json +*.onnx_data +*.wav +*.txt +*.onnx.data
\ No newline at end of file diff --git a/vendor/parakeet-rs/Cargo.lock b/vendor/parakeet-rs/Cargo.lock new file mode 100644 index 0000000..7f0b9f8 --- /dev/null +++ b/vendor/parakeet-rs/Cargo.lock @@ -0,0 +1,1688 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + +[[package]] +name = "bitflags" +version = "2.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + +[[package]] +name = "cc" +version = "1.2.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac9fe6cdbb24b6ade63616c0a0688e45bb56732262c158df3c0c4bea4ca47cb7" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" + +[[package]] +name = "flate2" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "generic-array" +version = "0.14.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indenter" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5" + +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "js-sys" +version = "0.3.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" + +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "log" +version = "0.4.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "openssl" +version = "0.10.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "ort" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec 2.0.0-alpha.10", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" +dependencies = [ + "flate2", + "glob", + "pkg-config", + "sha2", + "tar", + "ureq", +] + +[[package]] +name = "parakeet-rs" +version = "0.2.5" +dependencies = [ + "eyre", + "hound", + "ndarray", + "ort", + "rustfft", + "serde", + "serde_json", + "tokenizers", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + +[[package]] +name = "proc-macro2" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokenizers" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.16", + "indicatif", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +dependencies = [ + "once_cell", +] + +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" + +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec 1.15.1", +] + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "ureq" +version = "3.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99ba1025f18a4a3fc3e9b48c868e9beb4f24f4b4b1a325bada26bd4119f46537" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pemfile", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b4531c118335662134346048ddb0e54cc86bd7e81866757873055f0e38f5d2" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05d651ec480de84b762e7be71e6efa7461699c19d9e2c272c8d93455f567786e" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" diff --git a/vendor/parakeet-rs/Cargo.toml b/vendor/parakeet-rs/Cargo.toml new file mode 100644 index 0000000..d3f83a6 --- /dev/null +++ b/vendor/parakeet-rs/Cargo.toml @@ -0,0 +1,97 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "parakeet-rs" +version = "0.2.5" +authors = ["altunenes"] +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "Fast ASR & Speaker Diarization with NVIDIA Parakeet via ONNX" +readme = "README.md" +keywords = [ + "speech-recognition", + "asr", + "parakeet", + "onnx", + "nvidia", +] +categories = [ + "multimedia::audio", + "science", +] +license = "MIT OR Apache-2.0" +repository = "https://github.com/altunenes/parakeet-rs" + +[features] +coreml = ["ort/coreml"] +cpu = [] +cuda = ["ort/cuda"] +default = ["cpu"] +directml = ["ort/directml"] +openvino = ["ort/openvino"] +rocm = ["ort/rocm"] +sortformer = [] +tensorrt = ["ort/tensorrt"] +webgpu = ["ort/webgpu"] + +[lib] +name = "parakeet_rs" +path = "src/lib.rs" + +[[example]] +name = "diarization" +path = "examples/diarization.rs" + +[[example]] +name = "raw" +path = "examples/raw.rs" + +[[example]] +name = "streaming" +path = "examples/streaming.rs" + +[[example]] +name = "transcribe" +path = "examples/transcribe.rs" + +[dependencies.eyre] +version = "0.6" + +[dependencies.hound] +version = "3.5" + +[dependencies.ndarray] +version = "0.16" + +[dependencies.ort] +version = "2.0.0-rc.10" +features = ["download-binaries"] + +[dependencies.rustfft] +version = "6.4" + +[dependencies.serde] +version = "1.0" +features = ["derive"] + +[dependencies.serde_json] +version = "1.0" + +[dependencies.tokenizers] +version = "0.20" + +[dev-dependencies] diff --git a/vendor/parakeet-rs/Cargo.toml.orig b/vendor/parakeet-rs/Cargo.toml.orig new file mode 100644 index 0000000..4d91e18 --- /dev/null +++ b/vendor/parakeet-rs/Cargo.toml.orig @@ -0,0 +1,54 @@ +[package] +name = "parakeet-rs" +version = "0.2.5" +edition = "2021" +authors = ["altunenes"] +description = "Fast ASR & Speaker Diarization with NVIDIA Parakeet via ONNX" +repository = "https://github.com/altunenes/parakeet-rs" +license = "MIT OR Apache-2.0" +keywords = ["speech-recognition", "asr", "parakeet", "onnx", "nvidia"] +categories = ["multimedia::audio", "science"] + +[lib] +name = "parakeet_rs" +path = "src/lib.rs" + +[[example]] +name = "transcribe" +path = "examples/transcribe.rs" + +[[example]] +name = "diarization" +path = "examples/diarization.rs" + +[[example]] +name = "raw" +path = "examples/raw.rs" + +[[example]] +name = "streaming" +path = "examples/streaming.rs" + +[dependencies] +ort = { version = "2.0.0-rc.10", features = ["download-binaries"] } +hound = "3.5" +eyre = "0.6" +ndarray = "0.16" +tokenizers = "0.20" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +rustfft = "6.4" + +[dev-dependencies] + +[features] +default = ["cpu"] +cpu = [] +cuda = ["ort/cuda"] +tensorrt = ["ort/tensorrt"] +coreml = ["ort/coreml"] +directml = ["ort/directml"] +rocm = ["ort/rocm"] +openvino = ["ort/openvino"] +webgpu = ["ort/webgpu"] +sortformer = []
\ No newline at end of file diff --git a/vendor/parakeet-rs/LICENSE b/vendor/parakeet-rs/LICENSE new file mode 100644 index 0000000..31ce7ce --- /dev/null +++ b/vendor/parakeet-rs/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Enes Altun + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/parakeet-rs/README.md b/vendor/parakeet-rs/README.md new file mode 100644 index 0000000..75dfe85 --- /dev/null +++ b/vendor/parakeet-rs/README.md @@ -0,0 +1,122 @@ +# parakeet-rs +[](https://github.com/altunenes/parakeet-rs/actions/workflows/rust.yml) +[](https://crates.io/crates/parakeet-rs) + +Fast speech recognition with NVIDIA's Parakeet models via ONNX Runtime. +Note: CoreML doesn't stable with this model - stick w/ CPU (or other GPU EP like CUDA). But its incredible fast in my Mac M3 16gb' CPU compared to Whisper metal! :-) + +## Models + +**CTC (English-only)**: Fast & accurate +```rust +use parakeet_rs::Parakeet; + +let mut parakeet = Parakeet::from_pretrained(".", None)?; +let result = parakeet.transcribe_file("audio.wav")?; +println!("{}", result.text); + +// Or transcribe in-memory audio +// let result = parakeet.transcribe_samples(audio, 16000, 1)?; + +// Token-level timestamps +for token in result.tokens { + println!("[{:.3}s - {:.3}s] {}", token.start, token.end, token.text); +} +``` + +**TDT (Multilingual)**: 25 languages with auto-detection +```rust +use parakeet_rs::ParakeetTDT; + +let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?; +let result = parakeet.transcribe_file("audio.wav")?; +println!("{}", result.text); + +// Or transcribe in-memory audio +// let result = parakeet.transcribe_samples(audio, 16000, 1)?; + +// Token-level timestamps +for token in result.tokens { + println!("[{:.3}s - {:.3}s] {}", token.start, token.end, token.text); +} +``` + +**EOU (Streaming)**: Real-time ASR with end-of-utterance detection +```rust +use parakeet_rs::ParakeetEOU; + +let mut parakeet = ParakeetEOU::from_pretrained("./eou", None)?; + +// Prepare your audio (Vec<f32>, 16kHz mono, normalized) +let audio: Vec<f32> = /* your audio samples */; + +// Process in 160ms chunks for streaming +const CHUNK_SIZE: usize = 2560; // 160ms at 16kHz +for chunk in audio.chunks(CHUNK_SIZE) { + let text = parakeet.transcribe(chunk, false)?; + print!("{}", text); +} +``` + +**Sortformer v2 (Speaker Diarization)**: Streaming 4-speaker diarization +```toml +parakeet-rs = { version = "0.2", features = ["sortformer"] } +``` +```rust +use parakeet_rs::sortformer::{Sortformer, DiarizationConfig}; + +let mut sortformer = Sortformer::with_config( + "diar_streaming_sortformer_4spk-v2.onnx", + None, + DiarizationConfig::callhome(), // or dihard3(),custom() +)?; +let segments = sortformer.diarize(audio, 16000, 1)?; +for seg in segments { + println!("Speaker {} [{:.2}s - {:.2}s]", seg.speaker_id, seg.start, seg.end); +} +``` +See `examples/diarization.rs` for combining with TDT transcription. + + +## Setup + +**CTC**: Download from [HuggingFace](https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main/onnx): `model.onnx`, `model.onnx_data`, `tokenizer.json` + +**TDT**: Download from [HuggingFace](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx): `encoder-model.onnx`, `encoder-model.onnx.data`, `decoder_joint-model.onnx`, `vocab.txt` + +**EOU**: Download from [HuggingFace](https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx): `encoder.onnx`, `decoder_joint.onnx`, `tokenizer.json` + +**Diarization (Sortformer v2)**: Download from [HuggingFace](https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx): `diar_streaming_sortformer_4spk-v2.onnx` + +Quantized versions available (int8). All files must be in the same directory. + +GPU support (auto-falls back to CPU if fails): +```toml +parakeet-rs = { version = "0.1", features = ["cuda"] } # or tensorrt, webgpu, directml, rocm +``` + +```rust +use parakeet_rs::{Parakeet, ExecutionConfig, ExecutionProvider}; + +let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::Cuda); +let mut parakeet = Parakeet::from_pretrained(".", Some(config))?; +``` + + +## Features + +- [CTC: English with punctuation & capitalization](https://huggingface.co/nvidia/parakeet-ctc-0.6b) +- [TDT: Multilingual (auto lang detection) ](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) +- [EOU: Streaming ASR with end-of-utterance detection](https://huggingface.co/nvidia/parakeet_realtime_eou_120m-v1) +- [Sortformer v2: Streaming speaker diarization (up to 4 speakers)](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2) +- Token-level timestamps (CTC, TDT) + +## Notes + +- Audio: 16kHz mono WAV (16-bit PCM or 32-bit float) + +## License + +Code: MIT OR Apache-2.0 + +FYI: The Parakeet ONNX models (downloaded separately from HuggingFace) are licensed under **CC-BY-4.0** by NVIDIA. This library does not distribute the models. diff --git a/vendor/parakeet-rs/examples/diarization.rs b/vendor/parakeet-rs/examples/diarization.rs new file mode 100644 index 0000000..5982ecb --- /dev/null +++ b/vendor/parakeet-rs/examples/diarization.rs @@ -0,0 +1,137 @@ +/* +Speaker Diarization with NVIDIA Sortformer v2 (Streaming) + +Download the Sortformer v2 model: +https://huggingface.co/altunenes/parakeet-rs/blob/main/diar_streaming_sortformer_4spk-v2.onnx +Download test audio: +wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav + +Usage: +cargo run --example diarization --features sortformer 6_speakers.wav + +NOTE: This example combines two NVIDIA models: +- Parakeet-TDT: Provides transcription with sentence-level timestamps +- Sortformer v2: Provides streaming speaker identification (4 speakers max) +- We use TDT's sentence timestamps + Sortformer's speaker IDs +- Even if Sortformer can't detect a segment, we still get the transcription (marked UNKNOWN) +- For more information: +https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 + +WARNING: Sortformer handles long audio natively (streaming), but TDT has sequence +length limitations (~8-10 minutes max). For production use with long audio files, +run Sortformer on the full audio for diarization, then chunk the audio into +~5-minute segments for TDT transcription, and map the results back together. +*/ + +#[cfg(feature = "sortformer")] +use parakeet_rs::sortformer::{DiarizationConfig, Sortformer}; +#[cfg(feature = "sortformer")] +use parakeet_rs::TimestampMode; +#[cfg(feature = "sortformer")] +use hound; +#[cfg(feature = "sortformer")] +use std::env; +#[cfg(feature = "sortformer")] +use std::time::Instant; + +#[allow(unreachable_code)] +fn main() -> Result<(), Box<dyn std::error::Error>> { + #[cfg(not(feature = "sortformer"))] + { + eprintln!("Error: This example requires the 'sortformer' feature."); + eprintln!("Please run with: cargo run --example diarization --features sortformer <audio.wav>"); + return Err("sortformer feature not enabled".into()); + } + + #[cfg(feature = "sortformer")] + { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = args.get(1) + .expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>"); + + println!("{}", "=".repeat(80)); + println!("Step 1/3: Loading audio..."); + + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + let audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32; + println!("Loaded {} samples ({} Hz, {} channels, {:.1}s)", + audio.len(), spec.sample_rate, spec.channels, duration); + + println!("{}", "=".repeat(80)); + println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)..."); + + // Create Sortformer with default config (callhome) + let mut sortformer = Sortformer::with_config( + "diar_streaming_sortformer_4spk-v2.onnx", + None, // default exec config + DiarizationConfig::callhome(), + )?; + + let speaker_segments = sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?; + + println!("Found {} speaker segments from Sortformer", speaker_segments.len()); + + // Print raw diarization segments + println!("\nRaw diarization segments:"); + for seg in &speaker_segments { + println!(" [{:06.2}s - {:06.2}s] Speaker {}", seg.start, seg.end, seg.speaker_id); + } + + println!("\n{}", "=".repeat(80)); + println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n"); + + // Use TDT for transcription with sentence-level timestamps + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; + + // Transcribe with Sentences mode (TDT provides punctuation for proper segmentation) + if let Ok(result) = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences)) { + // For each sentence from TDT, find the corresponding speaker from Sortformer + for segment in &result.tokens { + // Find speaker with maximum overlap + let speaker = speaker_segments + .iter() + .filter_map(|s| { + // Calculate overlap between transcription and diarization segment + let overlap_start = segment.start.max(s.start); + let overlap_end = segment.end.min(s.end); + let overlap = (overlap_end - overlap_start).max(0.0); + if overlap > 0.0 { + Some((s.speaker_id, overlap)) + } else { + None + } + }) + .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .map(|(id, _)| format!("Speaker {}", id)) + .unwrap_or_else(|| "UNKNOWN".to_string()); + + println!("[{:.2}s - {:.2}s] {}: {}", + segment.start, segment.end, speaker, segment.text); + } + } + + println!("\n{}", "=".repeat(80)); + let elapsed = start_time.elapsed(); + println!("\n✓ Diarization and transcription completed in {:.2}s", elapsed.as_secs_f32()); + println!("• UNKNOWN: Segments where no speaker was detected by Sortformer"); + println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)"); + + Ok(()) + } + + #[cfg(not(feature = "sortformer"))] + unreachable!() +} diff --git a/vendor/parakeet-rs/examples/raw.rs b/vendor/parakeet-rs/examples/raw.rs new file mode 100644 index 0000000..a1a2adc --- /dev/null +++ b/vendor/parakeet-rs/examples/raw.rs @@ -0,0 +1,86 @@ +/* +Demonstrates using transcribe_samples() + +This example shows manual audio loading and calling transcribe_samples() directly +with sample_rate and channels instead of using transcribe_file() + +Usage: +cargo run --example raw 6_speakers.wav +cargo run --example raw 6_speakers.wav tdt + +WARNING: TDT model has sequence length limitations (~8-10 minutes max). +For longer audio files, you must split into chunks (e.g., 5-minute segments) +and transcribe each chunk separately. Attempting to transcribe 25+ minute +audio files in one call will cause ONNX runtime errors. +Otherwise you will likely get a error like: +"Error: Ort(Error { code: RuntimeException, msg: "Non-zero status code returned while running Add node. Name:'/layers.0/self_attn/Add_2' Status Message: /Users/runner/work/ort-artifacts/ort-artifacts/onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. })" +*/ + +use parakeet_rs::{Parakeet, ParakeetTDT, TimestampMode}; +use std::env; +use std::time::Instant; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = if args.len() > 1 { + &args[1] + } else { + "6_speakers.wav" + }; + + let use_tdt = args.len() > 2 && args[2] == "tdt"; + + // Load audio manually using hound (or any other audio library) + // remember if you use raw audio API, you need to handle audio preprocessing yourself! + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + println!("Audio info: {}Hz, {} channel(s)", spec.sample_rate, spec.channels); + + let audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + if use_tdt { + println!("Loading TDT model..."); + let mut parakeet = ParakeetTDT::from_pretrained("./tdt", None)?; + + // Use transcribe_samples() with raw parameters and timestamp mode + let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Sentences))?; + + println!("{}", result.text); + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + } else { + println!("Loading CTC model..."); + let mut parakeet = Parakeet::from_pretrained(".", None)?; + + // CTC model doesn't predict punctuation (lowercase alphabet only) + // This means no sentence boundaries. we use Words mode instead of Sentences + let result = parakeet.transcribe_samples(audio, spec.sample_rate, spec.channels, Some(TimestampMode::Words))?; + + println!("{}", result.text); + + // Access word-level timestamps (showing first 10 for brevity) + // Note: CTC generates word-level timestamps but cannot segment into sentences + // due to lack of punctuation prediction - this is a model limitation if I not mistake + println!("\nWords (first 10):"); + for word in result.tokens.iter().take(10) { + println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); + } + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + + Ok(()) +} diff --git a/vendor/parakeet-rs/examples/streaming.rs b/vendor/parakeet-rs/examples/streaming.rs new file mode 100644 index 0000000..f5d36c9 --- /dev/null +++ b/vendor/parakeet-rs/examples/streaming.rs @@ -0,0 +1,129 @@ +/* +Demonstrates streaming ASR with Parakeet RealTime EOU + +Download models files from: +https://huggingface.co/altunenes/parakeet-rs/tree/main/realtime_eou_120m-v1-onnx + +This example +- Maintains 4-second ring buffer for feature extraction context +- Processes 160ms chunks (2560 samples at 16kHz) +- Extracts features from full buffer, then slices last 25 frames +- Encoder receives: 9 frames (pre-encode cache) + 16 frames (new) = 25 total +- Cache states (cache_last_channel/time) maintain temporal context + +Model files required in ./fullstr/: + - encoder.onnx (cache_aware_stream_step export) + - decoder_joint.onnx + - tokenizer.json + +Additional notes: +let reset_on_eou: bool = false; +I must admit that this is not work very well on my real world tests :/ + + +Usage: +cargo run --release --example streaming <audio.wav> +*/ + +use hound; +use parakeet_rs::ParakeetEOU; +use std::env; +use std::time::Instant; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + + let args: Vec<String> = env::args().collect(); + let audio_path = args + .get(1) + .expect("Usage: cargo run --release --example streaming <audio.wav>"); + + println!("Loading model from ./fullstr..."); + let mut parakeet = ParakeetEOU::from_pretrained("./fullstr", None)?; + + println!("Loading audio: {}", audio_path); + let mut reader = hound::WavReader::open(audio_path)?; + let spec = reader.spec(); + + let mut audio: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<Result<Vec<_>, _>>()?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<Result<Vec<_>, _>>()?, + }; + + if spec.sample_rate != 16000 { + return Err(format!( + "Expected 16kHz audio, got {}Hz. Please resample first.", + spec.sample_rate + ) + .into()); + } + + if spec.channels > 1 { + audio = audio + .chunks(spec.channels as usize) + .map(|chunk| chunk.iter().sum::<f32>() / spec.channels as f32) + .collect(); + } + + let max_val = audio.iter().fold(0.0f32, |a, &b| a.max(b.abs())); + if max_val > 1e-6 { + let norm_factor = max_val + 1e-5; + for sample in &mut audio { + *sample /= norm_factor; + } + } + + let duration = audio.len() as f32 / 16000.0; + // 160ms at 16kHz + const CHUNK_SIZE: usize = 2560; + let reset_on_eou: bool = false; + + println!("Streaming transcription (160ms chunks with 4s buffer)...\n"); + + let mut full_text = String::new(); + + for chunk in audio.chunks(CHUNK_SIZE) { + let chunk_vec = if chunk.len() < CHUNK_SIZE { + let mut padded = chunk.to_vec(); + padded.resize(CHUNK_SIZE, 0.0); + padded + } else { + chunk.to_vec() + }; + + let text = parakeet.transcribe(&chunk_vec, reset_on_eou)?; + if !text.is_empty() { + print!("{}", text); + std::io::Write::flush(&mut std::io::stdout())?; + full_text.push_str(&text); + } + } + + println!("\n\nFlushing decoder..."); + let silence = vec![0.0f32; CHUNK_SIZE]; + for _ in 0..3 { + let text = parakeet.transcribe(&silence, reset_on_eou)?; + if !text.is_empty() { + print!("{}", text); + std::io::Write::flush(&mut std::io::stdout())?; + full_text.push_str(&text); + } + } + + println!("\n\nFinal Transcription:\n{}", full_text.trim()); + + let elapsed = start_time.elapsed(); + println!( + "\nTranscription completed in {:.2}s (audio: {:.2}s, RTF: {:.2}x)", + elapsed.as_secs_f32(), + duration, + duration / elapsed.as_secs_f32() + ); + + Ok(()) +} diff --git a/vendor/parakeet-rs/examples/transcribe.rs b/vendor/parakeet-rs/examples/transcribe.rs new file mode 100644 index 0000000..685e8de --- /dev/null +++ b/vendor/parakeet-rs/examples/transcribe.rs @@ -0,0 +1,106 @@ +/* +transcribes entire audio, no diarization +wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav + +CTC (English-only): +cargo run --example transcribe 6_speakers.wav + +TDT (Multilingual): +cargo run --example transcribe 6_speakers.wav tdt + +NOTE: For manual audio loading without using transcribe_file(), see examples/raw.rs +- Shows transcribe_samples(audio, sample_rate, channels, timestamps) usage + +WARNING: This may fail on very long audio files (>8 min). +For longer audio, use the pyannote example which processes segments, or split your audio into chunks. + +Note: The coreml feature flag is only for reproducing a known ONNX Runtime bug. +Just ignore it :). See: https://github.com/microsoft/onnxruntime/issues/26355 +*/ +use parakeet_rs::{Parakeet, TimestampMode}; +use std::env; +use std::time::Instant; + +#[cfg(feature = "coreml")] +use parakeet_rs::{ExecutionConfig, ExecutionProvider}; + +fn main() -> Result<(), Box<dyn std::error::Error>> { + let start_time = Instant::now(); + let args: Vec<String> = env::args().collect(); + let audio_path = if args.len() > 1 { + &args[1] + } else { + "6_speakers.wav" + }; + + let use_tdt = args.len() > 2 && args[2] == "tdt"; + + // TDT model (multilingual, 25 languages) + if use_tdt { + #[cfg(feature = "coreml")] + { + let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", Some(config))?; + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; + println!("{}", result.text); + + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + return Ok(()); + } + + #[cfg(not(feature = "coreml"))] + { + let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?; + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Sentences))?; + println!("{}", result.text); + + println!("\nSentencess:"); + for segment in result.tokens.iter() { + println!("[{:.2}s - {:.2}s]: {}", segment.start, segment.end, segment.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + return Ok(()); + } + } + + // CTC model (English-only) + #[cfg(feature = "coreml")] + let mut parakeet = { + let config = ExecutionConfig::new().with_execution_provider(ExecutionProvider::CoreML); + Parakeet::from_pretrained(".", Some(config))? + }; + + // Default: CPU execution provider (works correctly) + // Auto-detects model with priority: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx + // Or specify exact model: Parakeet::from_pretrained("model_q4.onnx", None)? + #[cfg(not(feature = "coreml"))] + let mut parakeet = Parakeet::from_pretrained(".", None)?; + + // CTC model doesn't predict punctuation (lowercase alphabet only) + // This means no sentence boundaries - use Words mode instead of Sentences + let result = parakeet.transcribe_file(audio_path, Some(TimestampMode::Words))?; + + // Print transcription + println!("{}", result.text); + + // Access word-level timestamps (showing first 10 for brevity) + // Note: CTC generates word-level timestamps but cannot segment into sentences + // due to lack of punctuation prediction - this is a model limitation + println!("\nWords (first 10):"); + for word in result.tokens.iter().take(10) { + println!("[{:.2}s - {:.2}s]: {}", word.start, word.end, word.text); + } + + let elapsed = start_time.elapsed(); + println!("\n✓ Transcription completed in {:.2}s", elapsed.as_secs_f32()); + + Ok(()) +} diff --git a/vendor/parakeet-rs/src/audio.rs b/vendor/parakeet-rs/src/audio.rs new file mode 100644 index 0000000..84d2616 --- /dev/null +++ b/vendor/parakeet-rs/src/audio.rs @@ -0,0 +1,179 @@ +use crate::config::PreprocessorConfig; +use crate::error::{Error, Result}; +use hound::{WavReader, WavSpec}; +use ndarray::Array2; +use std::f32::consts::PI; +use std::path::Path; + +pub fn load_audio<P: AsRef<Path>>(path: P) -> Result<(Vec<f32>, WavSpec)> { + let mut reader = WavReader::open(path)?; + let spec = reader.spec(); + + let samples: Vec<f32> = match spec.sample_format { + hound::SampleFormat::Float => reader + .samples::<f32>() + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(|e| Error::Audio(format!("Failed to read float samples: {e}")))?, + hound::SampleFormat::Int => reader + .samples::<i16>() + .map(|s| s.map(|s| s as f32 / 32768.0)) + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(|e| Error::Audio(format!("Failed to read int samples: {e}")))?, + }; + + Ok((samples, spec)) +} + +pub fn apply_preemphasis(audio: &[f32], coef: f32) -> Vec<f32> { + let mut result = Vec::with_capacity(audio.len()); + result.push(audio[0]); + + for i in 1..audio.len() { + result.push(audio[i] - coef * audio[i - 1]); + } + + result +} + +fn hann_window(window_length: usize) -> Vec<f32> { + (0..window_length) + .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / (window_length as f32 - 1.0)).cos()) + .collect() +} + +// We use proper FFT here instead of naive DFT because the model was trained +// on correctly computed spectrograms. Naive DFT produces wrong frequency bins +// and the model outputs all blank tokens. RustFFT gives us O(n log n) performance +// and numerically correct results that match what the model expects. +pub fn stft(audio: &[f32], n_fft: usize, hop_length: usize, win_length: usize) -> Array2<f32> { + use rustfft::{num_complex::Complex, FftPlanner}; + + let window = hann_window(win_length); + let num_frames = (audio.len() - win_length) / hop_length + 1; + let freq_bins = n_fft / 2 + 1; + let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames)); + + let mut planner = FftPlanner::<f32>::new(); + let fft = planner.plan_fft_forward(n_fft); + + for frame_idx in 0..num_frames { + let start = frame_idx * hop_length; + + let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft]; + for i in 0..win_length.min(audio.len() - start) { + frame[i] = Complex::new(audio[start + i] * window[i], 0.0); + } + + fft.process(&mut frame); + + for k in 0..freq_bins { + let magnitude = frame[k].norm(); + spectrogram[[k, frame_idx]] = magnitude * magnitude; + } + } + + spectrogram +} + +fn hz_to_mel(freq: f32) -> f32 { + 2595.0 * (1.0 + freq / 700.0).log10() +} + +fn mel_to_hz(mel: f32) -> f32 { + 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0) +} + +fn create_mel_filterbank(n_fft: usize, n_mels: usize, sample_rate: usize) -> Array2<f32> { + let freq_bins = n_fft / 2 + 1; + let mut filterbank = Array2::<f32>::zeros((n_mels, freq_bins)); + + let min_mel = hz_to_mel(0.0); + let max_mel = hz_to_mel(sample_rate as f32 / 2.0); + + let mel_points: Vec<f32> = (0..=n_mels + 1) + .map(|i| mel_to_hz(min_mel + (max_mel - min_mel) * i as f32 / (n_mels + 1) as f32)) + .collect(); + + let freq_bin_width = sample_rate as f32 / n_fft as f32; + + for mel_idx in 0..n_mels { + let left = mel_points[mel_idx]; + let center = mel_points[mel_idx + 1]; + let right = mel_points[mel_idx + 2]; + + for freq_idx in 0..freq_bins { + let freq = freq_idx as f32 * freq_bin_width; + + if freq >= left && freq <= center { + filterbank[[mel_idx, freq_idx]] = (freq - left) / (center - left); + } else if freq > center && freq <= right { + filterbank[[mel_idx, freq_idx]] = (right - freq) / (right - center); + } + } + } + + filterbank +} + +/// Extract mel spectrogram features from raw audio samples. +/// +/// # Arguments +/// +/// * `audio` - Audio samples as f32 values +/// * `sample_rate` - Sample rate in Hz +/// * `channels` - Number of audio channels +/// * `config` - Preprocessor configuration +/// +/// # Returns +/// +/// 2D array of mel spectrogram features (time_steps x feature_size) +pub fn extract_features_raw( + mut audio: Vec<f32>, + sample_rate: u32, + channels: u16, + config: &PreprocessorConfig, +) -> Result<Array2<f32>> { + if sample_rate != config.sampling_rate as u32 { + return Err(Error::Audio(format!( + "Audio sample rate {} doesn't match expected {}. Please resample your audio first.", + sample_rate, config.sampling_rate + ))); + } + + if channels > 1 { + let mono: Vec<f32> = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) + .collect(); + audio = mono; + } + + audio = apply_preemphasis(&audio, config.preemphasis); + + let spectrogram = stft(&audio, config.n_fft, config.hop_length, config.win_length); + + let mel_filterbank = + create_mel_filterbank(config.n_fft, config.feature_size, config.sampling_rate); + let mel_spectrogram = mel_filterbank.dot(&spectrogram); + let mel_spectrogram = mel_spectrogram.mapv(|x| (x.max(1e-10)).ln()); + + let mut mel_spectrogram = mel_spectrogram.t().to_owned(); + + // Normalize each feature dimension to mean=0, std=1 + let num_frames = mel_spectrogram.shape()[0]; + let num_features = mel_spectrogram.shape()[1]; + + for feat_idx in 0..num_features { + let mut column = mel_spectrogram.column_mut(feat_idx); + let mean: f32 = column.iter().sum::<f32>() / num_frames as f32; + let variance: f32 = + column.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32; + let std = variance.sqrt().max(1e-10); + + for val in column.iter_mut() { + *val = (*val - mean) / std; + } + } + + Ok(mel_spectrogram) +} diff --git a/vendor/parakeet-rs/src/config.rs b/vendor/parakeet-rs/src/config.rs new file mode 100644 index 0000000..1dae890 --- /dev/null +++ b/vendor/parakeet-rs/src/config.rs @@ -0,0 +1,51 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PreprocessorConfig { + pub feature_extractor_type: String, + pub feature_size: usize, + pub hop_length: usize, + pub n_fft: usize, + pub padding_side: String, + pub padding_value: f32, + pub preemphasis: f32, + pub processor_class: String, + pub return_attention_mask: bool, + pub sampling_rate: usize, + pub win_length: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelConfig { + pub architectures: Vec<String>, + pub vocab_size: usize, + pub pad_token_id: usize, +} + +impl Default for PreprocessorConfig { + fn default() -> Self { + Self { + feature_extractor_type: "ParakeetFeatureExtractor".to_string(), + feature_size: 80, + hop_length: 160, + n_fft: 512, + padding_side: "right".to_string(), + padding_value: 0.0, + preemphasis: 0.97, + processor_class: "ParakeetProcessor".to_string(), + return_attention_mask: true, + sampling_rate: 16000, + win_length: 400, + } + } +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + architectures: vec!["ParakeetForCTC".to_string()], + vocab_size: 1025, + pad_token_id: 1024, + } + } +} diff --git a/vendor/parakeet-rs/src/decoder.rs b/vendor/parakeet-rs/src/decoder.rs new file mode 100644 index 0000000..6da6d65 --- /dev/null +++ b/vendor/parakeet-rs/src/decoder.rs @@ -0,0 +1,211 @@ +use crate::error::{Error, Result}; +use ndarray::Array2; +use std::path::Path; + +// Token with its timestamp information +// start and end are in seconds +#[derive(Debug, Clone)] +pub struct TimedToken { + pub text: String, + pub start: f32, + pub end: f32, +} + +#[derive(Debug, Clone)] +pub struct TranscriptionResult { + pub text: String, + pub tokens: Vec<TimedToken>, +} + +// CTC decoder for parakeet-ctc-0.6b model with token-level timestamps +pub struct ParakeetDecoder { + tokenizer: tokenizers::Tokenizer, + pad_token_id: usize, +} + +impl ParakeetDecoder { + pub fn from_pretrained<P: AsRef<Path>>(tokenizer_path: P) -> Result<Self> { + let tokenizer_path = tokenizer_path.as_ref(); + + let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) + .map_err(|e| Error::Tokenizer(format!("Failed to load tokenizer: {e}")))?; + + // Hardcoded pad_token_id for Parakeet-CTC-0.6b (constant across all models: please see def configs jsons: https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main) + let pad_token_id = 1024; + + Ok(Self { + tokenizer, + pad_token_id, + }) + } + + pub fn decode(&self, logits: &Array2<f32>) -> Result<String> { + let time_steps = logits.shape()[0]; + + let mut token_ids = Vec::new(); + for t in 0..time_steps { + let logits_t = logits.row(t); + let max_idx = logits_t + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(0); + + token_ids.push(max_idx as u32); + } + + let collapsed = self.ctc_collapse(&token_ids); + + let text = self + .tokenizer + .decode(&collapsed, true) + .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?; + + Ok(text) + } + + fn ctc_collapse(&self, token_ids: &[u32]) -> Vec<u32> { + let mut result = Vec::new(); + let mut prev_token: Option<u32> = None; + + for &token_id in token_ids { + if token_id == self.pad_token_id as u32 { + prev_token = Some(token_id); + continue; + } + + if Some(token_id) != prev_token { + result.push(token_id); + } + + prev_token = Some(token_id); + } + + result + } + + // CTC collapse with frame tracking for timestamps + fn ctc_collapse_with_frames(&self, token_ids: &[(u32, usize)]) -> Vec<(u32, usize, usize)> { + let mut result: Vec<(u32, usize, usize)> = Vec::new(); + let mut prev_token: Option<u32> = None; + + for &(token_id, frame) in token_ids.iter() { + if token_id == self.pad_token_id as u32 { + prev_token = Some(token_id); + continue; + } + + if Some(token_id) != prev_token { + if let Some(prev) = prev_token { + if prev != self.pad_token_id as u32 { + // End previous token + if let Some(last) = result.last_mut() { + last.2 = frame; + } + } + } + // Start new token + result.push((token_id, frame, frame)); + } + + prev_token = Some(token_id); + } + + // Close last token + if let Some(last) = result.last_mut() { + last.2 = token_ids.len(); + } + + result + } + + // Decode with token-level timestamps + // hop_length and sample_rate are needed to convert frames to seconds + pub fn decode_with_timestamps( + &self, + logits: &Array2<f32>, + hop_length: usize, + sample_rate: usize, + ) -> Result<TranscriptionResult> { + let time_steps = logits.shape()[0]; + + let mut token_ids_with_frames = Vec::new(); + for t in 0..time_steps { + let logits_t = logits.row(t); + let max_idx = logits_t + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(0); + + token_ids_with_frames.push((max_idx as u32, t)); + } + + // CTC collapse with frame tracking + let collapsed_with_frames = self.ctc_collapse_with_frames(&token_ids_with_frames); + + // Extract just token IDs for decoding + let token_ids: Vec<u32> = collapsed_with_frames.iter().map(|(id, _, _)| *id).collect(); + + // Decode full text + let full_text = self + .tokenizer + .decode(&token_ids, true) + .map_err(|e| Error::Tokenizer(format!("Failed to decode: {e}")))?; + + // Progressive decode to detect word boundaries + // BPE tokenizers only add spaces when decoding sequences, not individual tokens + let mut timed_tokens = Vec::new(); + let mut prev_decode = String::new(); + + for (i, (_token_id, start_frame, end_frame)) in collapsed_with_frames.iter().enumerate() { + // Decode from start up to and including current token + let token_ids_so_far: Vec<u32> = collapsed_with_frames[0..=i] + .iter() + .map(|(id, _, _)| *id) + .collect(); + + if let Ok(curr_decode) = self.tokenizer.decode(&token_ids_so_far, true) { + // Find what this token added + let added_text = if curr_decode.len() > prev_decode.len() { + &curr_decode[prev_decode.len()..] + } else { + "" + }; + + if !added_text.is_empty() { + let start_time = (*start_frame * hop_length) as f32 / sample_rate as f32; + let end_time = (*end_frame * hop_length) as f32 / sample_rate as f32; + + timed_tokens.push(TimedToken { + text: added_text.to_string(), + start: start_time, + end: end_time, + }); + } + + prev_decode = curr_decode; + } + } + + Ok(TranscriptionResult { + text: full_text, + tokens: timed_tokens, + }) + } + + // Stub - falls back to greedy decoding. Full beam search with language model is TODO. + pub fn decode_with_beam_search( + &self, + logits: &Array2<f32>, + _beam_width: usize, + ) -> Result<String> { + self.decode(logits) + } + + pub fn pad_token_id(&self) -> usize { + self.pad_token_id + } +} diff --git a/vendor/parakeet-rs/src/decoder_tdt.rs b/vendor/parakeet-rs/src/decoder_tdt.rs new file mode 100644 index 0000000..65f576d --- /dev/null +++ b/vendor/parakeet-rs/src/decoder_tdt.rs @@ -0,0 +1,63 @@ +use crate::decoder::TranscriptionResult; +use crate::error::Result; +use crate::vocab::Vocabulary; + +/// TDT greedy decoder for Parakeet TDT models +#[derive(Debug)] +pub struct ParakeetTDTDecoder { + vocab: Vocabulary, +} + +impl ParakeetTDTDecoder { + /// Load decoder from vocab file + pub fn from_vocab(vocab: Vocabulary) -> Self { + Self { vocab } + } + + /// Decode tokens with timestamps + /// For TDT models, greedy decoding is done in the model, here we just convert to text + pub fn decode_with_timestamps( + &self, + tokens: &[usize], + frame_indices: &[usize], + _durations: &[usize], + hop_length: usize, + sample_rate: usize, + ) -> Result<TranscriptionResult> { + let mut result_tokens = Vec::new(); + let mut full_text = String::new(); + // TDT encoder does 8x subsampling + let encoder_stride = 8; + + for (i, &token_id) in tokens.iter().enumerate() { + if let Some(token_text) = self.vocab.id_to_text(token_id) { + let frame = frame_indices[i]; + let start = (frame * encoder_stride * hop_length) as f32 / sample_rate as f32; + let end = if i + 1 < frame_indices.len() { + (frame_indices[i + 1] * encoder_stride * hop_length) as f32 / sample_rate as f32 + } else { + start + 0.01 + }; + + // Handle SentencePiece format (▁ prefix for word start) + let display_text = token_text.replace('▁', " "); + + // Skip special tokens + if !(token_text.starts_with('<') && token_text.ends_with('>') && token_text != "<unk>") { + full_text.push_str(&display_text); + + result_tokens.push(crate::decoder::TimedToken { + text: display_text, + start, + end, + }); + } + } + } + + Ok(TranscriptionResult { + text: full_text.trim().to_string(), + tokens: result_tokens, + }) + } +} diff --git a/vendor/parakeet-rs/src/error.rs b/vendor/parakeet-rs/src/error.rs new file mode 100644 index 0000000..690e0e5 --- /dev/null +++ b/vendor/parakeet-rs/src/error.rs @@ -0,0 +1,52 @@ +use std::fmt; + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug)] +pub enum Error { + Io(std::io::Error), + Ort(ort::Error), + Audio(String), + Model(String), + Tokenizer(String), + Config(String), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::Io(e) => write!(f, "IO error: {e}"), + Error::Ort(e) => write!(f, "ONNX Runtime error: {e}"), + Error::Audio(msg) => write!(f, "Audio processing error: {msg}"), + Error::Model(msg) => write!(f, "Model error: {msg}"), + Error::Tokenizer(msg) => write!(f, "Tokenizer error: {msg}"), + Error::Config(msg) => write!(f, "Config error: {msg}"), + } + } +} + +impl std::error::Error for Error {} + +impl From<std::io::Error> for Error { + fn from(e: std::io::Error) -> Self { + Error::Io(e) + } +} + +impl From<ort::Error> for Error { + fn from(e: ort::Error) -> Self { + Error::Ort(e) + } +} + +impl From<serde_json::Error> for Error { + fn from(e: serde_json::Error) -> Self { + Error::Config(e.to_string()) + } +} + +impl From<hound::Error> for Error { + fn from(e: hound::Error) -> Self { + Error::Audio(e.to_string()) + } +} diff --git a/vendor/parakeet-rs/src/execution.rs b/vendor/parakeet-rs/src/execution.rs new file mode 100644 index 0000000..e29aa1d --- /dev/null +++ b/vendor/parakeet-rs/src/execution.rs @@ -0,0 +1,141 @@ +use crate::error::Result; +use ort::session::builder::SessionBuilder; + +// Hardware acceleration options. CPU is default and most reliable. +// GPU providers (CUDA, TensorRT, ROCm) offer 5-10x speedup but require specific hardware. +// All GPU providers automatically fall back to CPU if they fail. +// +// Note: CoreML currently fails with this model due to unsupported operations. +// WebGPU is experimental and may produce incorrect results. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum ExecutionProvider { + #[default] + Cpu, + #[cfg(feature = "cuda")] + Cuda, + #[cfg(feature = "tensorrt")] + TensorRT, + #[cfg(feature = "coreml")] + CoreML, + #[cfg(feature = "directml")] + DirectML, + #[cfg(feature = "rocm")] + ROCm, + #[cfg(feature = "openvino")] + OpenVINO, + #[cfg(feature = "webgpu")] + WebGPU, +} + +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub execution_provider: ExecutionProvider, + pub intra_threads: usize, + pub inter_threads: usize, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + execution_provider: ExecutionProvider::default(), + intra_threads: 4, + inter_threads: 1, + } + } +} + +impl ModelConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self { + self.execution_provider = provider; + self + } + + pub fn with_intra_threads(mut self, threads: usize) -> Self { + self.intra_threads = threads; + self + } + + pub fn with_inter_threads(mut self, threads: usize) -> Self { + self.inter_threads = threads; + self + } + + pub(crate) fn apply_to_session_builder( + &self, + builder: SessionBuilder, + ) -> Result<SessionBuilder> { + use ort::session::builder::GraphOptimizationLevel; + #[cfg(any( + feature = "cuda", + feature = "tensorrt", + feature = "coreml", + feature = "directml", + feature = "rocm", + feature = "openvino", + feature = "webgpu" + ))] + use ort::execution_providers::CPUExecutionProvider; + + let mut builder = builder + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(self.intra_threads)? + .with_inter_threads(self.inter_threads)?; + + builder = match self.execution_provider { + ExecutionProvider::Cpu => builder, + + #[cfg(feature = "cuda")] + ExecutionProvider::Cuda => builder.with_execution_providers([ + ort::execution_providers::CUDAExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "tensorrt")] + ExecutionProvider::TensorRT => builder.with_execution_providers([ + ort::execution_providers::TensorRTExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "coreml")] + ExecutionProvider::CoreML => { + use ort::execution_providers::coreml::{CoreMLComputeUnits, CoreMLExecutionProvider}; + builder.with_execution_providers([ + CoreMLExecutionProvider::default() + .with_compute_units(CoreMLComputeUnits::CPUAndGPU) + .build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])? + } + + #[cfg(feature = "directml")] + ExecutionProvider::DirectML => builder.with_execution_providers([ + ort::execution_providers::DirectMLExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "rocm")] + ExecutionProvider::ROCm => builder.with_execution_providers([ + ort::execution_providers::ROCMExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "openvino")] + ExecutionProvider::OpenVINO => builder.with_execution_providers([ + ort::execution_providers::OpenVINOExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + + #[cfg(feature = "webgpu")] + ExecutionProvider::WebGPU => builder.with_execution_providers([ + ort::execution_providers::WebGPUExecutionProvider::default().build(), + CPUExecutionProvider::default().build().error_on_failure(), + ])?, + }; + + Ok(builder) + } +} diff --git a/vendor/parakeet-rs/src/lib.rs b/vendor/parakeet-rs/src/lib.rs new file mode 100644 index 0000000..0aaefd1 --- /dev/null +++ b/vendor/parakeet-rs/src/lib.rs @@ -0,0 +1,74 @@ +//! # parakeet-rs +//! +//! Rust bindings for NVIDIA's Parakeet speech recognition model using ONNX Runtime. +//! +//! Parakeet is a state-of-the-art automatic speech recognition (ASR) model developed by NVIDIA, +//! based on the FastConformer-TDT architecture with 600 million parameters. +//! +//! ## Features +//! +//! - Easy-to-use API for speech-to-text transcription +//! - Support for ONNX format models +//! - 16kHz mono audio input +//! - Punctuation and capitalization included in output +//! - Fast inference using ONNX Runtime +//! +//! ## Quick Start +//! +//! ```ignore +//! use parakeet_rs::Parakeet; +//! +//! // Load the model +//! let parakeet = Parakeet::from_pretrained(".")?; +//! +//! // Transcribe audio file +//! let text = parakeet.transcribe_file("audio.wav")?; +//! println!("Transcription: {}", text); +//! ``` +//! +//! ## Model Requirements +//! +//! Your model directory should contain: +//! - `model.onnx` - The ONNX model file +//! - `model.onnx_data` - External model weights +//! - `config.json` - Model configuration +//! - `preprocessor_config.json` - Audio preprocessing configuration +//! - `tokenizer.json` - Tokenizer vocabulary +//! - `tokenizer_config.json` - Tokenizer configuration +//! +//! ## Audio Requirements +//! +//! - Format: WAV +//! - Sample Rate: 16kHz +//! - Channels: Mono (stereo will be converted automatically) +//! - Bit Depth: 16-bit PCM or 32-bit float + +mod audio; +mod config; +mod decoder; +mod decoder_tdt; +mod error; +mod execution; +mod model; +mod model_tdt; +mod parakeet; +mod parakeet_tdt; +mod timestamps; +mod vocab; +mod model_eou; +mod parakeet_eou; +#[cfg(feature = "sortformer")] +pub mod sortformer; + +pub use error::{Error, Result}; +pub use execution::{ExecutionProvider, ModelConfig as ExecutionConfig}; +pub use parakeet::Parakeet; +pub use parakeet_tdt::ParakeetTDT; +pub use timestamps::TimestampMode; + +pub use config::{ModelConfig as ModelConfigJson, PreprocessorConfig}; + +pub use decoder::{ParakeetDecoder, TimedToken, TranscriptionResult}; +pub use model::ParakeetModel; +pub use model_eou::ParakeetEOUModel; +pub use parakeet_eou::ParakeetEOU;
\ No newline at end of file diff --git a/vendor/parakeet-rs/src/model.rs b/vendor/parakeet-rs/src/model.rs new file mode 100644 index 0000000..b3cd131 --- /dev/null +++ b/vendor/parakeet-rs/src/model.rs @@ -0,0 +1,93 @@ +use crate::config::ModelConfig; +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use ndarray::Array2; +use ort::session::Session; +use std::path::Path; + +pub struct ParakeetModel { + session: Session, + config: ModelConfig, +} + +impl ParakeetModel { + pub fn from_pretrained<P: AsRef<Path>>(model_path: P) -> Result<Self> { + Self::from_pretrained_with_config(model_path, ExecutionConfig::default()) + } + + pub fn from_pretrained_with_config<P: AsRef<Path>>( + model_path: P, + exec_config: ExecutionConfig, + ) -> Result<Self> { + let model_path = model_path.as_ref(); + + // Use default config (hardcoded constants for Parakeet-CTC-0.6b: please see: json files https://huggingface.co/onnx-community/parakeet-ctc-0.6b-ONNX/tree/main) + let config = ModelConfig::default(); + + let builder = Session::builder()?; + let builder = exec_config.apply_to_session_builder(builder)?; + let session = builder.commit_from_file(model_path)?; + + Ok(Self { session, config }) + } + pub fn forward(&mut self, features: Array2<f32>) -> Result<Array2<f32>> { + let batch_size = 1; + let time_steps = features.shape()[0]; + let feature_size = features.shape()[1]; + + let input = features + .to_shape((batch_size, time_steps, feature_size)) + .map_err(|e| Error::Model(format!("Failed to reshape input: {e}")))? + .to_owned(); + + use ndarray::Array2; + let attention_mask = Array2::<i64>::ones((batch_size, time_steps)); + + let input_value = ort::value::Value::from_array(input)?; + let attention_mask_value = ort::value::Value::from_array(attention_mask)?; + + let outputs = self.session.run(ort::inputs!( + "input_features" => input_value, + "attention_mask" => attention_mask_value + ))?; + + let logits_value = &outputs["logits"]; + let (shape, data) = logits_value + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; + + let shape_dims = shape.as_ref(); + if shape_dims.len() != 3 { + return Err(Error::Model(format!( + "Expected 3D logits, got shape: {shape_dims:?}" + ))); + } + + let batch_size = shape_dims[0] as usize; + let time_steps_out = shape_dims[1] as usize; + let vocab_size = shape_dims[2] as usize; + + if batch_size != 1 { + return Err(Error::Model(format!( + "Expected batch size 1, got {batch_size}" + ))); + } + + let logits_2d = Array2::from_shape_vec((time_steps_out, vocab_size), data.to_vec()) + .map_err(|e| Error::Model(format!("Failed to create array: {e}")))?; + + Ok(logits_2d) + } + + pub fn config(&self) -> &ModelConfig { + &self.config + } + + pub fn vocab_size(&self) -> usize { + self.config.vocab_size + } + + pub fn pad_token_id(&self) -> usize { + self.config.pad_token_id + } +} diff --git a/vendor/parakeet-rs/src/model_eou.rs b/vendor/parakeet-rs/src/model_eou.rs new file mode 100644 index 0000000..5b56e6d --- /dev/null +++ b/vendor/parakeet-rs/src/model_eou.rs @@ -0,0 +1,183 @@ +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use ndarray::{Array1, Array2, Array3, Array4}; +use ort::session::Session; +use std::path::Path; + +/// Encoder cache state for streaming inference +/// The cache maintains temporal context across chunks +pub struct EncoderCache { + /// channel cache: [1, 1, 70, 512] - batch=1, 70 frame lookback + pub cache_last_channel: Array4<f32>, + /// time cache: [1, 1, 512, 8] - batch=1, fixed 8 time steps + pub cache_last_time: Array4<f32>, + /// cache length: [1] with value 0 initially + pub cache_last_channel_len: Array1<i64>, +} + +impl EncoderCache { + /// 17 layers, batch=1, 70 frame lookback, 512 features + pub fn new() -> Self { + Self { + cache_last_channel: Array4::zeros((17, 1, 70, 512)), + cache_last_time: Array4::zeros((17, 1, 512, 8)), + cache_last_channel_len: Array1::from_vec(vec![0i64]), + } + } +} + +pub struct ParakeetEOUModel { + encoder: Session, + decoder_joint: Session, +} + +impl ParakeetEOUModel { + pub fn from_pretrained<P: AsRef<Path>>( + model_dir: P, + exec_config: ExecutionConfig, + ) -> Result<Self> { + let model_dir = model_dir.as_ref(); + + let encoder_path = model_dir.join("encoder.onnx"); + let decoder_path = model_dir.join("decoder_joint.onnx"); + + if !encoder_path.exists() || !decoder_path.exists() { + return Err(Error::Config(format!( + "Missing ONNX files in {}. Expected encoder.onnx and decoder_joint.onnx", + model_dir.display() + ))); + } + + // Load encoder + let builder = Session::builder()?; + let builder = exec_config.apply_to_session_builder(builder)?; + let encoder = builder.commit_from_file(&encoder_path)?; + + // Load decoder + let builder = Session::builder()?; + let builder = exec_config.apply_to_session_builder(builder)?; + let decoder_joint = builder.commit_from_file(&decoder_path)?; + + Ok(Self { + encoder, + decoder_joint, + }) + } + + /// Run the stateful encoder with cache + /// Input: features [1, 128, T], cache state + /// Output: (encoded [1, 512, T], new_cache) + pub fn run_encoder( + &mut self, + features: &Array3<f32>, + length: i64, + cache: &EncoderCache + ) -> Result<(Array3<f32>, EncoderCache)> { + let length_arr = Array1::from_vec(vec![length]); + + let outputs = self.encoder.run(ort::inputs![ + "audio_signal" => ort::value::Value::from_array(features.clone())?, + "length" => ort::value::Value::from_array(length_arr)?, + "cache_last_channel" => ort::value::Value::from_array(cache.cache_last_channel.clone())?, + "cache_last_time" => ort::value::Value::from_array(cache.cache_last_time.clone())?, + "cache_last_channel_len" => ort::value::Value::from_array(cache.cache_last_channel_len.clone())? + ])?; + + // Extract encoder output [1, 512, T] + let (shape, data) = outputs["outputs"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; + + let shape_dims = shape.as_ref(); + let b = shape_dims[0] as usize; + let d = shape_dims[1] as usize; + let t = shape_dims[2] as usize; + + let encoder_out = Array3::from_shape_vec((b, d, t), data.to_vec()) + .map_err(|e| Error::Model(format!("Failed to reshape encoder output: {e}")))?; + + // Extract new cache states + let (ch_shape, ch_data) = outputs["new_cache_last_channel"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract cache_last_channel: {e}")))?; + + let (tm_shape, tm_data) = outputs["new_cache_last_time"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract cache_last_time: {e}")))?; + + let (len_shape, len_data) = outputs["new_cache_last_channel_len"] + .try_extract_tensor::<i64>() + .map_err(|e| Error::Model(format!("Failed to extract cache_len: {e}")))?; + + // Build new cache with extracted shapes + let new_cache = EncoderCache { + cache_last_channel: Array4::from_shape_vec( + (ch_shape[0] as usize, ch_shape[1] as usize, ch_shape[2] as usize, ch_shape[3] as usize), + ch_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_channel: {e}")))?, + + cache_last_time: Array4::from_shape_vec( + (tm_shape[0] as usize, tm_shape[1] as usize, tm_shape[2] as usize, tm_shape[3] as usize), + tm_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape cache_last_time: {e}")))?, + + cache_last_channel_len: Array1::from_shape_vec( + len_shape[0] as usize, + len_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape cache_len: {e}")))?, + }; + + Ok((encoder_out, new_cache)) + } + + /// Run the stateful decoder + /// Returns: (logits [1, 1, 1, vocab], new_state_h, new_state_c) + pub fn run_decoder( + &mut self, + encoder_frame: &Array3<f32>, // [1, 512, 1] + last_token: &Array2<i32>, // [1, 1] + state_h: &Array3<f32>, // [1, 1, 640] + state_c: &Array3<f32>, // [1, 1, 640] + ) -> Result<(Array3<f32>, Array3<f32>, Array3<f32>)> { + + // Target length is always 1 for single step + let target_len = Array1::from_vec(vec![1i32]); + + let outputs = self.decoder_joint.run(ort::inputs![ + "encoder_outputs" => ort::value::Value::from_array(encoder_frame.clone())?, + "targets" => ort::value::Value::from_array(last_token.clone())?, + "target_length" => ort::value::Value::from_array(target_len)?, + "input_states_1" => ort::value::Value::from_array(state_h.clone())?, + "input_states_2" => ort::value::Value::from_array(state_c.clone())? + ])?; + + // 1. Extract Logits + let (l_shape, l_data) = outputs["outputs"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; + + // 2. Extract States (output_states_1, output_states_2) + let (_h_shape, h_data) = outputs["output_states_1"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract state h: {e}")))?; + + let (_c_shape, c_data) = outputs["output_states_2"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract state c: {e}")))?; + + // Reconstruct Arrays + // Logits: I simplify to [1, 1, vocab] + let vocab_size = l_shape[3] as usize; + let logits = Array3::from_shape_vec((1, 1, vocab_size), l_data.to_vec()) + .map_err(|e| Error::Model(format!("Reshape logits failed: {e}")))?; + + // States: [1, 1, 640] + let new_h = Array3::from_shape_vec((1, 1, 640), h_data.to_vec()) + .map_err(|e| Error::Model(format!("Reshape state h failed: {e}")))?; + + let new_c = Array3::from_shape_vec((1, 1, 640), c_data.to_vec()) + .map_err(|e| Error::Model(format!("Reshape state c failed: {e}")))?; + + Ok((logits, new_h, new_c)) + } +}
\ No newline at end of file diff --git a/vendor/parakeet-rs/src/model_tdt.rs b/vendor/parakeet-rs/src/model_tdt.rs new file mode 100644 index 0000000..e00ebdc --- /dev/null +++ b/vendor/parakeet-rs/src/model_tdt.rs @@ -0,0 +1,263 @@ +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use ndarray::{Array1, Array2, Array3}; +use ort::session::Session; +use std::path::{Path, PathBuf}; + +/// TDT model configs +#[derive(Debug, Clone)] +pub struct TDTModelConfig { + pub vocab_size: usize, +} + +impl Default for TDTModelConfig { + fn default() -> Self { + Self { + vocab_size: 8193, + } + } +} + +pub struct ParakeetTDTModel { + encoder: Session, + decoder_joint: Session, + config: TDTModelConfig, +} + +impl ParakeetTDTModel { + /// Load TDT model from directory containing encoder and decoder_joint ONNX files + pub fn from_pretrained<P: AsRef<Path>>( + model_dir: P, + exec_config: ExecutionConfig, + ) -> Result<Self> { + let model_dir = model_dir.as_ref(); + + // Find encoder and decoder_joint files + let encoder_path = Self::find_encoder(model_dir)?; + let decoder_joint_path = Self::find_decoder_joint(model_dir)?; + + let config = TDTModelConfig::default(); + + // Load encoder + let builder = Session::builder()?; + let builder = exec_config.apply_to_session_builder(builder)?; + let encoder = builder.commit_from_file(&encoder_path)?; + + // Load decoder_joint + let builder = Session::builder()?; + let builder = exec_config.apply_to_session_builder(builder)?; + let decoder_joint = builder.commit_from_file(&decoder_joint_path)?; + + + Ok(Self { + encoder, + decoder_joint, + config, + }) + } + + fn find_encoder(dir: &Path) -> Result<PathBuf> { + let candidates = ["encoder-model.onnx", "encoder.onnx"]; + for candidate in &candidates { + let path = dir.join(candidate); + if path.exists() { + return Ok(path); + } + } + Err(Error::Config(format!( + "No encoder model found in {}", + dir.display() + ))) + } + + fn find_decoder_joint(dir: &Path) -> Result<PathBuf> { + let candidates = [ + "decoder_joint-model.onnx", + "decoder_joint.onnx", + "decoder-model.onnx", + ]; + for candidate in &candidates { + let path = dir.join(candidate); + if path.exists() { + return Ok(path); + } + } + Err(Error::Config(format!( + "No decoder_joint model found in {}", + dir.display() + ))) + } + + /// Run greedy decoding - returns (token_ids, frame_indices, durations) + pub fn forward(&mut self, features: Array2<f32>) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> { + // Run encoder + let (encoder_out, encoder_len) = self.run_encoder(&features)?; + + // Run greedy decoding with decoder_joint + let (tokens, frame_indices, durations) = self.greedy_decode(&encoder_out, encoder_len)?; + + Ok((tokens, frame_indices, durations)) + } + + fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Array3<f32>, i64)> { + let batch_size = 1; + let time_steps = features.shape()[0]; + let feature_size = features.shape()[1]; + + // TDT encoder expects (batch, features, time) not (batch, time, features) + let input = features + .t() + .to_shape((batch_size, feature_size, time_steps)) + .map_err(|e| Error::Model(format!("Failed to reshape encoder input: {e}")))? + .to_owned(); + + let input_length = Array1::from_vec(vec![time_steps as i64]); + + let input_value = ort::value::Value::from_array(input)?; + let length_value = ort::value::Value::from_array(input_length)?; + + let outputs = self.encoder.run(ort::inputs!( + "audio_signal" => input_value, + "length" => length_value + ))?; + + let encoder_out = &outputs["outputs"]; + let encoder_lens = &outputs["encoded_lengths"]; + + let (shape, data) = encoder_out + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract encoder output: {e}")))?; + + let (_, lens_data) = encoder_lens + .try_extract_tensor::<i64>() + .map_err(|e| Error::Model(format!("Failed to extract encoder lengths: {e}")))?; + + let shape_dims = shape.as_ref(); + if shape_dims.len() != 3 { + return Err(Error::Model(format!( + "Expected 3D encoder output, got shape: {shape_dims:?}" + ))); + } + + let b = shape_dims[0] as usize; + let t = shape_dims[1] as usize; + let d = shape_dims[2] as usize; + + let encoder_array = Array3::from_shape_vec((b, t, d), data.to_vec()) + .map_err(|e| Error::Model(format!("Failed to create encoder array: {e}")))?; + + // TDT encoder outputs [batch, encoder_dim, time] directly + Ok((encoder_array, lens_data[0])) + } + + fn greedy_decode(&mut self, encoder_out: &Array3<f32>, _encoder_len: i64) -> Result<(Vec<usize>, Vec<usize>, Vec<usize>)> { + // encoder_out shape: [batch, encoder_dim, time] + let encoder_dim = encoder_out.shape()[1]; + let time_steps = encoder_out.shape()[2]; + let vocab_size = self.config.vocab_size; + let max_tokens_per_step = 10; + let blank_id = vocab_size - 1; + + // States: (num_layers=2, batch=1, hidden_dim=640) + let mut state_h = Array3::<f32>::zeros((2, 1, 640)); + let mut state_c = Array3::<f32>::zeros((2, 1, 640)); + + let mut tokens = Vec::new(); + let mut frame_indices = Vec::new(); + let mut durations = Vec::new(); + + let mut t = 0; + let mut emitted_tokens = 0; + let mut last_emitted_token = blank_id as i32; + + // Frame-by-frame RNN-T/TDT greedy decoding + while t < time_steps { + // Get single encoder frame: slice [0, :, t] and reshape to [1, encoder_dim, 1] + let frame = encoder_out.slice(ndarray::s![0, .., t]).to_owned(); + let frame_reshaped = frame + .to_shape((1, encoder_dim, 1)) + .map_err(|e| Error::Model(format!("Failed to reshape frame: {e}")))? + .to_owned(); + + // Current token for prediction network + let targets = Array2::from_shape_vec((1, 1), vec![last_emitted_token]) + .map_err(|e| Error::Model(format!("Failed to create targets: {e}")))?; + + // Run decoder_joint + let outputs = self.decoder_joint.run(ort::inputs!( + "encoder_outputs" => ort::value::Value::from_array(frame_reshaped)?, + "targets" => ort::value::Value::from_array(targets)?, + "target_length" => ort::value::Value::from_array(Array1::from_vec(vec![1i32]))?, + "input_states_1" => ort::value::Value::from_array(state_h.clone())?, + "input_states_2" => ort::value::Value::from_array(state_c.clone())? + ))?; + + // Extract logits + let (_, logits_data) = outputs["outputs"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract logits: {e}")))?; + + // TDT outputs vocab_size + 5 durations (8193 + 5 = 8198) + let vocab_logits: Vec<f32> = logits_data.iter().take(vocab_size).copied().collect(); + let duration_logits: Vec<f32> = logits_data.iter().skip(vocab_size).copied().collect(); + + let token_id = vocab_logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(blank_id); + + let duration_step = if !duration_logits.is_empty() { + duration_logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(idx, _)| idx) + .unwrap_or(0) + } else { + 0 + }; + + // Check if blank token + if token_id != blank_id { + // Update states when we emit a token + if let Ok((h_shape, h_data)) = outputs["output_states_1"].try_extract_tensor::<f32>() { + let dims = h_shape.as_ref(); + state_h = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), h_data.to_vec()) + .map_err(|e| Error::Model(format!("Failed to update state_h: {e}")))?; + } + if let Ok((c_shape, c_data)) = outputs["output_states_2"].try_extract_tensor::<f32>() { + let dims = c_shape.as_ref(); + state_c = Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), c_data.to_vec()) + .map_err(|e| Error::Model(format!("Failed to update state_c: {e}")))?; + } + + tokens.push(token_id); + frame_indices.push(t); + durations.push(duration_step); + last_emitted_token = token_id as i32; + emitted_tokens += 1; + + // Don't advance yet - try to emit more tokens from the same frame + } else { + // Blank token - advance frame pointer + // Duration prediction applies when we finally move to next frame after emitting tokens + if duration_step > 0 && emitted_tokens > 0 { + t += duration_step; + } else { + t += 1; + } + emitted_tokens = 0; + } + + // Safety check: if we've emitted too many tokens from the same frame, advance + if emitted_tokens >= max_tokens_per_step { + t += 1; + emitted_tokens = 0; + } + } + + Ok((tokens, frame_indices, durations)) + } +} diff --git a/vendor/parakeet-rs/src/parakeet.rs b/vendor/parakeet-rs/src/parakeet.rs new file mode 100644 index 0000000..d2aabdd --- /dev/null +++ b/vendor/parakeet-rs/src/parakeet.rs @@ -0,0 +1,210 @@ +use crate::audio; +use crate::config::PreprocessorConfig; +use crate::decoder::{ParakeetDecoder, TranscriptionResult}; +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use crate::model::ParakeetModel; +use crate::timestamps::{process_timestamps, TimestampMode}; +use std::path::{Path, PathBuf}; + +pub struct Parakeet { + model: ParakeetModel, + decoder: ParakeetDecoder, + preprocessor_config: PreprocessorConfig, + model_dir: PathBuf, +} + +impl Parakeet { + /// Load Parakeet model from path with optional configuration. + /// + /// # Arguments + /// * `path` - Directory containing model files, or path to specific model file + /// * `config` - Optional execution configuration (defaults to CPU if None) + /// + /// # Examples + /// ```no_run + /// use parakeet_rs::Parakeet; + /// + /// // Load from directory with CPU (default) + /// let parakeet = Parakeet::from_pretrained(".", None)?; + /// + /// // Or load from specific model file + /// let parakeet = Parakeet::from_pretrained("model_q4.onnx", None)?; + /// # Ok::<(), Box<dyn std::error::Error>>(()) + /// ``` + /// + /// For GPU acceleration, enable the corresponding feature (cuda, tensorrt, webgpu, etc.) + /// and pass an `ExecutionConfig` with the desired execution provider. + pub fn from_pretrained<P: AsRef<Path>>( + path: P, + config: Option<ExecutionConfig>, + ) -> Result<Self> { + let path = path.as_ref(); + + // Determine if path is a directory or file + let (model_path, tokenizer_path, model_dir) = if path.is_dir() { + // Directory mode: auto-detect model file + let model_path = Self::find_model_file(path)?; + let tokenizer_path = path.join("tokenizer.json"); + (model_path, tokenizer_path, path.to_path_buf()) + } else if path.is_file() { + // File mode: path points directly to model file + let model_dir = path + .parent() + .ok_or_else(|| Error::Config("Invalid model path".to_string()))?; + let tokenizer_path = model_dir.join("tokenizer.json"); + (path.to_path_buf(), tokenizer_path, model_dir.to_path_buf()) + } else { + return Err(Error::Config(format!( + "Path does not exist: {}", + path.display() + ))); + }; + + // Check tokenizer exists + if !tokenizer_path.exists() { + return Err(Error::Config(format!( + "Required file 'tokenizer.json' not found in {}", + model_dir.display() + ))); + } + + let preprocessor_config = PreprocessorConfig::default(); + let exec_config = config.unwrap_or_default(); + + let model = ParakeetModel::from_pretrained_with_config(&model_path, exec_config)?; + let decoder = ParakeetDecoder::from_pretrained(&tokenizer_path)?; + + Ok(Self { + model, + decoder, + preprocessor_config, + model_dir, + }) + } + + fn find_model_file(dir: &Path) -> Result<PathBuf> { + // Priority order: model.onnx > model_fp16.onnx > model_int8.onnx > model_q4.onnx + let candidates = [ + "model.onnx", + "model_fp16.onnx", + "model_int8.onnx", + "model_q4.onnx", + ]; + + for candidate in &candidates { + let path = dir.join(candidate); + if path.exists() { + return Ok(path); + } + } + + // If none of the standard names found, search for any .onnx file + if let Ok(entries) = std::fs::read_dir(dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("onnx") { + return Ok(path); + } + } + } + + Err(Error::Config(format!( + "No model file (*.onnx) found in directory: {}", + dir.display() + ))) + } + + /// Transcribe audio samples. + /// + /// # Arguments + /// + /// * `audio` - Audio samples as f32 values + /// * `sample_rate` - Sample rate in Hz + /// * `channels` - Number of audio channels + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested level. + pub fn transcribe_samples( + &mut self, + audio: Vec<f32>, + sample_rate: u32, + channels: u16, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; + let logits = self.model.forward(features)?; + + let mut result = self.decoder.decode_with_timestamps( + &logits, + self.preprocessor_config.hop_length, + self.preprocessor_config.sampling_rate, + )?; + + // Process timestamps to requested output mode + let mode = mode.unwrap_or(TimestampMode::Tokens); + result.tokens = process_timestamps(&result.tokens, mode); + + // Rebuild full text from processed tokens to ensure consistency + result.text = result.tokens.iter() + .map(|t| t.text.as_str()) + .collect::<Vec<_>>() + .join(" "); + + Ok(result) + } + + /// Transcribe an audio file with timestamps + /// + /// # Arguments + /// + /// * `audio_path` - A path to the audio file that needs to be transcribed. + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. + pub fn transcribe_file<P: AsRef<Path>>( + &mut self, + audio_path: P, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let audio_path = audio_path.as_ref(); + let (audio, spec) = audio::load_audio(audio_path)?; + + self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) + } + + /// Transcribes multiple audio files in batch. + /// + /// # Arguments + /// + /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. + /// * `mode` - Optional timestamp output mode (Tokens, Words, or Sentences) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested level. + pub fn transcribe_file_batch<P: AsRef<Path>>( + &mut self, + audio_paths: &[P], + mode: Option<TimestampMode>, + ) -> Result<Vec<TranscriptionResult>> { + let mut results = Vec::with_capacity(audio_paths.len()); + for path in audio_paths { + let result = self.transcribe_file(path, mode)?; + results.push(result); + } + Ok(results) + } + + pub fn model_dir(&self) -> &Path { + &self.model_dir + } + + pub fn preprocessor_config(&self) -> &PreprocessorConfig { + &self.preprocessor_config + } +} diff --git a/vendor/parakeet-rs/src/parakeet_eou.rs b/vendor/parakeet-rs/src/parakeet_eou.rs new file mode 100644 index 0000000..25c7d64 --- /dev/null +++ b/vendor/parakeet-rs/src/parakeet_eou.rs @@ -0,0 +1,304 @@ +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use crate::model_eou::{EncoderCache, ParakeetEOUModel}; +use ndarray::{s, Array2, Array3}; +use rustfft::{num_complex::Complex, FftPlanner}; +use std::collections::VecDeque; +use std::f32::consts::PI; +use std::path::Path; + +const SAMPLE_RATE: usize = 16000; + +const N_FFT: usize = 512; +const WIN_LENGTH: usize = 400; +const HOP_LENGTH: usize = 160; +const N_MELS: usize = 128; +const PREEMPH: f32 = 0.97; +const LOG_ZERO_GUARD: f32 = 5.960464478e-8; +const FMAX: f32 = 8000.0; + +/// Parakeet RealTime EOU model for streaming ASR with end-of-utterance detection. +/// Uses cache-aware streaming with audio buffering for pre-encode context. +pub struct ParakeetEOU { + model: ParakeetEOUModel, + tokenizer: tokenizers::Tokenizer, + encoder_cache: EncoderCache, + state_h: Array3<f32>, + state_c: Array3<f32>, + last_token: Array2<i32>, + blank_id: i32, + eou_id: i32, + mel_basis: Array2<f32>, + window: Vec<f32>, + audio_buffer: VecDeque<f32>, + buffer_size_samples: usize, +} + +impl ParakeetEOU { + /// Load Parakeet EOU model from path + /// + /// # Arguments + /// * `path` - Directory containing encoder.onnx, decoder_joint.onnx, and tokenizer.json + /// * `config` - Optional execution configuration (defaults to CPU if None) + pub fn from_pretrained<P: AsRef<Path>>(path: P, config: Option<ExecutionConfig>) -> Result<Self> { + let path = path.as_ref(); + let tokenizer_path = path.join("tokenizer.json"); + let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path) + .map_err(|e| Error::Config(format!("Failed to load tokenizer: {e}")))?; + + let vocab_size = tokenizer.get_vocab_size(true); + let blank_id = (vocab_size - 1) as i32; + let blank_id = if blank_id < 1000 { 1026 } else { blank_id }; + let eou_id = tokenizer.token_to_id("<EOU>").map(|id| id as i32).unwrap_or(1024); + + let exec_config = config.unwrap_or_default(); + let model = ParakeetEOUModel::from_pretrained(path, exec_config)?; + + // Buffer size: 4 seconds of audio + // Provides long history for feature extraction context + // Note that, I pick those "magic numbers" by looking NeMo's ring buffer approach. + let buffer_size_samples = SAMPLE_RATE * 4; // 4 seconds = 64000 samples + + Ok(Self { + model, + tokenizer, + encoder_cache: EncoderCache::new(), + state_h: Array3::zeros((1, 1, 640)), + state_c: Array3::zeros((1, 1, 640)), + last_token: Array2::from_elem((1, 1), blank_id), + blank_id, + eou_id, + mel_basis: Self::create_mel_filterbank(), + window: Self::create_window(), + audio_buffer: VecDeque::with_capacity(buffer_size_samples), + buffer_size_samples, + }) + } + + /// Transcribe a chunk of audio samples. + /// + /// # Arguments + /// * `chunk` - Audio chunk (typically 160ms / 2560 samples at 16kHz) + /// * `reset_on_eou` - If true, reset decoder state when end-of-utterance is detected + /// + /// # Streaming Behavior + /// Cache-aware streaming + /// - Maintains 4-second ring buffer for feature extraction context + /// - Extracts features from full buffer + /// - Slices last (pre_encode_cache + new_frames) for encoder input + /// - pre_encode_cache=9 frames, new_frames=~16, total=~25 frames to encoder + pub fn transcribe(&mut self, chunk: &[f32], reset_on_eou: bool) -> Result<String> { + // Add new chunk to rolling buffer + self.audio_buffer.extend(chunk.iter().copied()); + + // Trim buffer to keep only the most recent samples + while self.audio_buffer.len() > self.buffer_size_samples { + self.audio_buffer.pop_front(); + } + + // Wait until buffer has minimum samples (at least 1 second for stable features) + const MIN_BUFFER_SAMPLES: usize = SAMPLE_RATE; // 1 second + if self.audio_buffer.len() < MIN_BUFFER_SAMPLES { + return Ok(String::new()); + } + + // Extract features from FULL buffer (provides context for feature extraction) + let buffer_slice: Vec<f32> = self.audio_buffer.iter().copied().collect(); + let full_features = self.extract_mel_features(&buffer_slice); + let total_frames = full_features.shape()[2]; + + // Slice to take only (pre_encode_cache + new_frames) for encoder + // pre_encode_cache = 9 frames, new_frames = ~16 for 160ms chunk + const PRE_ENCODE_CACHE: usize = 9; + const FRAMES_PER_CHUNK: usize = 16; + const SLICE_LEN: usize = PRE_ENCODE_CACHE + FRAMES_PER_CHUNK; + + let start_frame = if total_frames > SLICE_LEN { + total_frames - SLICE_LEN + } else { + 0 + }; + + let features = full_features.slice(s![.., .., start_frame..]).to_owned(); + let time_steps = features.shape()[2]; + + // Encode with cache - encoder sees full buffer context + let (encoder_out, new_cache) = self.model.run_encoder(&features, time_steps as i64, &self.encoder_cache)?; + self.encoder_cache = new_cache; + + let total_frames = encoder_out.shape()[2]; + if total_frames == 0 { + return Ok(String::new()); + } + + // Process all output frames (typically 1 frame per chunk) + let new_frames = encoder_out; + + let mut text_output = String::new(); + + for t in 0..new_frames.shape()[2] { + let current_frame = new_frames.slice(s![.., .., t..t + 1]).to_owned(); + let mut syms_added = 0; + + while syms_added < 5 { + let (logits, new_h, new_c) = self.model.run_decoder( + ¤t_frame, + &self.last_token, + &self.state_h, + &self.state_c, + )?; + + let vocab = logits.slice(s![0, 0, ..]); + + let mut max_idx = 0; + let mut max_val = f32::NEG_INFINITY; + for (i, &val) in vocab.iter().enumerate() { + if val.is_finite() && val > max_val { + max_val = val; + max_idx = i as i32; + } + } + + if max_idx == self.blank_id || max_idx == 0 { + break; + } + + if max_idx == self.eou_id { + if reset_on_eou { + self.reset_states(); + return Ok(text_output + " [EOU]"); + } + break; + } + + if max_idx as usize >= self.tokenizer.get_vocab_size(true) { + break; + } + + self.state_h = new_h; + self.state_c = new_c; + self.last_token.fill(max_idx); + + if let Some(token) = self.tokenizer.id_to_token(max_idx as u32) { + let clean = token.replace('▁', " "); + text_output.push_str(&clean); + } + syms_added += 1; + } + } + Ok(text_output) + } + + fn reset_states(&mut self) { + // Soft reset: Only reset decoder states + // at this state, we need to keep encoder cache and audio buffer flowing for continuous context + // self.encoder_cache = EncoderCache::new(); // DON'T reset!!! + self.state_h.fill(0.0); + self.state_c.fill(0.0); + self.last_token.fill(self.blank_id); + // self.audio_buffer.clear(); // DON'T clear!! + } + + fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> { + let audio_pre = Self::apply_preemphasis(audio); + let spec = self.stft(&audio_pre); + let mel = self.mel_basis.dot(&spec); + let mel_log = mel.mapv(|x| (x.max(0.0) + LOG_ZERO_GUARD).ln()); + mel_log.insert_axis(ndarray::Axis(0)) + } + + fn apply_preemphasis(audio: &[f32]) -> Vec<f32> { + let mut result = Vec::with_capacity(audio.len()); + if audio.is_empty() { + return result; + } + + let safe_x = |x: f32| if x.is_finite() { x } else { 0.0 }; + + result.push(safe_x(audio[0])); + for i in 1..audio.len() { + result.push(safe_x(audio[i]) - PREEMPH * safe_x(audio[i - 1])); + } + result + } + + fn stft(&self, audio: &[f32]) -> Array2<f32> { + let mut planner = FftPlanner::<f32>::new(); + let fft = planner.plan_fft_forward(N_FFT); + + let pad_amount = N_FFT / 2; + let mut padded_audio = vec![0.0; pad_amount]; + padded_audio.extend_from_slice(audio); + padded_audio.extend(std::iter::repeat(0.0).take(pad_amount)); + + let num_frames = 1 + (padded_audio.len().saturating_sub(WIN_LENGTH)) / HOP_LENGTH; + let freq_bins = N_FFT / 2 + 1; + let mut spec = Array2::zeros((freq_bins, num_frames)); + + for frame_idx in 0..num_frames { + let start = frame_idx * HOP_LENGTH; + if start + WIN_LENGTH > padded_audio.len() { + break; + } + + let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT]; + for i in 0..WIN_LENGTH { + buffer[i] = Complex::new(padded_audio[start + i] * self.window[i], 0.0); + } + fft.process(&mut buffer); + for (i, val) in buffer.iter().take(freq_bins).enumerate() { + let mag_sq = val.norm_sqr(); + spec[[i, frame_idx]] = if mag_sq.is_finite() { mag_sq } else { 0.0 }; + } + } + spec + } + + fn create_window() -> Vec<f32> { + (0..WIN_LENGTH) + .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / ((WIN_LENGTH - 1) as f32)).cos()) + .collect() + } + + fn create_mel_filterbank() -> Array2<f32> { + let num_freqs = N_FFT / 2 + 1; + + let hz_to_mel = |hz: f32| 2595.0 * (1.0 + hz / 700.0).log10(); + let mel_to_hz = |mel: f32| 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0); + + let mel_min = hz_to_mel(0.0); + let mel_max = hz_to_mel(FMAX); + + let mel_points: Vec<f32> = (0..=N_MELS + 1) + .map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (N_MELS + 1) as f32)) + .collect(); + + let fft_freqs: Vec<f32> = (0..num_freqs) + .map(|i| (SAMPLE_RATE as f32 / N_FFT as f32) * i as f32) + .collect(); + + let mut weights = Array2::zeros((N_MELS, num_freqs)); + + for i in 0..N_MELS { + let left = mel_points[i]; + let center = mel_points[i + 1]; + let right = mel_points[i + 2]; + for (j, &freq) in fft_freqs.iter().enumerate() { + if freq >= left && freq <= center { + weights[[i, j]] = (freq - left) / (center - left); + } else if freq > center && freq <= right { + weights[[i, j]] = (right - freq) / (right - center); + } + } + } + + for i in 0..N_MELS { + let enorm = 2.0 / (mel_points[i + 2] - mel_points[i]); + for j in 0..num_freqs { + weights[[i, j]] *= enorm; + } + } + + weights + } +} diff --git a/vendor/parakeet-rs/src/parakeet_tdt.rs b/vendor/parakeet-rs/src/parakeet_tdt.rs new file mode 100644 index 0000000..719ae75 --- /dev/null +++ b/vendor/parakeet-rs/src/parakeet_tdt.rs @@ -0,0 +1,167 @@ +use crate::audio; +use crate::config::PreprocessorConfig; +use crate::decoder::TranscriptionResult; +use crate::decoder_tdt::ParakeetTDTDecoder; +use crate::error::{Error, Result}; +use crate::execution::ModelConfig as ExecutionConfig; +use crate::model_tdt::ParakeetTDTModel; +use crate::timestamps::{process_timestamps, TimestampMode}; +use crate::vocab::Vocabulary; +use std::path::{Path, PathBuf}; + +/// Parakeet TDT model for multilingual ASR +pub struct ParakeetTDT { + model: ParakeetTDTModel, + decoder: ParakeetTDTDecoder, + preprocessor_config: PreprocessorConfig, + model_dir: PathBuf, +} + +impl ParakeetTDT { + /// Load Parakeet TDT model from path with optional configuration. + /// + /// # Arguments + /// * `path` - Directory containing encoder-model.onnx, decoder_joint-model.onnx, and vocab.txt + /// * `config` - Optional execution configuration (defaults to CPU if None) + pub fn from_pretrained<P: AsRef<Path>>( + path: P, + config: Option<ExecutionConfig>, + ) -> Result<Self> { + let path = path.as_ref(); + + if !path.is_dir() { + return Err(Error::Config(format!( + "TDT model path must be a directory: {}", + path.display() + ))); + } + + let vocab_path = path.join("vocab.txt"); + if !vocab_path.exists() { + return Err(Error::Config(format!( + "vocab.txt not found in {}", + path.display() + ))); + } + + // TDT-specific preprocessor config (128 features instead of 80) + let preprocessor_config = PreprocessorConfig { + feature_extractor_type: "ParakeetFeatureExtractor".to_string(), + feature_size: 128, + hop_length: 160, + n_fft: 512, + padding_side: "right".to_string(), + padding_value: 0.0, + preemphasis: 0.97, + processor_class: "ParakeetProcessor".to_string(), + return_attention_mask: true, + sampling_rate: 16000, + win_length: 400, + }; + + let exec_config = config.unwrap_or_default(); + + let model = ParakeetTDTModel::from_pretrained(path, exec_config)?; + let vocab = Vocabulary::from_file(&vocab_path)?; + let decoder = ParakeetTDTDecoder::from_vocab(vocab); + + Ok(Self { + model, + decoder, + preprocessor_config, + model_dir: path.to_path_buf(), + }) + } + + /// Transcribe audio samples. + /// + /// # Arguments + /// + /// * `audio` - Audio samples as f32 values + /// * `sample_rate` - Sample rate in Hz + /// * `channels` - Number of audio channels + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// A `TranscriptionResult` containing the transcribed text and timestamps at the requested mode. + pub fn transcribe_samples( + &mut self, + audio: Vec<f32>, + sample_rate: u32, + channels: u16, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let features = audio::extract_features_raw(audio, sample_rate, channels, &self.preprocessor_config)?; + let (tokens, frame_indices, durations) = self.model.forward(features)?; + + let mut result = self.decoder.decode_with_timestamps( + &tokens, + &frame_indices, + &durations, + self.preprocessor_config.hop_length, + self.preprocessor_config.sampling_rate, + )?; + + // Apply timestamp mode conversion + let mode = mode.unwrap_or(TimestampMode::Tokens); + result.tokens = process_timestamps(&result.tokens, mode); + + // Rebuild full text from processed tokens + result.text = result.tokens.iter() + .map(|t| t.text.as_str()) + .collect::<Vec<_>>() + .join(" "); + + Ok(result) + } + + /// Transcribe an audio file with timestamps + /// + /// # Arguments + /// + /// * `audio_path` - A path to the audio file that needs to be transcribed. + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. + pub fn transcribe_file<P: AsRef<Path>>( + &mut self, + audio_path: P, + mode: Option<TimestampMode>, + ) -> Result<TranscriptionResult> { + let audio_path = audio_path.as_ref(); + let (audio, spec) = audio::load_audio(audio_path)?; + + self.transcribe_samples(audio, spec.sample_rate, spec.channels, mode) + } + + /// Transcribes multiple audio files in batch. + /// + /// # Arguments + /// + /// * `audio_paths`: A slice of paths to the audio files that need to be transcribed. + /// * `mode` - Optional timestamp mode (Token, Word, or Segment) + /// + /// # Returns + /// + /// This function returns a `TranscriptionResult` which includes the transcribed text along with timestamps at the requested mode. + pub fn transcribe_file_batch<P: AsRef<Path>>( + &mut self, + audio_paths: &[P], + mode: Option<TimestampMode>, + ) -> Result<Vec<TranscriptionResult>> { + let mut results = Vec::with_capacity(audio_paths.len()); + for path in audio_paths { + let result = self.transcribe_file(path, mode)?; + results.push(result); + } + Ok(results) + } + + /// Get model directory path + pub fn model_dir(&self) -> &Path { + &self.model_dir + } +} diff --git a/vendor/parakeet-rs/src/sortformer.rs b/vendor/parakeet-rs/src/sortformer.rs new file mode 100644 index 0000000..2b1e5a3 --- /dev/null +++ b/vendor/parakeet-rs/src/sortformer.rs @@ -0,0 +1,1062 @@ +//! NVIDIA Sortformer v2 Streaming Speaker Diarization +//! +//! This module implements NVIDIA's Sortformer v2 streaming model for speaker diarization. +//! +//! Key features: +//! - Streaming inference with ~10s chunks (124 frames at 80ms each) +//! - FIFO buffer for context management +//! - Smart speaker cache compression (keeps important frames, not just recent) +//! - Silence profile tracking +//! - Post-processing: median filtering, hysteresis thresholding +//! - Supports up to 4 speakers +//! +//! Reference: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 +//! Note that, my ONNX export: +//! CHUNK_LEN = 124 +//! FIFO_LEN = 124 +//! CACHE_LEN = 188 +//! FEAT_DIM = 128 +//! EMB_DIM = 512 +//! Note, my stft code is adapted from: https://librosa.org/doc/main/generated/librosa.stft.html + +use crate::error::{Error, Result}; +use crate::execution::ModelConfig; +use ndarray::{s, Array1, Array2, Array3, Axis}; +use ort::session::Session; +use rustfft::{num_complex::Complex, FftPlanner}; +use std::f32::consts::PI; +use std::path::Path; + +// Model constants +const N_FFT: usize = 512; +const WIN_LENGTH: usize = 400; +const HOP_LENGTH: usize = 160; +const N_MELS: usize = 128; +const PREEMPH: f32 = 0.97; +const LOG_ZERO_GUARD: f32 = 5.960464478e-8; +const SAMPLE_RATE: usize = 16000; +const FMIN: f32 = 0.0; +const FMAX: f32 = 8000.0; + +// Streaming constants +const CHUNK_LEN: usize = 124; // Frames per chunk (~10s at 80ms) +const FIFO_LEN: usize = 124; // FIFO buffer length +const SPKCACHE_LEN: usize = 188; // Speaker cache length +const SPKCACHE_UPDATE_PERIOD: usize = 124; +const SUBSAMPLING: usize = 8; // Audio frames -> model frames +const EMB_DIM: usize = 512; // Embedding dimension +const NUM_SPEAKERS: usize = 4; // Model supports 4 speakers +const FRAME_DURATION: f32 = 0.08; // 80ms per frame + +// Cache compression params (from NeMo) +const SPKCACHE_SIL_FRAMES_PER_SPK: usize = 3; +const PRED_SCORE_THRESHOLD: f32 = 0.25; +const STRONG_BOOST_RATE: f32 = 0.75; +const WEAK_BOOST_RATE: f32 = 1.5; +const MIN_POS_SCORES_RATE: f32 = 0.5; +const SIL_THRESHOLD: f32 = 0.2; +const MAX_INDEX: usize = 99999; + +/// Post-processing configuration for speaker diarization. (NVIDIA official configs from v2 YAMLs) +/// +/// Controls how raw model predictions are converted into speaker segments. +/// NVIDIA provides pre-tuned configs for different datasets (CallHome, DIHARD3, AMI). +/// +/// # Parameters +/// - `onset`: Probability threshold to START a speaker segment (higher = more strict) +/// - `offset`: Probability threshold to END a speaker segment (lower = longer segments) +/// - `pad_onset`: Seconds to subtract from segment start times +/// - `pad_offset`: Seconds to add to segment end times +/// - `min_duration_on`: Minimum segment length in seconds (filters short blips) +/// - `min_duration_off`: Minimum gap between segments before merging +/// - `median_window`: Smoothing window size (odd number, higher = smoother) +/// +/// # Pre-tuned Configs +/// - `callhome()` - (default) +/// - `dihard3()` +/// +/// # Custom Config +/// Use `custom(onset, offset)` to create your own config for fine-tuning. +/// +/// See: https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/speaker_tasks/diarization/conf/neural_diarizer +#[derive(Debug, Clone)] +pub struct DiarizationConfig { + pub onset: f32, + pub offset: f32, + pub pad_onset: f32, + pub pad_offset: f32, + pub min_duration_on: f32, + pub min_duration_off: f32, + pub median_window: usize, +} + +impl Default for DiarizationConfig { + fn default() -> Self { + Self::callhome() + } +} + +impl DiarizationConfig { + /// CallHome dataset config for v2 (default) + /// From: diar_streaming_sortformer_4spk-v2_callhome-part1.yaml + pub fn callhome() -> Self { + Self { + onset: 0.641, + offset: 0.561, + pad_onset: 0.229, + pad_offset: 0.079, + min_duration_on: 0.511, + min_duration_off: 0.296, + median_window: 11, + } + } + + /// DIHARD3 dataset config for v2 + /// From: diar_streaming_sortformer_4spk-v2_dihard3-dev.yaml + pub fn dihard3() -> Self { + Self { + onset: 0.56, + offset: 1.0, + pad_onset: 0.063, + pad_offset: 0.002, + min_duration_on: 0.007, + min_duration_off: 0.151, + median_window: 11, + } + } + + /// Create a custom config for fine-tuning diarization behavior. + /// + /// # Arguments + /// * `onset` - Probability threshold to start a segment (0.0-1.0, typical: 0.5-0.7) + /// * `offset` - Probability threshold to end a segment (0.0-1.0, typical: 0.4-0.6) + /// + /// # Example + /// ```rust + /// use parakeet_rs::sortformer::DiarizationConfig; + /// + /// // More sensitive detection (lower thresholds) + /// let sensitive = DiarizationConfig::custom(0.5, 0.4); + /// + /// // Stricter detection (higher thresholds, fewer false positives) + /// let strict = DiarizationConfig::custom(0.7, 0.6); + /// + /// // Full customization + /// let mut config = DiarizationConfig::custom(0.6, 0.5); + /// config.min_duration_on = 0.3; // Ignore segments shorter than 300ms + /// config.median_window = 15; // More smoothing + /// ``` + pub fn custom(onset: f32, offset: f32) -> Self { + Self { + onset, + offset, + pad_onset: 0.0, + pad_offset: 0.0, + min_duration_on: 0.1, + min_duration_off: 0.1, + median_window: 11, + } + } +} + +/// Speaker segment with start time, end time, and speaker ID +#[derive(Debug, Clone)] +pub struct SpeakerSegment { + pub start: f32, + pub end: f32, + pub speaker_id: usize, +} + +/// Streaming Sortformer v2 speaker diarization engine +pub struct Sortformer { + session: Session, + config: DiarizationConfig, + // Streaming state. note that, Same way as Nemo + spkcache: Array3<f32>, // (1, 0..SPKCACHE_LEN, EMB_DIM) + spkcache_preds: Option<Array3<f32>>, // (1, 0..SPKCACHE_LEN, NUM_SPEAKERS) + fifo: Array3<f32>, // (1, 0..FIFO_LEN, EMB_DIM) + fifo_preds: Array3<f32>, // (1, 0..FIFO_LEN, NUM_SPEAKERS) + mean_sil_emb: Array2<f32>, // (1, EMB_DIM) + n_sil_frames: usize, + // Mel filterbank (cached) + mel_basis: Array2<f32>, +} + +impl Sortformer { + /// a new Sortformer instance from ONNX model path + pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> { + Self::with_config(model_path, None, DiarizationConfig::default()) + } + + /// Create with custom config + pub fn with_config<P: AsRef<Path>>( + model_path: P, + execution_config: Option<ModelConfig>, + config: DiarizationConfig, + ) -> Result<Self> { + let config_to_use = execution_config.unwrap_or_default(); + + let session = config_to_use + .apply_to_session_builder(Session::builder()?)? + .commit_from_file(model_path.as_ref())?; + + let mel_basis = Self::create_mel_filterbank(); + + let mut instance = Self { + session, + config, + spkcache: Array3::zeros((1, 0, EMB_DIM)), + spkcache_preds: None, + fifo: Array3::zeros((1, 0, EMB_DIM)), + fifo_preds: Array3::zeros((1, 0, NUM_SPEAKERS)), + mean_sil_emb: Array2::zeros((1, EMB_DIM)), + n_sil_frames: 0, + mel_basis, + }; + instance.reset_state(); + Ok(instance) + } + + /// Reset streaming state + pub fn reset_state(&mut self) { + self.spkcache = Array3::zeros((1, 0, EMB_DIM)); + self.spkcache_preds = None; + self.fifo = Array3::zeros((1, 0, EMB_DIM)); + self.fifo_preds = Array3::zeros((1, 0, NUM_SPEAKERS)); + self.mean_sil_emb = Array2::zeros((1, EMB_DIM)); + self.n_sil_frames = 0; + } + + /// Main diarization entry point + pub fn diarize( + &mut self, + mut audio: Vec<f32>, + sample_rate: u32, + channels: u16, + ) -> Result<Vec<SpeakerSegment>> { + // Resample if needed + if sample_rate != SAMPLE_RATE as u32 { + return Err(Error::Audio(format!( + "Expected {} Hz, got {} Hz", + SAMPLE_RATE, sample_rate + ))); + } + + // Convert to mono + if channels > 1 { + audio = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) + .collect(); + } + + // Reset state for new audio + self.reset_state(); + + // Extract mel features (B, T, D) + let features = self.extract_mel_features(&audio); + let total_frames = features.shape()[1]; + + // Process in chunks + let chunk_stride = CHUNK_LEN * SUBSAMPLING; + let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; + + let mut all_chunk_preds = Vec::new(); + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_stride; + let end = (start + chunk_stride).min(total_frames); + let current_len = end - start; + + // Extract chunk features + let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); + + // Pad last chunk if needed + if current_len < chunk_stride { + let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); + padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); + chunk_feat = padded; + } + + // Run streaming update + let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; + all_chunk_preds.push(chunk_preds); + } + + // Concatenate all predictions + let full_preds = Self::concat_predictions(&all_chunk_preds); + + // Apply median filtering + let filtered_preds = if self.config.median_window > 1 { + self.median_filter(&full_preds) + } else { + full_preds + }; + + // Binarize to segments + let segments = self.binarize(&filtered_preds); + + Ok(segments) + } + + /// Streaming diarization that maintains state across calls. + /// + /// Unlike `diarize()`, this method does NOT reset the internal state, + /// allowing speaker embeddings to be preserved across multiple audio chunks. + /// Call `reset_state()` manually when starting a new audio session. + /// + /// This enables consistent speaker identification across long audio streams + /// by maintaining the speaker cache between processing windows. + /// + /// # Arguments + /// * `audio` - Audio samples (will be converted to mono if multi-channel) + /// * `sample_rate` - Must be 16000 Hz + /// * `channels` - Number of audio channels + /// + /// # Example + /// ```ignore + /// // Start of session + /// sortformer.reset_state(); + /// + /// // Process sliding windows + /// let segments1 = sortformer.diarize_streaming(window1, 16000, 1)?; + /// let segments2 = sortformer.diarize_streaming(window2, 16000, 1)?; // Maintains speaker IDs + /// ``` + pub fn diarize_streaming( + &mut self, + mut audio: Vec<f32>, + sample_rate: u32, + channels: u16, + ) -> Result<Vec<SpeakerSegment>> { + // Resample if needed + if sample_rate != SAMPLE_RATE as u32 { + return Err(Error::Audio(format!( + "Expected {} Hz, got {} Hz", + SAMPLE_RATE, sample_rate + ))); + } + + // Convert to mono + if channels > 1 { + audio = audio + .chunks(channels as usize) + .map(|chunk| chunk.iter().sum::<f32>() / channels as f32) + .collect(); + } + + // NOTE: Unlike diarize(), we do NOT call reset_state() here + // This preserves speaker embeddings across calls + + // Extract mel features (B, T, D) + let features = self.extract_mel_features(&audio); + let total_frames = features.shape()[1]; + + // Process in chunks + let chunk_stride = CHUNK_LEN * SUBSAMPLING; + let num_chunks = (total_frames + chunk_stride - 1) / chunk_stride; + + let mut all_chunk_preds = Vec::new(); + + for chunk_idx in 0..num_chunks { + let start = chunk_idx * chunk_stride; + let end = (start + chunk_stride).min(total_frames); + let current_len = end - start; + + // Extract chunk features + let mut chunk_feat = features.slice(s![.., start..end, ..]).to_owned(); + + // Pad last chunk if needed + if current_len < chunk_stride { + let mut padded = Array3::zeros((1, chunk_stride, N_MELS)); + padded.slice_mut(s![.., ..current_len, ..]).assign(&chunk_feat); + chunk_feat = padded; + } + + // Run streaming update + let chunk_preds = self.streaming_update(&chunk_feat, current_len)?; + all_chunk_preds.push(chunk_preds); + } + + // Concatenate all predictions + let full_preds = Self::concat_predictions(&all_chunk_preds); + + // Apply median filtering + let filtered_preds = if self.config.median_window > 1 { + self.median_filter(&full_preds) + } else { + full_preds + }; + + // Binarize to segments + let segments = self.binarize(&filtered_preds); + + Ok(segments) + } + + /// NeMo's streaming_update with smart cache compression. + /// Public to allow incremental streaming diarization. + pub fn streaming_update(&mut self, chunk_feat: &Array3<f32>, current_len: usize) -> Result<Array2<f32>> { + let spkcache_len = self.spkcache.shape()[1]; + let fifo_len = self.fifo.shape()[1]; + + // Prepare inputs + let chunk_lengths = Array1::from_vec(vec![current_len as i64]); + let spkcache_lengths = Array1::from_vec(vec![spkcache_len as i64]); + let fifo_lengths = Array1::from_vec(vec![fifo_len as i64]); + + // Prepare FIFO input + let fifo_input = if fifo_len > 0 { + self.fifo.clone() + } else { + Array3::zeros((1, 0, EMB_DIM)) + }; + + // Prepare spkcache input (may be empty) + let spkcache_input = if spkcache_len > 0 { + self.spkcache.clone() + } else { + Array3::zeros((1, 0, EMB_DIM)) + }; + + // Create input values + let chunk_value = ort::value::Value::from_array(chunk_feat.clone())?; + let chunk_lengths_value = ort::value::Value::from_array(chunk_lengths)?; + let spkcache_value = ort::value::Value::from_array(spkcache_input)?; + let spkcache_lengths_value = ort::value::Value::from_array(spkcache_lengths)?; + let fifo_value = ort::value::Value::from_array(fifo_input)?; + let fifo_lengths_value = ort::value::Value::from_array(fifo_lengths)?; + + // Run ONNX inference and extract all data in a block to release borrow + let (preds, new_embs, chunk_len) = { + let outputs = self.session.run(ort::inputs!( + "chunk" => chunk_value, + "chunk_lengths" => chunk_lengths_value, + "spkcache" => spkcache_value, + "spkcache_lengths" => spkcache_lengths_value, + "fifo" => fifo_value, + "fifo_lengths" => fifo_lengths_value + ))?; + + // Extract outputs + let (preds_shape, preds_data) = outputs["spkcache_fifo_chunk_preds"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract preds: {e}")))?; + let (embs_shape, embs_data) = outputs["chunk_pre_encode_embs"] + .try_extract_tensor::<f32>() + .map_err(|e| Error::Model(format!("Failed to extract embs: {e}")))?; + + // Convert to ndarray + let preds_dims = preds_shape.as_ref(); + let embs_dims = embs_shape.as_ref(); + + let preds = Array3::from_shape_vec( + (preds_dims[0] as usize, preds_dims[1] as usize, preds_dims[2] as usize), + preds_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape preds: {e}")))?; + + let new_embs = Array3::from_shape_vec( + (embs_dims[0] as usize, embs_dims[1] as usize, embs_dims[2] as usize), + embs_data.to_vec() + ).map_err(|e| Error::Model(format!("Failed to reshape embs: {e}")))?; + + // Calculate valid frames + let valid_frames = (current_len + SUBSAMPLING - 1) / SUBSAMPLING; + + (preds, new_embs, valid_frames) + }; + + // Extract predictions for different parts + let fifo_preds = if fifo_len > 0 { + preds.slice(s![0, spkcache_len..spkcache_len + fifo_len, ..]).to_owned() + } else { + Array2::zeros((0, NUM_SPEAKERS)) + }; + + let chunk_preds = preds.slice(s![0, spkcache_len + fifo_len..spkcache_len + fifo_len + chunk_len, ..]).to_owned(); + let chunk_embs = new_embs.slice(s![0, ..chunk_len, ..]).to_owned(); + + // Append chunk embeddings to FIFO + self.fifo = Self::concat_axis1(&self.fifo, &chunk_embs.insert_axis(Axis(0))); + + // Update FIFO predictions + if fifo_len > 0 { + let combined = Self::concat_axis1_2d(&fifo_preds, &chunk_preds); + self.fifo_preds = combined.insert_axis(Axis(0)); + } else { + self.fifo_preds = chunk_preds.clone().insert_axis(Axis(0)); + } + + let fifo_len_after = self.fifo.shape()[1]; + + // Move from FIFO to cache when FIFO exceeds limit + if fifo_len_after > FIFO_LEN { + let mut pop_out_len = SPKCACHE_UPDATE_PERIOD; + pop_out_len = pop_out_len.max(chunk_len.saturating_sub(FIFO_LEN) + fifo_len); + pop_out_len = pop_out_len.min(fifo_len_after); + + let pop_out_embs = self.fifo.slice(s![.., ..pop_out_len, ..]).to_owned(); + let pop_out_preds = self.fifo_preds.slice(s![.., ..pop_out_len, ..]).to_owned(); + + // Update silence profile + self.update_silence_profile(&pop_out_embs, &pop_out_preds); + + // Remove from FIFO + self.fifo = self.fifo.slice(s![.., pop_out_len.., ..]).to_owned(); + self.fifo_preds = self.fifo_preds.slice(s![.., pop_out_len.., ..]).to_owned(); + + // Append to cache + self.spkcache = Self::concat_axis1(&self.spkcache, &pop_out_embs); + + if let Some(ref cache_preds) = self.spkcache_preds { + self.spkcache_preds = Some(Self::concat_axis1(cache_preds, &pop_out_preds)); + } + + // Smart compression when cache exceeds limit + if self.spkcache.shape()[1] > SPKCACHE_LEN { + if self.spkcache_preds.is_none() { + // Initialize cache predictions from initial output + let initial_cache_preds = preds.slice(s![.., ..spkcache_len, ..]).to_owned(); + let combined = Self::concat_axis1(&initial_cache_preds, &pop_out_preds); + self.spkcache_preds = Some(combined); + } + + // Use smart compression + self.compress_spkcache(); + } + } + + Ok(chunk_preds) + } + + /// Update mean silence embedding + fn update_silence_profile(&mut self, embs: &Array3<f32>, preds: &Array3<f32>) { + let preds_2d = preds.slice(s![0, .., ..]); + + for t in 0..preds_2d.shape()[0] { + let sum: f32 = (0..NUM_SPEAKERS).map(|s| preds_2d[[t, s]]).sum(); + if sum < SIL_THRESHOLD { + // This is a silence frame + let emb = embs.slice(s![0, t, ..]); + + // Update running mean + let old_sum: Vec<f32> = self.mean_sil_emb.slice(s![0, ..]).iter() + .map(|&x| x * self.n_sil_frames as f32) + .collect(); + + self.n_sil_frames += 1; + + for i in 0..EMB_DIM { + self.mean_sil_emb[[0, i]] = (old_sum[i] + emb[i]) / self.n_sil_frames as f32; + } + } + } + } + + /// Smart cache compression + fn compress_spkcache(&mut self) { + let cache_preds = match &self.spkcache_preds { + Some(p) => p.clone(), + None => return, + }; + + let n_frames = self.spkcache.shape()[1]; + let spkcache_len_per_spk = SPKCACHE_LEN / NUM_SPEAKERS - SPKCACHE_SIL_FRAMES_PER_SPK; + let strong_boost_per_spk = (spkcache_len_per_spk as f32 * STRONG_BOOST_RATE) as usize; + let weak_boost_per_spk = (spkcache_len_per_spk as f32 * WEAK_BOOST_RATE) as usize; + let min_pos_scores_per_spk = (spkcache_len_per_spk as f32 * MIN_POS_SCORES_RATE) as usize; + + // Calculate quality scores + let preds_2d = cache_preds.slice(s![0, .., ..]).to_owned(); + let mut scores = self.get_log_pred_scores(&preds_2d); + + // Disable low scores + scores = self.disable_low_scores(&preds_2d, scores, min_pos_scores_per_spk); + + // Boost important frames + scores = self.boost_topk_scores(scores, strong_boost_per_spk, 2.0); + scores = self.boost_topk_scores(scores, weak_boost_per_spk, 1.0); + + // Add silence frames placeholder + if SPKCACHE_SIL_FRAMES_PER_SPK > 0 { + let mut padded = Array2::from_elem((n_frames + SPKCACHE_SIL_FRAMES_PER_SPK, NUM_SPEAKERS), f32::NEG_INFINITY); + padded.slice_mut(s![..n_frames, ..]).assign(&scores); + for i in n_frames..n_frames + SPKCACHE_SIL_FRAMES_PER_SPK { + for j in 0..NUM_SPEAKERS { + padded[[i, j]] = f32::INFINITY; + } + } + scores = padded; + } + + // Select top frames + let (topk_indices, is_disabled) = self.get_topk_indices(&scores, n_frames); + + // Gather embeddings + let (new_embs, new_preds) = self.gather_spkcache(&topk_indices, &is_disabled); + + self.spkcache = new_embs; + self.spkcache_preds = Some(new_preds); + } + + /// Calculate quality scores + fn get_log_pred_scores(&self, preds: &Array2<f32>) -> Array2<f32> { + let mut scores = Array2::zeros(preds.dim()); + + for t in 0..preds.shape()[0] { + let mut log_1_probs_sum = 0.0f32; + for s in 0..NUM_SPEAKERS { + let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); + let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); + log_1_probs_sum += log_1_p; + } + + for s in 0..NUM_SPEAKERS { + let p = preds[[t, s]].max(PRED_SCORE_THRESHOLD); + let log_p = p.ln(); + let log_1_p = (1.0 - p).max(PRED_SCORE_THRESHOLD).ln(); + scores[[t, s]] = log_p - log_1_p + log_1_probs_sum - 0.5f32.ln(); + } + } + + scores + } + + /// Disable non-speech and overlapped speech + fn disable_low_scores(&self, preds: &Array2<f32>, mut scores: Array2<f32>, min_pos_scores_per_spk: usize) -> Array2<f32> { + // Count positive scores per speaker + let mut pos_count = vec![0usize; NUM_SPEAKERS]; + for t in 0..scores.shape()[0] { + for s in 0..NUM_SPEAKERS { + if scores[[t, s]] > 0.0 { + pos_count[s] += 1; + } + } + } + + for t in 0..preds.shape()[0] { + for s in 0..NUM_SPEAKERS { + let is_speech = preds[[t, s]] > 0.5; + + if !is_speech { + scores[[t, s]] = f32::NEG_INFINITY; + } else { + let is_pos = scores[[t, s]] > 0.0; + if !is_pos && pos_count[s] >= min_pos_scores_per_spk { + scores[[t, s]] = f32::NEG_INFINITY; + } + } + } + } + + scores + } + + /// Boost top K frames per speaker + fn boost_topk_scores(&self, mut scores: Array2<f32>, n_boost_per_spk: usize, scale_factor: f32) -> Array2<f32> { + for s in 0..NUM_SPEAKERS { + // Get column for this speaker + let col: Vec<(usize, f32)> = (0..scores.shape()[0]) + .map(|t| (t, scores[[t, s]])) + .collect(); + + // Sort by score descending + let mut sorted = col.clone(); + sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + // Boost top K + for i in 0..n_boost_per_spk.min(sorted.len()) { + let t = sorted[i].0; + if scores[[t, s]] != f32::NEG_INFINITY { + scores[[t, s]] -= scale_factor * 0.5f32.ln(); + } + } + } + + scores + } + + /// Get indices of top frames + fn get_topk_indices(&self, scores: &Array2<f32>, n_frames_no_sil: usize) -> (Vec<usize>, Vec<bool>) { + let n_frames = scores.shape()[0]; + + // Flatten scores as (S, T) then reshape to (S*T,) + // This means we iterate: speaker 0 all times, then speaker 1 all times, etc. + // flat_index = speaker * n_frames + time + let mut flat_scores: Vec<(usize, f32)> = Vec::with_capacity(n_frames * NUM_SPEAKERS); + for s in 0..NUM_SPEAKERS { + for t in 0..n_frames { + let flat_idx = s * n_frames + t; + flat_scores.push((flat_idx, scores[[t, s]])); + } + } + + // Sort by score descending to get top-K + flat_scores.sort_by(|a, b| { + b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) + }); + + // Take top SPKCACHE_LEN and replace invalid scores with MAX_INDEX + let mut topk_flat: Vec<usize> = flat_scores + .iter() + .take(SPKCACHE_LEN) + .map(|(idx, score)| { + if *score == f32::NEG_INFINITY { + MAX_INDEX + } else { + *idx + } + }) + .collect(); + + // Sort flat indices ascending (this puts MAX_INDEX at the end) + topk_flat.sort(); + + // Compute is_disabled and convert to frame indices + let mut is_disabled = vec![false; SPKCACHE_LEN]; + let mut frame_indices = vec![0usize; SPKCACHE_LEN]; + + for (i, &flat_idx) in topk_flat.iter().enumerate() { + if flat_idx == MAX_INDEX { + // Invalid entries are disabled + is_disabled[i] = true; + frame_indices[i] = 0; // We set disabled to 0 + } else { + // convert to frame index + let frame_idx = flat_idx % n_frames; + + // check if frame is beyond valid range + if frame_idx >= n_frames_no_sil { + is_disabled[i] = true; + frame_indices[i] = 0; // same as abov: set disabled to 0 + } else { + frame_indices[i] = frame_idx; + } + } + } + + (frame_indices, is_disabled) + } + + /// Gather selected frames + fn gather_spkcache(&self, indices: &[usize], is_disabled: &[bool]) -> (Array3<f32>, Array3<f32>) { + let mut new_embs = Array3::zeros((1, SPKCACHE_LEN, EMB_DIM)); + let mut new_preds = Array3::zeros((1, SPKCACHE_LEN, NUM_SPEAKERS)); + + let cache_preds = self.spkcache_preds.as_ref().unwrap(); + + for (i, (&idx, &disabled)) in indices.iter().zip(is_disabled.iter()).enumerate() { + if i >= SPKCACHE_LEN { + break; + } + + if disabled { + // Use silence embedding + new_embs.slice_mut(s![0, i, ..]).assign(&self.mean_sil_emb.slice(s![0, ..])); + // Predictions stay zero + } else if idx < self.spkcache.shape()[1] { + new_embs.slice_mut(s![0, i, ..]).assign(&self.spkcache.slice(s![0, idx, ..])); + new_preds.slice_mut(s![0, i, ..]).assign(&cache_preds.slice(s![0, idx, ..])); + } + } + + (new_embs, new_preds) + } + + /// Concatenate along axis 1 for 3D arrays + fn concat_axis1(a: &Array3<f32>, b: &Array3<f32>) -> Array3<f32> { + if a.shape()[1] == 0 { + return b.clone(); + } + if b.shape()[1] == 0 { + return a.clone(); + } + ndarray::concatenate(Axis(1), &[a.view(), b.view()]).unwrap() + } + + /// Concatenate along axis 0 for 2D arrays + fn concat_axis1_2d(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> { + if a.shape()[0] == 0 { + return b.clone(); + } + if b.shape()[0] == 0 { + return a.clone(); + } + ndarray::concatenate(Axis(0), &[a.view(), b.view()]).unwrap() + } + + /// Concatenate predictions + fn concat_predictions(preds: &[Array2<f32>]) -> Array2<f32> { + if preds.is_empty() { + return Array2::zeros((0, NUM_SPEAKERS)); + } + if preds.len() == 1 { + return preds[0].clone(); + } + + let views: Vec<_> = preds.iter().map(|p| p.view()).collect(); + ndarray::concatenate(Axis(0), &views).unwrap() + } + + /// Apply median filter to predictions + fn median_filter(&self, preds: &Array2<f32>) -> Array2<f32> { + let window = self.config.median_window; + let half = window / 2; + let mut filtered = preds.clone(); + + for spk in 0..NUM_SPEAKERS { + for t in 0..preds.shape()[0] { + let start = t.saturating_sub(half); + let end = (t + half + 1).min(preds.shape()[0]); + + let mut values: Vec<f32> = (start..end) + .map(|i| preds[[i, spk]]) + .collect(); + values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + filtered[[t, spk]] = values[values.len() / 2]; + } + } + + filtered + } + + /// Binarize predictions to segments (padding applied during thresholding) + fn binarize(&self, preds: &Array2<f32>) -> Vec<SpeakerSegment> { + let mut segments = Vec::new(); + let num_frames = preds.shape()[0]; + + for spk in 0..NUM_SPEAKERS { + let mut in_seg = false; + let mut seg_start = 0; + let mut temp_segments = Vec::new(); + + for t in 0..num_frames { + let p = preds[[t, spk]]; + + if p >= self.config.onset && !in_seg { + in_seg = true; + seg_start = t; + } else if p < self.config.offset && in_seg { + in_seg = false; + + // Apply padding during conversion + let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); + let end_t = t as f32 * FRAME_DURATION + self.config.pad_offset; + + if end_t - start_t >= self.config.min_duration_on { + temp_segments.push(SpeakerSegment { + start: start_t, + end: end_t, + speaker_id: spk, + }); + } + } + } + + // Handle segment at end + if in_seg { + let start_t = (seg_start as f32 * FRAME_DURATION - self.config.pad_onset).max(0.0); + let end_t = num_frames as f32 * FRAME_DURATION + self.config.pad_offset; + + if end_t - start_t >= self.config.min_duration_on { + temp_segments.push(SpeakerSegment { + start: start_t, + end: end_t, + speaker_id: spk, + }); + } + } + + // Merge close segments (min_duration_off) + if temp_segments.len() > 1 { + let mut filtered = vec![temp_segments[0].clone()]; + for seg in temp_segments.into_iter().skip(1) { + let last = filtered.last_mut().unwrap(); + let gap = seg.start - last.end; + if gap < self.config.min_duration_off { + last.end = seg.end; // Merge + } else { + filtered.push(seg); + } + } + segments.extend(filtered); + } else { + segments.extend(temp_segments); + } + } + + // Sort by start time + segments.sort_by(|a, b| a.start.partial_cmp(&b.start).unwrap()); + segments + } + + + fn apply_preemphasis(audio: &[f32]) -> Vec<f32> { + let mut result = Vec::with_capacity(audio.len()); + result.push(audio[0]); + for i in 1..audio.len() { + result.push(audio[i] - PREEMPH * audio[i - 1]); + } + result + } + + fn hann_window(window_length: usize) -> Vec<f32> { + // Librosa uses periodic window (fftbins=True): divide by N, not N-1 + (0..window_length) + .map(|i| 0.5 - 0.5 * ((2.0 * PI * i as f32) / window_length as f32).cos()) + .collect() + } + + fn stft(audio: &[f32]) -> Array2<f32> { + let mut planner = FftPlanner::<f32>::new(); + let fft = planner.plan_fft_forward(N_FFT); + + // Create Hann window of length win_length, then zero-pad to n_fft (centered) + // This is exactly what librosa does: util.pad_center(fft_window, size=n_fft) + let hann = Self::hann_window(WIN_LENGTH); + let win_offset = (N_FFT - WIN_LENGTH) / 2; + let mut fft_window = vec![0.0f32; N_FFT]; + for i in 0..WIN_LENGTH { + fft_window[win_offset + i] = hann[i]; + } + + // Pad signal for center=True (like librosa/torch.stft) + // Padding is n_fft // 2 on each side + let pad_amount = N_FFT / 2; + let mut padded_audio = vec![0.0; pad_amount]; + padded_audio.extend_from_slice(audio); + padded_audio.extend(vec![0.0; pad_amount]); + + let num_frames = (padded_audio.len() - N_FFT) / HOP_LENGTH + 1; + let freq_bins = N_FFT / 2 + 1; + let mut spectrogram = Array2::<f32>::zeros((freq_bins, num_frames)); + + for frame_idx in 0..num_frames { + let start = frame_idx * HOP_LENGTH; + let mut frame: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); N_FFT]; + + // Extract n_fft samples and multiply by zero-padded window + for i in 0..N_FFT { + if start + i < padded_audio.len() { + frame[i] = Complex::new(padded_audio[start + i] * fft_window[i], 0.0); + } + } + + fft.process(&mut frame); + for k in 0..freq_bins { + let magnitude = frame[k].norm(); + // Power spectrum (magnitude^2) - NeMo uses mag_power=2.0 + spectrogram[[k, frame_idx]] = magnitude * magnitude; + } + } + + spectrogram + } + + // Librosa's Slaney mel scale (htk=False, which is the default) + fn hz_to_mel_slaney(hz: f64) -> f64 { + let f_min = 0.0; + let f_sp = 200.0 / 3.0; + let min_log_hz = 1000.0; + let min_log_mel = (min_log_hz - f_min) / f_sp; + let logstep = (6.4f64).ln() / 27.0; + + if hz >= min_log_hz { + min_log_mel + (hz / min_log_hz).ln() / logstep + } else { + (hz - f_min) / f_sp + } + } + + fn mel_to_hz_slaney(mel: f64) -> f64 { + let f_min = 0.0; + let f_sp = 200.0 / 3.0; + let min_log_hz = 1000.0; + let min_log_mel = (min_log_hz - f_min) / f_sp; + let logstep = (6.4f64).ln() / 27.0; + + if mel >= min_log_mel { + min_log_hz * (logstep * (mel - min_log_mel)).exp() + } else { + f_min + f_sp * mel + } + } + + fn create_mel_filterbank() -> Array2<f32> { + // lets use f64 for intermediate calculations to avoid precision loss + let freq_bins = N_FFT / 2 + 1; + let mut filterbank = Array2::<f32>::zeros((N_MELS, freq_bins)); + + // FFT frequencies: fftfreqs[k] = k * sr / n_fft + let fftfreqs: Vec<f64> = (0..freq_bins) + .map(|k| k as f64 * SAMPLE_RATE as f64 / N_FFT as f64) + .collect(); + + // Mel center frequencies using Slaney scale (librosa default, htk=False) + let fmin_mel = Self::hz_to_mel_slaney(FMIN as f64); + let fmax_mel = Self::hz_to_mel_slaney(FMAX as f64); + let mel_f: Vec<f64> = (0..=N_MELS + 1) + .map(|i| { + let mel = fmin_mel + (fmax_mel - fmin_mel) * i as f64 / (N_MELS + 1) as f64; + Self::mel_to_hz_slaney(mel) + }) + .collect(); + + // Differences between consecutive mel frequencies + let fdiff: Vec<f64> = mel_f.windows(2).map(|w| w[1] - w[0]).collect(); + + // Compute filterbank weights (reference: librosa's ramp method) + // https://librosa.org/doc/main/generated/librosa.stft.html + for i in 0..N_MELS { + for k in 0..freq_bins { + // Lower slope: (fftfreqs[k] - mel_f[i]) / fdiff[i] + let lower = (fftfreqs[k] - mel_f[i]) / fdiff[i]; + // Upper slope: (mel_f[i+2] - fftfreqs[k]) / fdiff[i+1] + let upper = (mel_f[i + 2] - fftfreqs[k]) / fdiff[i + 1]; + // Weight is max(0, min(lower, upper)) + filterbank[[i, k]] = 0.0f64.max(lower.min(upper)) as f32; + } + } + + // Apply Slaney normalization: 2.0 / (mel_f[i+2] - mel_f[i]) + for i in 0..N_MELS { + let enorm = 2.0 / (mel_f[i + 2] - mel_f[i]); + for k in 0..freq_bins { + filterbank[[i, k]] *= enorm as f32; + } + } + + filterbank + } + + fn extract_mel_features(&self, audio: &[f32]) -> Array3<f32> { + // 1. Add dither (small random noise to prevent log(0)) + // NeMo uses dither=1e-5, but for determinism we skip random noise + // The log_zero_guard handles zero values + + // 2. Apply preemphasis (NeMo uses preemph=0.97) + let preemphasized = Self::apply_preemphasis(audio); + + // 3. STFT + let spectrogram = Self::stft(&preemphasized); + + // 4. Apply mel filterbank (with Slaney normalization) + let mel_spec = self.mel_basis.dot(&spectrogram); + + // 5. Log with guard value (NeMo uses log_zero_guard_value = 2^-24) + // NeMo uses normalize='NA' which means NO normalization + let log_mel_spec = mel_spec.mapv(|x| (x + LOG_ZERO_GUARD).ln()); + + let num_frames = log_mel_spec.shape()[1]; + let mut features = Array3::<f32>::zeros((1, num_frames, N_MELS)); + + // Transpose to (batch, time, features) - NeMo outputs (B, D, T), model expects (B, T, D) + for t in 0..num_frames { + for m in 0..N_MELS { + features[[0, t, m]] = log_mel_spec[[m, t]]; + } + } + + features + } +} diff --git a/vendor/parakeet-rs/src/timestamps.rs b/vendor/parakeet-rs/src/timestamps.rs new file mode 100644 index 0000000..81ea600 --- /dev/null +++ b/vendor/parakeet-rs/src/timestamps.rs @@ -0,0 +1,280 @@ +use crate::decoder::TimedToken; + +/// Timestamp output mode for transcription results +/// +/// Determines how token-level timestamps are grouped and presented: +/// - `Tokens`: Raw token-level output from the model (most detailed) +/// - `Words`: Tokens grouped into individual words +/// - `Sentences`: Tokens grouped by sentence boundaries (., ?, !) +/// +/// # Model-Specific Recommendations +/// +/// - **Parakeet CTC (English)**: Use `Words` mode. The CTC model only outputs lowercase +/// alphabet without punctuation, so sentence segmentation is not possible. +/// - **Parakeet TDT (Multilingual)**: Use `Sentences` mode. The TDT model predicts +/// punctuation, enabling natural sentence boundaries. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TimestampMode { + /// Raw token-level timestamps from the model + Tokens, + /// Word-level timestamps (groups subword tokens) + Words, + /// Sentence-level timestamps (groups by punctuation) + /// + /// Note: Only works with models that predict punctuation (e.g., Parakeet TDT). + /// CTC models don't predict punctuation, so use `Words` mode instead. + Sentences, +} + +impl Default for TimestampMode { + fn default() -> Self { + Self::Tokens + } +} + +/// Convert token timestamps to the requested output mode +/// +/// Takes raw token-level timestamps from the model and optionally groups them +/// into words or sentences while preserving the original timing information. +/// +/// # Arguments +/// +/// * `tokens` - Raw token-level timestamps from model output +/// * `mode` - Desired grouping level (Tokens, Words, or Sentences) +/// +/// # Returns +/// +/// Vector of TimedToken with timestamps at the requested granularity +pub fn process_timestamps(tokens: &[TimedToken], mode: TimestampMode) -> Vec<TimedToken> { + match mode { + TimestampMode::Tokens => tokens.to_vec(), + TimestampMode::Words => group_by_words(tokens), + TimestampMode::Sentences => group_by_sentences(tokens), + } +} + +// Group tokens into words based on word boundary markers +fn group_by_words(tokens: &[TimedToken]) -> Vec<TimedToken> { + if tokens.is_empty() { + return Vec::new(); + } + + let mut words = Vec::new(); + let mut current_word_text = String::new(); + let mut current_word_start = 0.0; + let mut last_word_lower = String::new(); + + for (i, token) in tokens.iter().enumerate() { + // Skip empty tokens + if token.text.trim().is_empty() { + continue; + } + + // Check if this starts a new word (SentencePiece uses ▁ or space prefix) + // Also treat PURE punctuation marks (like ".", ",") as separate words + // But NOT contractions like "'re" or "'s" which should attach to previous word + let is_pure_punctuation = !token.text.is_empty() && + token.text.chars().all(|c| c.is_ascii_punctuation()); + + // Check if this is a contraction suffix + // These should NOT start a new word - they attach to the previous word + let token_without_marker = token.text.trim_start_matches('▁').trim_start_matches(' '); + let is_contraction = token_without_marker.starts_with('\''); + + let starts_word = (token.text.starts_with('▁') + || token.text.starts_with(' ') + || is_pure_punctuation) + && !is_contraction + || i == 0; + + if starts_word && !current_word_text.is_empty() { + // Save previous word (with deduplication) + let word_lower = current_word_text.to_lowercase(); + if word_lower != last_word_lower { + words.push(TimedToken { + text: current_word_text.clone(), + start: current_word_start, + end: tokens[i - 1].end, + }); + last_word_lower = word_lower; + } + current_word_text.clear(); + } + + // Start new word or append to current + if current_word_text.is_empty() { + current_word_start = token.start; + } + + // Add token text, removing word boundary markers + let token_text = token + .text + .trim_start_matches('▁') + .trim_start_matches(' '); + current_word_text.push_str(token_text); + } + + // Add final word + if !current_word_text.is_empty() { + let word_lower = current_word_text.to_lowercase(); + if word_lower != last_word_lower { + words.push(TimedToken { + text: current_word_text, + start: current_word_start, + end: tokens.last().unwrap().end, + }); + } + } + + words +} + +// Group words into sentences based on punctuation +fn group_by_sentences(tokens: &[TimedToken]) -> Vec<TimedToken> { + // First get word-level grouping + let words = group_by_words(tokens); + if words.is_empty() { + return Vec::new(); + } + + let mut sentences = Vec::new(); + let mut current_sentence = Vec::new(); + + for word in words { + current_sentence.push(word.clone()); + + // Check if word ends with sentence terminator + let ends_sentence = word.text.contains('.') + || word.text.contains('?') + || word.text.contains('!'); + + if ends_sentence { + let sentence_text = format_sentence(¤t_sentence); + let start = current_sentence.first().unwrap().start; + let end = current_sentence.last().unwrap().end; + + if !sentence_text.is_empty() { + sentences.push(TimedToken { + text: sentence_text, + start, + end, + }); + } + current_sentence.clear(); + } + } + + // Add final sentence if exists + if !current_sentence.is_empty() { + let sentence_text = format_sentence(¤t_sentence); + let start = current_sentence.first().unwrap().start; + let end = current_sentence.last().unwrap().end; + + if !sentence_text.is_empty() { + sentences.push(TimedToken { + text: sentence_text, + start, + end, + }); + } + } + + sentences +} + +// Join words with punctuation spacing +fn format_sentence(words: &[TimedToken]) -> String { + let result: Vec<&str> = words.iter().map(|w| w.text.as_str()).collect(); + + // Join words, but don't add space before certain punctuation + let mut output = String::new(); + for (i, word) in result.iter().enumerate() { + // Check if this word is standalone punctuation that shouldn't have space before it + // Contractions like "'re" or "'s" should have spaces before them + let is_standalone_punct = word.len() == 1 && + word.chars().all(|c| matches!(c, '.' | ',' | '!' | '?' | ';' | ':' | ')')); + + if i > 0 && !is_standalone_punct { + output.push(' '); + } + output.push_str(word); + } + output +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_word_grouping() { + let tokens = vec![ + TimedToken { + text: "▁Hello".to_string(), + start: 0.0, + end: 0.5, + }, + TimedToken { + text: "▁world".to_string(), + start: 0.5, + end: 1.0, + }, + ]; + + let words = group_by_words(&tokens); + assert_eq!(words.len(), 2); + assert_eq!(words[0].text, "Hello"); + assert_eq!(words[1].text, "world"); + } + + #[test] + fn test_sentence_grouping() { + let tokens = vec![ + TimedToken { + text: "▁Hello".to_string(), + start: 0.0, + end: 0.5, + }, + TimedToken { + text: "▁world".to_string(), + start: 0.5, + end: 1.0, + }, + TimedToken { + text: ".".to_string(), + start: 1.0, + end: 1.1, + }, + ]; + + let sentences = group_by_sentences(&tokens); + assert_eq!(sentences.len(), 1); + assert_eq!(sentences[0].text, "Hello world."); + assert_eq!(sentences[0].start, 0.0); + assert_eq!(sentences[0].end, 1.1); + } + + #[test] + fn test_repetition_preservation() { + let words = vec![ + TimedToken { + text: "uh".to_string(), + start: 0.0, + end: 0.5, + }, + TimedToken { + text: "uh".to_string(), + start: 0.5, + end: 1.0, + }, + TimedToken { + text: "hello".to_string(), + start: 1.0, + end: 1.5, + }, + ]; + + let result = format_sentence(&words); + assert_eq!(result, "uh uh hello"); + } +} diff --git a/vendor/parakeet-rs/src/vocab.rs b/vendor/parakeet-rs/src/vocab.rs new file mode 100644 index 0000000..888568e --- /dev/null +++ b/vendor/parakeet-rs/src/vocab.rs @@ -0,0 +1,63 @@ +use crate::error::{Error, Result}; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; + +/// Vocabulary parser for vocab.txt format used by TDT models +#[derive(Debug, Clone)] +pub struct Vocabulary { + pub id_to_token: Vec<String>, + pub _blank_id: usize, +} + +impl Vocabulary { + /// Load vocabulary from vocab.txt file + pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> { + let file = File::open(path.as_ref()).map_err(|e| { + Error::Config(format!("Failed to open vocab file: {}", e)) + })?; + + let reader = BufReader::new(file); + let mut id_to_token = Vec::new(); + let mut blank_id = 0; + + for line in reader.lines() { + let line = line.map_err(|e| { + Error::Config(format!("Failed to read vocab file: {}", e)) + })?; + + let parts: Vec<&str> = line.splitn(2, ' ').collect(); + if parts.len() == 2 { + let token = parts[0].to_string(); + let id: usize = parts[1].parse().map_err(|e| { + Error::Config(format!("Invalid token ID in vocab: {}", e)) + })?; + + if id >= id_to_token.len() { + id_to_token.resize(id + 1, String::new()); + } + id_to_token[id] = token.clone(); + + // Track blank token + if token == "<blk>" || token == "<blank>" { + blank_id = id; + } + } + } + + // Default to last token if no blank found + if blank_id == 0 && !id_to_token.is_empty() { + blank_id = id_to_token.len() - 1; + } + + Ok(Self { + id_to_token, + _blank_id: blank_id, + }) + } + + /// Get token by ID + pub fn id_to_text(&self, id: usize) -> Option<&str> { + self.id_to_token.get(id).map(|s| s.as_str()) + } +} |
