mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
mup_vec create as float64
This commit is contained in:
@@ -6591,7 +6591,7 @@ class FalconH1Model(Mamba2Model):
|
|||||||
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
|
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
|
||||||
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])
|
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])
|
||||||
|
|
||||||
mup_vector = torch.ones(1, 1, vector_shape)
|
mup_vector = torch.ones(1, 1, vector_shape, dtype=torch.float64)
|
||||||
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
|
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
|
||||||
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
|
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
|
||||||
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
|
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
|
||||||
|
|||||||
Reference in New Issue
Block a user