From a9f3a63dc10b05dba5b7f56a820fd91e3c35deca Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 7 Jul 2025 15:00:25 +0400 Subject: [PATCH] injected mup --- convert_hf_to_gguf.py | 65 ++++++++++++++++++++++------------ gguf-py/gguf/constants.py | 3 -- gguf-py/gguf/tensor_mapping.py | 6 +--- src/llama-arch.cpp | 12 ------- src/llama-arch.h | 11 ------ src/llama-graph.cpp | 7 ---- src/llama-hparams.h | 10 ------ src/llama-model.cpp | 29 --------------- src/llama-model.h | 1 - 9 files changed, 43 insertions(+), 101 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bdf7363964..66073067e7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6576,6 +6576,7 @@ class FalconH1Model(Mamba2Model): self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) self.intermediate_size = self.find_hparam(["intermediate_size"]) + self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True) def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: prefixed = [] @@ -6607,16 +6608,38 @@ class FalconH1Model(Mamba2Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: tensors = list(super().modify_tensors(data_torch, name, bid)) + tensor = tensors[0][1] - if self.ssm_multipliers is not None and "mamba.dt_bias" in name: - block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name) - if block_match: - block_id = int(block_match.group(1)) - mup_tensor = self._generate_mup_vector(block_id) - mup_name = f"blk.{block_id}.ssm_mup_vec" - logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}") - tensors.append((self.map_tensor_name(mup_name), mup_tensor)) + if "down_proj" in name: + tensor = tensor * self.mlp_multipliers[1] + elif "gate_proj" in name: + tensor = tensor * self.mlp_multipliers[0] + elif "k_proj" in name: + tensor = tensor * self.key_multiplier * self.attention_in_multiplier + elif "q_proj" in name: + tensor = tensor * self.attention_in_multiplier + elif "v_proj" in name: + tensor = tensor * self.attention_in_multiplier + elif "o_proj" in name: + tensor = tensor * self.attention_out_multiplier + elif "out_proj" in name: + tensor = tensor * self.ssm_out_multiplier + elif "in_proj" in name: + tensor = tensor * self.ssm_in_multiplier + zxbcdt_multipliers = self.hparams["ssm_multipliers"] + intermediate_size = self.hparams["mamba_d_ssm"] + groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] + tensor[:intermediate_size, :] *= zxbcdt_multipliers[0] + tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1] + tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2] + tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3] + tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4] + elif "lm_head" in name: + tensor = tensor * self.hparams["lm_head_multiplier"] + elif "embed_tokens" in name: + tensor = tensor * self.hparams["embedding_multiplier"] + tensors = [(tensors[0][0], tensor)] return tensors def set_gguf_parameters(self): @@ -6644,8 +6667,8 @@ class FalconH1Model(Mamba2Model): self.gguf_writer.add_float64("falcon_h1.key_multiplier", self.hparams["key_multiplier"]) ## Other params - self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"]) - self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"]) + # self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"]) + # self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"]) ## Validation ## assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" @@ -6661,20 +6684,16 @@ class FalconH1Model(Mamba2Model): self.find_hparam(["num_key_value_heads"], optional=True) or self.find_hparam(["num_attention_heads"])) - # Add multipliers as metadata instead of tensors - self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier) - self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier) - self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier) - self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier) + # # Add multipliers as metadata instead of tensors + # self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier) + # self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier) + # self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier) + # self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier) - # Add MLP multipliers - if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2: - self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0]) - self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1]) - - # Add has MuP flag if SSM multipliers are present - if self.ssm_multipliers is not None: - self.gguf_writer.add_bool("falcon_h1.ssm.has_mup", True) + # # Add MLP multipliers + # if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2: + # self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0]) + # self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1]) # Add any other Falcon Mamba2 specific configuration self.gguf_writer.add_bool("falcon_h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True)) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 08223927a7..84a91e82ca 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -527,7 +527,6 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() - SSM_MUP_VEC = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -740,7 +739,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", - MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -2230,7 +2228,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, # Output projection # SSM components (Mamba2 specific) - MODEL_TENSOR.SSM_MUP_VEC, # Mup vector MODEL_TENSOR.SSM_IN, # Input projection for SSM MODEL_TENSOR.SSM_CONV1D, # Convolution layer MODEL_TENSOR.SSM_DT, # Delta time projection diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 245385dfdf..ff3f273bd5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1176,11 +1176,7 @@ class TensorNameMap: MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( "resampler.attn.out_proj", ), - - MODEL_TENSOR.SSM_MUP_VEC: ( - "model.layers.{bid}.mamba.mup_vector", # falcon_h1 - ), - + MODEL_TENSOR.SSM_NORM: ( "model.layers.{bid}.mamba.norm", ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7fb81cfdc8..b43911eb30 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -228,18 +228,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_MAMBA_D_SSM, "%s.ssm.mamba_d_ssm" }, { LLM_KV_FALCON_H1_USE_MLP, "%s.mamba_use_mlp" }, - { LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, "%s.attention_in_multiplier" }, - { LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, "%s.attention_out_multiplier" }, - { LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, "%s.ssm_in_multiplier" }, - { LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, "%s.ssm_out_multiplier" }, - { LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, "%s.mlp_gate_multiplier" }, - { LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, "%s.mlp_down_multiplier" }, - { LLM_KV_FALCON_H1_SSM_HAS_MUP, "%s.ssm.has_mup" }, { LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, "%s.mamba_norm_before_gate" }, { LLM_KV_FALCON_H1_MAMBA_RMS_NORM, "%s.mamba_rms_norm" }, - { LLM_KV_FALCON_H1_KEY_MULTIPLIER, "%s.key_multiplier" }, - { LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, "%s.lm_head_multiplier" }, - { LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, "%s.embedding_multiplier" }, { LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, "%s.ssm.mamba_chunk_size" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, @@ -1062,7 +1052,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, - { LLM_TENSOR_SSM_MUP_VEC, "blk.%d.ssm_mup_vec" }, { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, @@ -1832,7 +1821,6 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_SSM_MUP_VEC, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 4ad1beb245..80af422361 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -163,18 +163,8 @@ enum llm_kv { LLM_KV_MAMBA_D_SSM, LLM_KV_N_LAYER, LLM_KV_FALCON_H1_USE_MLP, - LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, - LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, - LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, - LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, - LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, - LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, - LLM_KV_FALCON_H1_SSM_HAS_MUP, LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, LLM_KV_FALCON_H1_MAMBA_RMS_NORM, - LLM_KV_FALCON_H1_KEY_MULTIPLIER, - LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, - LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, LLM_KV_ROPE_DIMENSION_COUNT, @@ -410,7 +400,6 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, - LLM_TENSOR_SSM_MUP_VEC, LLM_TENSOR_FFN_PRE_NORM, LLM_TENSOR_FINAL_NORM, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index eea8207c14..4443420132 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -545,10 +545,6 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); - if (arch == LLM_ARCH_FALCON_H1) { - cur = ggml_scale(ctx0, cur, hparams.mlp_gate_multiplier); - } - cb(cur, "ffn_gate", il); } break; } @@ -635,9 +631,6 @@ ggml_tensor * llm_graph_context::build_ffn( // GLM4 seems to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } - if (arch == LLM_ARCH_FALCON_H1) { - cur = ggml_scale(ctx0, cur, hparams.mlp_down_multiplier); - } } if (down_b) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 2142a74aaf..429eaf0482 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -122,17 +122,7 @@ struct llama_hparams { bool mamba_use_mlp = false; bool mamba_norm_before_gate = false; bool mamba_rms_norm = false; - double attention_in_multiplier = 1.0; - double attention_out_multiplier = 1.0; - double ssm_in_multiplier = 1.0; - double ssm_out_multiplier = 1.0; - double mlp_gate_multiplier = 1.0; - double mlp_down_multiplier = 1.0; - double key_multiplier = 1.0; - double lm_head_multiplier = 1.0; double rope_theta = 10000.0; - double embedding_multiplier = 1.0; - bool ssm_has_mup = false; uint32_t vocab_size = 0; uint32_t intermediate_size = 0; float mamba_expand = 0.0f; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 661256df64..bf6613a80e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1568,18 +1568,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Falcon-H1 parameters ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim); ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp); - ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier); - ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier); - ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier); - ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier); - ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier); - ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier); - ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup); ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate); ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm); - ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier); - ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier); - ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier); std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); @@ -4570,9 +4560,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // no "weight" suffix for these layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); - if (hparams.ssm_has_mup == true) { - layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0); - } // ssm_norm if (hparams.mamba_rms_norm == true) { layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); @@ -14665,7 +14652,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier); // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -14684,7 +14670,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier); // self-attention ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -14699,8 +14684,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -14721,18 +14704,15 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * attn_out = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); - attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); cb(attn_out, "attn_out", il); cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // Mamba2 layer - cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); cb(cur, "ssm_in", il); ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il); - ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); cb(ssm_out, "ssm_out", il); // // Aggregation @@ -14782,7 +14762,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier); cb(cur, "result_output", -1); res->t_logits = cur; @@ -14829,14 +14808,6 @@ struct llm_build_falcon_h1 : public llm_graph_context { ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); cb(zxBCdt, "zxBCdt", il); - // check if the models has ssm_multipliers (MuP) - if (hparams.ssm_has_mup) { - struct ggml_tensor * mup_vec = model.layers[il].ssm_mup_vec; - cur = ggml_mul(ctx0, zxBCdt, mup_vec); - cb(cur, "ssm_mup", il); - zxBCdt = cur; - } - // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt)); diff --git a/src/llama-model.h b/src/llama-model.h index 506fcd4789..1f089ebd2e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -227,7 +227,6 @@ struct llama_layer { // falcon_h1 struct ggml_tensor * ssm_in_b = nullptr; - struct ggml_tensor * ssm_mup_vec = nullptr; // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr;