summaryrefslogtreecommitdiff
path: root/makima/src/tts/qwen3/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'makima/src/tts/qwen3/model.rs')
-rw-r--r--makima/src/tts/qwen3/model.rs11
1 files changed, 7 insertions, 4 deletions
diff --git a/makima/src/tts/qwen3/model.rs b/makima/src/tts/qwen3/model.rs
index 8a1e986..e19e5f9 100644
--- a/makima/src/tts/qwen3/model.rs
+++ b/makima/src/tts/qwen3/model.rs
@@ -389,9 +389,12 @@ pub struct Qwen3Model {
impl Qwen3Model {
pub fn new(config: &Qwen3LmConfig, vb: VarBuilder) -> Result<Self> {
- let model_vb = vb.pp("model");
+ // HuggingFace Qwen3-TTS uses "talker.model.*" prefix
+ let talker_vb = vb.pp("talker");
+ let model_vb = talker_vb.pp("model");
- let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("embed_tokens"))?;
+ // Text embedding (called "text_embedding" in HF, not "embed_tokens")
+ let embed_tokens = embedding(config.vocab_size, config.hidden_size, model_vb.pp("text_embedding"))?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
@@ -401,8 +404,8 @@ impl Qwen3Model {
let norm = rms_norm(config.hidden_size, config.rms_norm_eps, model_vb.pp("norm"))?;
- // LM head — may or may not share weights with embed_tokens
- let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?;
+ // Codec head (called "codec_head" in HF, not "lm_head")
+ let lm_head = linear_no_bias(config.hidden_size, config.vocab_size, talker_vb.pp("codec_head"))?;
let dtype = vb.dtype();
let device = vb.device().clone();