diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 532cc879de..2debb6e63f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4205,6 +4213,15 @@ class MambaModel(TextModel): class Mamba2Model(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA2 + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 16 @@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch