From 46110e0630f9d52f8289c26dd9ec07c3e960e4fe Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 20 Sep 2025 12:00:14 +0700 Subject: [PATCH] split q_proj/gate --- convert_hf_to_gguf.py | 6 +++++- gguf-py/gguf/constants.py | 3 +++ src/llama-arch.cpp | 2 ++ src/llama-arch.h | 1 + src/llama-model.cpp | 39 +++++++++++++++++++-------------------- src/llama-model.h | 1 + 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 85cb9b2142..f6ffc9c1fc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3767,8 +3767,12 @@ class Qwen3NextModel(Qwen3MoeModel): name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" elif "conv1d" in name: data_torch = data_torch.squeeze() + elif "q_proj.weight" in name: + q_proj, gate = data_torch.chunk(2, dim=0) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), gate) + data_torch = q_proj - return Qwen2MoeModel.modify_tensors(self, data_torch, name, bid) + yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid) @ModelBase.register("GPT2LMHeadModel") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a16b26f618..2cfd861cb8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -433,6 +433,7 @@ class MODEL_TENSOR(IntEnum): ATTN_NORM_2 = auto() ATTN_OUT_NORM = auto() ATTN_POST_NORM = auto() + ATTN_GATE = auto() ATTN_ROT_EMBD = auto() ATTN_SINKS = auto() FFN_GATE_INP = auto() @@ -776,6 +777,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", @@ -1478,6 +1480,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4c9652c3a3..ce6ec355b2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -769,6 +769,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, @@ -2245,6 +2246,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index b9abe3c096..d4d5995715 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -381,6 +381,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_SUB_NORM, LLM_TENSOR_DEC_ATTN_NORM, LLM_TENSOR_DEC_ATTN_Q, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index acd7ed8e31..e7731c20ad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2436,7 +2436,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if ((i + 1) % 4 == 0) { // TODO: magic 4 // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_ff }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -2445,6 +2445,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + // attn gate + layer.wq_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions @@ -2454,7 +2457,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); @@ -19032,30 +19035,27 @@ private: const int64_t n_embd_head, const int il) { - // QKV projection with gating - ggml_tensor * qkv_g = build_lora_mm(model.layers[il].wq, cur); - cb(qkv_g, "qkv_g", il); - - // Split into Q and gate - const int64_t n_embd_q = hparams.n_head(il) * n_embd_head; - ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens, - n_embd_head * sizeof(float), qkv_g->nb[1], 0); - ggml_tensor * gate = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens, - n_embd_head * sizeof(float), qkv_g->nb[1], n_embd_q * ggml_element_size(qkv_g)); - - // K and V projections - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // Apply Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); // Apply RoPE Qcur = ggml_rope_ext( @@ -19079,7 +19079,6 @@ private: Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); // Apply gating - gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens); cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate))); cb(cur, "attn_gated", il); diff --git a/src/llama-model.h b/src/llama-model.h index 9b4eb27953..753e332537 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -228,6 +228,7 @@ struct llama_layer { struct ggml_tensor * wk_enc = nullptr; struct ggml_tensor * wv_enc = nullptr; struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wq_gate = nullptr; // attention bias struct ggml_tensor * bq = nullptr;