summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/model.rs
diff options
context:
space:
mode:
authorsoryu <soryu@soryu.co>2026-01-30 03:07:52 +0000
committersoryu <soryu@soryu.co>2026-01-30 03:07:52 +0000
commitc526f93aa4255cb581eeb3f7a495c1689683b0a2 (patch)
treefbdc579d04fe92dc610ec8c84b77eeffb9141622 /makima/src/tts/qwen3/model.rs
parenta9655dccdad116db2b92c13794ddd559f160148d (diff)
downloadsoryu-c526f93aa4255cb581eeb3f7a495c1689683b0a2.tar.gz
soryu-c526f93aa4255cb581eeb3f7a495c1689683b0a2.zip
Fix Qwen3-TTS tensor paths to match HuggingFace model structure
The HuggingFace model uses different tensor name prefixes: - talker.model.text_embedding instead of model.embed_tokens - talker.codec_head instead of lm_head - talker.code_predictor.model.codec_embedding instead of code_embeddings - talker.code_predictor.lm_head instead of output_heads Also removed input_proj which doesn't exist in the HF model. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'makima/src/tts/qwen3/model.rs')
-rw-r--r--makima/src/tts/qwen3/model.rs11
1 files changed, 7 insertions, 4 deletions
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs
index 8a1e986..e19e5f9 100644
--- a/makima/src/tts/qwen3/model.rs
+++ b/makima/src/tts/qwen3/model.rs
@@ -389,9 +389,12 @@ pub struct Qwen3Model {
impl Qwen3Model {
pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- let model_vb = vb.pp("model");
+ // HuggingFace Qwen3-TTS uses "talker.model.*" prefix
+ let talker_vb = vb.pp("talker");
+ let model_vb = talker_vb.pp("model");
- let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("embed_tokens"))?;
+ // Text embedding (called "text_embedding" in HF, not "embed_tokens")
+ let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
@@ -401,8 +404,8 @@ impl Qwen3Model {
let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?;
- // LM head — may or may not share weights with embed_tokens
- let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
+ // Codec head (called "codec_head" in HF, not "lm_head")
+ let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?;
let dtype = vb.dtype();
let device = vb.device().clone();