mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-31 08:51:55 +00:00
convert : avoid AutoConfig for Mamba and Mamba2 hparams
This commit is contained in:
@@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model):
|
|||||||
class MambaModel(TextModel):
|
class MambaModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||||
|
|
||||||
|
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||||
|
# Avoid using AutoConfig for hparams
|
||||||
|
hparams = kwargs.pop("hparams", None)
|
||||||
|
if hparams is None:
|
||||||
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
|
hparams = json.load(f)
|
||||||
|
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
vocab_size = self.hparams["vocab_size"]
|
vocab_size = self.hparams["vocab_size"]
|
||||||
# Round vocab size to next multiple of 8
|
# Round vocab size to next multiple of 8
|
||||||
@@ -4205,6 +4213,15 @@ class MambaModel(TextModel):
|
|||||||
class Mamba2Model(TextModel):
|
class Mamba2Model(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.MAMBA2
|
model_arch = gguf.MODEL_ARCH.MAMBA2
|
||||||
|
|
||||||
|
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||||
|
# Avoid using AutoConfig for hparams
|
||||||
|
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
|
||||||
|
hparams = kwargs.pop("hparams", None)
|
||||||
|
if hparams is None:
|
||||||
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
|
hparams = json.load(f)
|
||||||
|
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
vocab_size = self.hparams["vocab_size"]
|
vocab_size = self.hparams["vocab_size"]
|
||||||
# Round vocab size to next multiple of 16
|
# Round vocab size to next multiple of 16
|
||||||
@@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
|
|||||||
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
|
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
|
||||||
text_config = hparams.get("text_config", {})
|
text_config = hparams.get("text_config", {})
|
||||||
vision_config = hparams.get("vision_config", {})
|
vision_config = hparams.get("vision_config", {})
|
||||||
arch = hparams["architectures"][0]
|
arch = None
|
||||||
|
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
|
||||||
|
arch = arches[0]
|
||||||
|
elif "ssm_cfg" in hparams:
|
||||||
|
# For non-hf Mamba and Mamba2 models
|
||||||
|
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
|
||||||
|
|
||||||
# if "architectures" is found in the sub-config, use that instead
|
# if "architectures" is found in the sub-config, use that instead
|
||||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||||
arch = text_config["architectures"][0]
|
arch = text_config["architectures"][0]
|
||||||
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
||||||
arch = vision_config["architectures"][0]
|
arch = vision_config["architectures"][0]
|
||||||
|
if arch is None:
|
||||||
|
raise ValueError("Failed to detect model architecture")
|
||||||
return arch
|
return arch
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user