diff options
Diffstat (limited to 'makima/src/tts/qwen3/model.rs')
| -rw-r--r-- | makima/src/tts/qwen3/model.rs | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs index 8a1e986..e19e5f9 100644 --- a/makima/src/tts/qwen3/model.rs +++ b/makima/src/tts/qwen3/model.rs @@ -389,9 +389,12 @@ pub struct Qwen3Model { impl Qwen3Model { pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> { - let model_vb = vb.pp("model"); + // HuggingFace Qwen3-TTS uses "talker.model.*" prefix + let talker_vb = vb.pp("talker"); + let model_vb = talker_vb.pp("model"); - let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("embed_tokens"))?; + // Text embedding (called "text_embedding" in HF, not "embed_tokens") + let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?; let mut layers = Vec::with_capacity(config.num_hidden_layers); for i in 0..config.num_hidden_layers { @@ -401,8 +404,8 @@ impl Qwen3Model { let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?; - // LM head — may or may not share weights with embed_tokens - let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?; + // Codec head (called "codec_head" in HF, not "lm_head") + let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?; let dtype = vb.dtype(); let device = vb.device().clone(); |
