mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	convert : avoid AutoConfig for Mamba and Mamba2 hparams
This commit is contained in:
		@@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model):
 | 
			
		||||
class MambaModel(TextModel):
 | 
			
		||||
    model_arch = gguf.MODEL_ARCH.MAMBA
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dir_model: Path, *args, **kwargs):
 | 
			
		||||
        # Avoid using AutoConfig for hparams
 | 
			
		||||
        hparams = kwargs.pop("hparams", None)
 | 
			
		||||
        if hparams is None:
 | 
			
		||||
            with open(dir_model / "config.json", "r", encoding="utf-8") as f:
 | 
			
		||||
                hparams = json.load(f)
 | 
			
		||||
        super().__init__(dir_model, *args, hparams=hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def set_vocab(self):
 | 
			
		||||
        vocab_size = self.hparams["vocab_size"]
 | 
			
		||||
        # Round vocab size to next multiple of 8
 | 
			
		||||
@@ -4205,6 +4213,15 @@ class MambaModel(TextModel):
 | 
			
		||||
class Mamba2Model(TextModel):
 | 
			
		||||
    model_arch = gguf.MODEL_ARCH.MAMBA2
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dir_model: Path, *args, **kwargs):
 | 
			
		||||
        # Avoid using AutoConfig for hparams
 | 
			
		||||
        # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
 | 
			
		||||
        hparams = kwargs.pop("hparams", None)
 | 
			
		||||
        if hparams is None:
 | 
			
		||||
            with open(dir_model / "config.json", "r", encoding="utf-8") as f:
 | 
			
		||||
                hparams = json.load(f)
 | 
			
		||||
        super().__init__(dir_model, *args, hparams=hparams, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def set_vocab(self):
 | 
			
		||||
        vocab_size = self.hparams["vocab_size"]
 | 
			
		||||
        # Round vocab size to next multiple of 16
 | 
			
		||||
@@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
 | 
			
		||||
    hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
 | 
			
		||||
    text_config = hparams.get("text_config", {})
 | 
			
		||||
    vision_config = hparams.get("vision_config", {})
 | 
			
		||||
    arch = hparams["architectures"][0]
 | 
			
		||||
    arch = None
 | 
			
		||||
    if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
 | 
			
		||||
        arch = arches[0]
 | 
			
		||||
    elif "ssm_cfg" in hparams:
 | 
			
		||||
        # For non-hf Mamba and Mamba2 models
 | 
			
		||||
        arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
 | 
			
		||||
 | 
			
		||||
    # if "architectures" is found in the sub-config, use that instead
 | 
			
		||||
    if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
 | 
			
		||||
        arch = text_config["architectures"][0]
 | 
			
		||||
    elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
 | 
			
		||||
        arch = vision_config["architectures"][0]
 | 
			
		||||
    if arch is None:
 | 
			
		||||
        raise ValueError("Failed to detect model architecture")
 | 
			
		||||
    return arch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user