convert : use F32 operations on Mamba A_log

This matches the previous behavior for BF16 tensors.
This commit is contained in:
Francis Couture-Harpin
2025-09-04 18:43:10 -04:00
parent c3738cfcef
commit 614b95a88d

View File

@@ -4770,7 +4770,7 @@ class Plamo2Model(TextModel):
del bid # unused
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif name.endswith(".dt_norm_weight"):
@@ -6294,7 +6294,7 @@ class MambaModel(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
# [4 1 8192 1] -> [4 8192 1 1]
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
@@ -6399,7 +6399,7 @@ class Mamba2Model(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
yield (new_name, data_torch)
@@ -6499,7 +6499,7 @@ class JambaModel(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
yield (new_name, data_torch)