mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama: Add support for Gemma2ForCausalLM (#8156)
* Inference support for Gemma 2 model family * Update convert-hf-to-gguf.py, constants, and tensor mappings * cleanup * format fix * Fix special token vocab bug * Don't add space prefix * fix deleted lines * Update src/llama.cpp Co-authored-by: slaren <slarengh@gmail.com> * Add model type names * Add control vector * Fix model type identification --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		@@ -150,6 +150,7 @@ class MODEL_ARCH(IntEnum):
 | 
			
		||||
    INTERNLM2    = auto()
 | 
			
		||||
    MINICPM      = auto()
 | 
			
		||||
    GEMMA        = auto()
 | 
			
		||||
    GEMMA2       = auto()
 | 
			
		||||
    STARCODER2   = auto()
 | 
			
		||||
    MAMBA        = auto()
 | 
			
		||||
    XVERSE       = auto()
 | 
			
		||||
@@ -180,10 +181,13 @@ class MODEL_TENSOR(IntEnum):
 | 
			
		||||
    ATTN_NORM            = auto()
 | 
			
		||||
    ATTN_NORM_2          = auto()
 | 
			
		||||
    ATTN_OUT_NORM        = auto()
 | 
			
		||||
    ATTN_POST_NORM       = auto()
 | 
			
		||||
    ATTN_ROT_EMBD        = auto()
 | 
			
		||||
    FFN_GATE_INP         = auto()
 | 
			
		||||
    FFN_GATE_INP_SHEXP   = auto()
 | 
			
		||||
    FFN_NORM             = auto()
 | 
			
		||||
    FFN_PRE_NORM         = auto()
 | 
			
		||||
    FFN_POST_NORM        = auto()
 | 
			
		||||
    FFN_GATE             = auto()
 | 
			
		||||
    FFN_DOWN             = auto()
 | 
			
		||||
    FFN_UP               = auto()
 | 
			
		||||
@@ -270,6 +274,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
 | 
			
		||||
    MODEL_ARCH.INTERNLM2:      "internlm2",
 | 
			
		||||
    MODEL_ARCH.MINICPM:        "minicpm",
 | 
			
		||||
    MODEL_ARCH.GEMMA:          "gemma",
 | 
			
		||||
    MODEL_ARCH.GEMMA2:         "gemma2",
 | 
			
		||||
    MODEL_ARCH.STARCODER2:     "starcoder2",
 | 
			
		||||
    MODEL_ARCH.MAMBA:          "mamba",
 | 
			
		||||
    MODEL_ARCH.XVERSE:         "xverse",
 | 
			
		||||
@@ -303,9 +308,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
 | 
			
		||||
    MODEL_TENSOR.ATTN_Q_NORM:          "blk.{bid}.attn_q_norm",
 | 
			
		||||
    MODEL_TENSOR.ATTN_K_NORM:          "blk.{bid}.attn_k_norm",
 | 
			
		||||
    MODEL_TENSOR.ATTN_OUT_NORM:        "blk.{bid}.attn_output_norm",
 | 
			
		||||
    MODEL_TENSOR.ATTN_POST_NORM:       "blk.{bid}.post_attention_norm",
 | 
			
		||||
    MODEL_TENSOR.FFN_GATE_INP:         "blk.{bid}.ffn_gate_inp",
 | 
			
		||||
    MODEL_TENSOR.FFN_GATE_INP_SHEXP:   "blk.{bid}.ffn_gate_inp_shexp",
 | 
			
		||||
    MODEL_TENSOR.FFN_NORM:             "blk.{bid}.ffn_norm",
 | 
			
		||||
    MODEL_TENSOR.FFN_PRE_NORM:         "blk.{bid}.ffn_norm",
 | 
			
		||||
    MODEL_TENSOR.FFN_POST_NORM:        "blk.{bid}.post_ffw_norm",
 | 
			
		||||
    MODEL_TENSOR.FFN_GATE:             "blk.{bid}.ffn_gate",
 | 
			
		||||
    MODEL_TENSOR.FFN_DOWN:             "blk.{bid}.ffn_down",
 | 
			
		||||
    MODEL_TENSOR.FFN_UP:               "blk.{bid}.ffn_up",
 | 
			
		||||
@@ -751,6 +759,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
 | 
			
		||||
        MODEL_TENSOR.FFN_UP,
 | 
			
		||||
        MODEL_TENSOR.FFN_NORM,
 | 
			
		||||
    ],
 | 
			
		||||
    MODEL_ARCH.GEMMA2: [
 | 
			
		||||
        MODEL_TENSOR.TOKEN_EMBD,
 | 
			
		||||
        MODEL_TENSOR.OUTPUT_NORM,
 | 
			
		||||
        MODEL_TENSOR.ATTN_Q,
 | 
			
		||||
        MODEL_TENSOR.ATTN_K,
 | 
			
		||||
        MODEL_TENSOR.ATTN_V,
 | 
			
		||||
        MODEL_TENSOR.ATTN_OUT,
 | 
			
		||||
        MODEL_TENSOR.FFN_GATE,
 | 
			
		||||
        MODEL_TENSOR.FFN_DOWN,
 | 
			
		||||
        MODEL_TENSOR.FFN_UP,
 | 
			
		||||
        MODEL_TENSOR.ATTN_NORM,
 | 
			
		||||
        MODEL_TENSOR.ATTN_POST_NORM,
 | 
			
		||||
        MODEL_TENSOR.FFN_PRE_NORM,
 | 
			
		||||
        MODEL_TENSOR.FFN_POST_NORM,
 | 
			
		||||
    ],
 | 
			
		||||
    MODEL_ARCH.STARCODER2: [
 | 
			
		||||
        MODEL_TENSOR.TOKEN_EMBD,
 | 
			
		||||
        MODEL_TENSOR.OUTPUT_NORM,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user