mup_vec create as float64

This commit is contained in:
ibrahimkhadraoui
2025-07-07 14:25:32 +04:00
parent 49d7420964
commit 97011d7a1f

View File

@@ -6591,7 +6591,7 @@ class FalconH1Model(Mamba2Model):
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"]) vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])
mup_vector = torch.ones(1, 1, vector_shape) mup_vector = torch.ones(1, 1, vector_shape, dtype=torch.float64)
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1] mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]