injected mup

This commit is contained in:
younesbelkada
2025-07-07 15:00:25 +04:00
parent b3bc1fb237
commit a9f3a63dc1
9 changed files with 43 additions and 101 deletions

View File

@@ -6576,6 +6576,7 @@ class FalconH1Model(Mamba2Model):
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
@@ -6607,16 +6608,38 @@ class FalconH1Model(Mamba2Model):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors = list(super().modify_tensors(data_torch, name, bid))
tensor = tensors[0][1]
if self.ssm_multipliers is not None and "mamba.dt_bias" in name:
block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name)
if block_match:
block_id = int(block_match.group(1))
mup_tensor = self._generate_mup_vector(block_id)
mup_name = f"blk.{block_id}.ssm_mup_vec"
logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}")
tensors.append((self.map_tensor_name(mup_name), mup_tensor))
if "down_proj" in name:
tensor = tensor * self.mlp_multipliers[1]
elif "gate_proj" in name:
tensor = tensor * self.mlp_multipliers[0]
elif "k_proj" in name:
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
elif "q_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "v_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "o_proj" in name:
tensor = tensor * self.attention_out_multiplier
elif "out_proj" in name:
tensor = tensor * self.ssm_out_multiplier
elif "in_proj" in name:
tensor = tensor * self.ssm_in_multiplier
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
elif "lm_head" in name:
tensor = tensor * self.hparams["lm_head_multiplier"]
elif "embed_tokens" in name:
tensor = tensor * self.hparams["embedding_multiplier"]
tensors = [(tensors[0][0], tensor)]
return tensors
def set_gguf_parameters(self):
@@ -6644,8 +6667,8 @@ class FalconH1Model(Mamba2Model):
self.gguf_writer.add_float64("falcon_h1.key_multiplier", self.hparams["key_multiplier"])
## Other params
self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"])
self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"])
# self.gguf_writer.add_float64("falcon_h1.lm_head_multiplier", self.hparams["lm_head_multiplier"])
# self.gguf_writer.add_float64("falcon_h1.embedding_multiplier", self.hparams["embedding_multiplier"])
## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
@@ -6661,20 +6684,16 @@ class FalconH1Model(Mamba2Model):
self.find_hparam(["num_key_value_heads"], optional=True) or
self.find_hparam(["num_attention_heads"]))
# Add multipliers as metadata instead of tensors
self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier)
self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier)
self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier)
self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier)
# # Add multipliers as metadata instead of tensors
# self.gguf_writer.add_float64("falcon_h1.attention_in_multiplier", self.attention_in_multiplier)
# self.gguf_writer.add_float64("falcon_h1.attention_out_multiplier", self.attention_out_multiplier)
# self.gguf_writer.add_float64("falcon_h1.ssm_in_multiplier", self.ssm_in_multiplier)
# self.gguf_writer.add_float64("falcon_h1.ssm_out_multiplier", self.ssm_out_multiplier)
# Add MLP multipliers
if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2:
self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0])
self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1])
# Add has MuP flag if SSM multipliers are present
if self.ssm_multipliers is not None:
self.gguf_writer.add_bool("falcon_h1.ssm.has_mup", True)
# # Add MLP multipliers
# if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2:
# self.gguf_writer.add_float64("falcon_h1.mlp_gate_multiplier", self.mlp_multipliers[0])
# self.gguf_writer.add_float64("falcon_h1.mlp_down_multiplier", self.mlp_multipliers[1])
# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_bool("falcon_h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True))