convert : use F32 operations on Mamba A_log

This matches the previous behavior for BF16 tensors.
This commit is contained in:
Francis Couture-Harpin
2025-09-04 18:43:10 -04:00
parent 6792f66a93
commit fb879b40c0

View File

@@ -4356,7 +4356,7 @@ class Plamo2Model(TextModel):
del bid # unused
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif name.endswith(".dt_norm_weight"):
@@ -5829,7 +5829,7 @@ class MambaModel(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
# [4 1 8192 1] -> [4 8192 1 1]
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
@@ -5934,7 +5934,7 @@ class Mamba2Model(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
yield (new_name, data_torch)
@@ -6042,7 +6042,7 @@ class JambaModel(TextModel):
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
data_torch = -torch.exp(data_torch.float())
yield (new_name, data_torch)