mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	convert : avoid AutoConfig for Mamba and Mamba2 hparams
This commit is contained in:
		| @@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model): | ||||
| class MambaModel(TextModel): | ||||
|     model_arch = gguf.MODEL_ARCH.MAMBA | ||||
|  | ||||
|     def __init__(self, dir_model: Path, *args, **kwargs): | ||||
|         # Avoid using AutoConfig for hparams | ||||
|         hparams = kwargs.pop("hparams", None) | ||||
|         if hparams is None: | ||||
|             with open(dir_model / "config.json", "r", encoding="utf-8") as f: | ||||
|                 hparams = json.load(f) | ||||
|         super().__init__(dir_model, *args, hparams=hparams, **kwargs) | ||||
|  | ||||
|     def set_vocab(self): | ||||
|         vocab_size = self.hparams["vocab_size"] | ||||
|         # Round vocab size to next multiple of 8 | ||||
| @@ -4205,6 +4213,15 @@ class MambaModel(TextModel): | ||||
| class Mamba2Model(TextModel): | ||||
|     model_arch = gguf.MODEL_ARCH.MAMBA2 | ||||
|  | ||||
|     def __init__(self, dir_model: Path, *args, **kwargs): | ||||
|         # Avoid using AutoConfig for hparams | ||||
|         # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 | ||||
|         hparams = kwargs.pop("hparams", None) | ||||
|         if hparams is None: | ||||
|             with open(dir_model / "config.json", "r", encoding="utf-8") as f: | ||||
|                 hparams = json.load(f) | ||||
|         super().__init__(dir_model, *args, hparams=hparams, **kwargs) | ||||
|  | ||||
|     def set_vocab(self): | ||||
|         vocab_size = self.hparams["vocab_size"] | ||||
|         # Round vocab size to next multiple of 16 | ||||
| @@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any | ||||
|     hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams | ||||
|     text_config = hparams.get("text_config", {}) | ||||
|     vision_config = hparams.get("vision_config", {}) | ||||
|     arch = hparams["architectures"][0] | ||||
|     arch = None | ||||
|     if (arches := hparams.get("architectures")) is not None and len(arches) > 0: | ||||
|         arch = arches[0] | ||||
|     elif "ssm_cfg" in hparams: | ||||
|         # For non-hf Mamba and Mamba2 models | ||||
|         arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" | ||||
|  | ||||
|     # if "architectures" is found in the sub-config, use that instead | ||||
|     if model_type == ModelType.TEXT and text_config.get("architectures") is not None: | ||||
|         arch = text_config["architectures"][0] | ||||
|     elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: | ||||
|         arch = vision_config["architectures"][0] | ||||
|     if arch is None: | ||||
|         raise ValueError("Failed to detect model architecture") | ||||
|     return arch | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Francis Couture-Harpin
					Francis Couture-Harpin