mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
Merge branch 'master' into compilade/imatrix-batched-chunks
This commit is contained in:
@@ -11,6 +11,11 @@ as an example for its usage.
|
||||
pip install gguf
|
||||
```
|
||||
|
||||
Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor.
|
||||
```sh
|
||||
pip install gguf[gui]
|
||||
```
|
||||
|
||||
## API Examples/Simple Tools
|
||||
|
||||
[examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
|
||||
@@ -25,6 +30,8 @@ pip install gguf
|
||||
|
||||
[gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values.
|
||||
|
||||
[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface.
|
||||
|
||||
## Development
|
||||
Maintainers who participate in development of this package are advised to install it in editable mode:
|
||||
|
||||
|
||||
@@ -104,6 +104,7 @@ class Keys:
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
@@ -139,6 +140,8 @@ class Keys:
|
||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||
SCALE = "{arch}.attention.scale"
|
||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@@ -174,6 +177,9 @@ class Keys:
|
||||
EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.convnext.block_count"
|
||||
|
||||
class Classifier:
|
||||
OUTPUT_LABELS = "{arch}.classifier.output_labels"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
@@ -221,6 +227,46 @@ class Keys:
|
||||
CHUNK_SIZE = "imatrix.chunk_size"
|
||||
DATASETS = "imatrix.datasets"
|
||||
|
||||
class Clip:
|
||||
PROJECTOR_TYPE = "clip.projector_type"
|
||||
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
||||
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
|
||||
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
|
||||
|
||||
class ClipVision:
|
||||
IMAGE_SIZE = "clip.vision.image_size"
|
||||
PATCH_SIZE = "clip.vision.patch_size"
|
||||
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
||||
FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
|
||||
PROJECTION_DIM = "clip.vision.projection_dim"
|
||||
BLOCK_COUNT = "clip.vision.block_count"
|
||||
IMAGE_MEAN = "clip.vision.image_mean"
|
||||
IMAGE_STD = "clip.vision.image_std"
|
||||
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
|
||||
USE_GELU = "clip.use_gelu"
|
||||
USE_SILU = "clip.use_silu"
|
||||
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
|
||||
|
||||
class Projector:
|
||||
SCALE_FACTOR = "clip.vision.projector.scale_factor"
|
||||
|
||||
class ClipAudio:
|
||||
NUM_MEL_BINS = "clip.audio.num_mel_bins"
|
||||
EMBEDDING_LENGTH = "clip.audio.embedding_length"
|
||||
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
|
||||
PROJECTION_DIM = "clip.audio.projection_dim"
|
||||
BLOCK_COUNT = "clip.audio.block_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
|
||||
|
||||
class Projector:
|
||||
STACK_FACTOR = "clip.audio.projector.stack_factor"
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
@@ -231,9 +277,11 @@ class GGUFType:
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
IMATRIX = "imatrix"
|
||||
MMPROJ = "mmproj" # dummy, unused for now
|
||||
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
MMPROJ = auto() # dummy arch for clip.cpp
|
||||
LLAMA = auto()
|
||||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
@@ -248,6 +296,8 @@ class MODEL_ARCH(IntEnum):
|
||||
REFACT = auto()
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
NOMIC_BERT_MOE = auto()
|
||||
NEO_BERT = auto()
|
||||
JINA_BERT_V2 = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
@@ -300,6 +350,18 @@ class MODEL_ARCH(IntEnum):
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
PLM = auto()
|
||||
BAILINGMOE = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
MLP = auto()
|
||||
LDP = auto()
|
||||
LDPV2 = auto()
|
||||
RESAMPLER = auto()
|
||||
GLM_EDGE = auto()
|
||||
MERGER = auto()
|
||||
GEMMA3 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@@ -389,6 +451,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_K_B = auto()
|
||||
ATTN_V_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
@@ -439,9 +503,68 @@ class MODEL_TENSOR(IntEnum):
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
V_MMPROJ_MLP = auto()
|
||||
V_MMPROJ_PEG = auto()
|
||||
V_ENC_EMBD_CLS = auto()
|
||||
V_ENC_EMBD_PATCH = auto()
|
||||
V_ENC_EMBD_POS = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_ATTN_Q = auto()
|
||||
V_ENC_ATTN_Q_NORM = auto()
|
||||
V_ENC_ATTN_K = auto()
|
||||
V_ENC_ATTN_K_NORM = auto()
|
||||
V_ENC_ATTN_V = auto()
|
||||
V_ENC_ATTN_O = auto()
|
||||
V_ENC_ATTN_O_NORM = auto()
|
||||
V_ENC_POST_ATTN_NORM = auto()
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_GATE = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
V_LAYER_SCALE_1 = auto()
|
||||
V_LAYER_SCALE_2 = auto()
|
||||
V_PRE_NORM = auto()
|
||||
V_POST_NORM = auto()
|
||||
V_MM_INP_NORM = auto()
|
||||
V_MM_INP_PROJ = auto() # gemma3
|
||||
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
||||
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
||||
V_RESMPL_ATTN_Q = auto() # minicpmv
|
||||
V_RESMPL_ATTN_K = auto() # minicpmv
|
||||
V_RESMPL_ATTN_V = auto() # minicpmv
|
||||
V_RESMPL_ATTN_OUT = auto() # minicpmv
|
||||
V_RESMPL_KV = auto() # minicpmv
|
||||
V_RESMPL_KV_NORM = auto() # minicpmv
|
||||
V_RESMPL_POST_NORM = auto() # minicpmv
|
||||
V_RESMPL_Q_NORM = auto() # minicpmv
|
||||
V_RESMPL_PROJ = auto() # minicpmv
|
||||
V_RESMPL_QUERY = auto() # minicpmv
|
||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||
# audio (mtmd)
|
||||
A_ENC_EMBD_POS = auto()
|
||||
A_ENC_CONV1D = auto()
|
||||
A_PRE_NORM = auto()
|
||||
A_POST_NORM = auto()
|
||||
A_ENC_ATTN_Q = auto()
|
||||
A_ENC_ATTN_K = auto()
|
||||
A_ENC_ATTN_V = auto()
|
||||
A_ENC_INPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto()
|
||||
A_ENC_OUTPUT_NORM = auto()
|
||||
A_ENC_FFN_UP = auto()
|
||||
A_ENC_FFN_GATE = auto()
|
||||
A_ENC_FFN_DOWN = auto()
|
||||
A_MMPROJ = auto()
|
||||
A_MMPROJ_FC = auto()
|
||||
A_MM_NORM_PRE = auto()
|
||||
A_MM_NORM_MID = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.LLAMA4: "llama4",
|
||||
MODEL_ARCH.DECI: "deci",
|
||||
@@ -456,6 +579,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
|
||||
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
|
||||
MODEL_ARCH.NEO_BERT: "neo-bert",
|
||||
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
MODEL_ARCH.STABLELM: "stablelm",
|
||||
@@ -508,6 +633,18 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
VISION_PROJECTOR_TYPE.MLP: "mlp",
|
||||
VISION_PROJECTOR_TYPE.LDP: "ldp",
|
||||
VISION_PROJECTOR_TYPE.LDPV2: "ldpv2",
|
||||
VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
|
||||
VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter",
|
||||
VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger",
|
||||
VISION_PROJECTOR_TYPE.GEMMA3: "gemma3",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@@ -597,6 +734,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
|
||||
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
@@ -647,9 +786,126 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
|
||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
||||
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out",
|
||||
MODEL_TENSOR.V_RESMPL_KV: "resampler.kv",
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv",
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post",
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
|
||||
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||
# audio (mtmd)
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
||||
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
||||
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.MMPROJ: [
|
||||
MODEL_TENSOR.V_MMPROJ,
|
||||
MODEL_TENSOR.V_MMPROJ_FC,
|
||||
MODEL_TENSOR.V_MMPROJ_MLP,
|
||||
MODEL_TENSOR.V_MMPROJ_PEG,
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM,
|
||||
MODEL_TENSOR.V_ENC_ATTN_V,
|
||||
MODEL_TENSOR.V_ENC_ATTN_O,
|
||||
MODEL_TENSOR.V_ENC_ATTN_O_NORM,
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1,
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2,
|
||||
MODEL_TENSOR.V_PRE_NORM,
|
||||
MODEL_TENSOR.V_POST_NORM,
|
||||
MODEL_TENSOR.V_MM_INP_PROJ,
|
||||
MODEL_TENSOR.V_MM_INP_NORM,
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
|
||||
MODEL_TENSOR.V_RESMPL_KV,
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_PROJ,
|
||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||
# audio
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
MODEL_TENSOR.A_POST_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K,
|
||||
MODEL_TENSOR.A_ENC_ATTN_V,
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_FFN_UP,
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.A_MMPROJ,
|
||||
MODEL_TENSOR.A_MMPROJ_FC,
|
||||
MODEL_TENSOR.A_MM_NORM_PRE,
|
||||
MODEL_TENSOR.A_MM_NORM_MID,
|
||||
],
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
@@ -792,6 +1048,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
@@ -816,6 +1073,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.NOMIC_BERT_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.TOKEN_TYPES,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.NEO_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
],
|
||||
MODEL_ARCH.JINA_BERT_V2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
@@ -1524,6 +1809,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_K_B,
|
||||
MODEL_TENSOR.ATTN_V_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
@@ -1720,6 +2007,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.CHAMELEON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -1778,6 +2068,45 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.DOTS1: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.ARCEE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@@ -1860,6 +2189,8 @@ class PoolingType(IntEnum):
|
||||
NONE = 0
|
||||
MEAN = 1
|
||||
CLS = 2
|
||||
LAST = 3
|
||||
RANK = 4
|
||||
|
||||
|
||||
class GGMLQuantizationType(IntEnum):
|
||||
@@ -1986,6 +2317,19 @@ class GGUFValueType(IntEnum):
|
||||
raise ValueError(f"Unknown type: {type(val)}")
|
||||
|
||||
|
||||
class VisionProjectorType:
|
||||
GEMMA3 = "gemma3"
|
||||
IDEFICS3 = "idefics3"
|
||||
PIXTRAL = "pixtral"
|
||||
LLAMA4 = "llama4"
|
||||
QWEN2VL = "qwen2vl_merger"
|
||||
QWEN25VL = "qwen2.5vl_merger"
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
QWEN25O = "qwen2.5o" # omni
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
QK_K = 256
|
||||
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
|
||||
@@ -251,7 +251,7 @@ class GGUFReader:
|
||||
offs += curr_size
|
||||
return offs - orig_offs, aparts, data_idxs, types
|
||||
# We can't deal with this one.
|
||||
raise ValueError('Unknown/unhandled field type {gtype}')
|
||||
raise ValueError(f'Unknown/unhandled field type {gtype}')
|
||||
|
||||
def _get_tensor_info_field(self, orig_offs: int) -> ReaderField:
|
||||
offs = orig_offs
|
||||
|
||||
@@ -49,6 +49,7 @@ class TensorInfo:
|
||||
class GGUFValue:
|
||||
value: Any
|
||||
type: GGUFValueType
|
||||
sub_type: GGUFValueType | None = None
|
||||
|
||||
|
||||
class WriterState(Enum):
|
||||
@@ -238,7 +239,7 @@ class GGUFWriter:
|
||||
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
|
||||
|
||||
fout.write(kv_bytes)
|
||||
|
||||
@@ -268,11 +269,11 @@ class GGUFWriter:
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
@@ -689,6 +690,12 @@ class GGUFWriter:
|
||||
def add_value_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_key_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_value_length_mla(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
|
||||
|
||||
def add_max_alibi_bias(self, bias: float) -> None:
|
||||
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
|
||||
|
||||
@@ -722,6 +729,9 @@ class GGUFWriter:
|
||||
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_moe_every_n_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
def add_swin_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
||||
|
||||
@@ -887,7 +897,7 @@ class GGUFWriter:
|
||||
def add_remove_extra_whitespaces(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
|
||||
|
||||
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
|
||||
def add_precompiled_charsmap(self, charsmap: bytes) -> None:
|
||||
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
@@ -925,13 +935,98 @@ class GGUFWriter:
|
||||
def add_eom_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
|
||||
|
||||
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
|
||||
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
|
||||
|
||||
# for vision models
|
||||
|
||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
|
||||
|
||||
def add_clip_has_audio_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
|
||||
|
||||
def add_clip_projector_type(self, value: str) -> None:
|
||||
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
|
||||
|
||||
def add_vision_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
||||
|
||||
def add_vision_patch_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
||||
|
||||
def add_vision_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_vision_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_vision_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
|
||||
|
||||
def add_vision_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
def add_vision_image_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
|
||||
|
||||
def add_vision_image_mean(self, values: Sequence[float]) -> None:
|
||||
self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
|
||||
|
||||
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
||||
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
||||
|
||||
def add_vision_spatial_merge_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
|
||||
|
||||
def add_vision_use_gelu(self, value: bool) -> None:
|
||||
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
||||
|
||||
def add_vision_use_silu(self, value: bool) -> None:
|
||||
self.add_bool(Keys.ClipVision.USE_SILU, value)
|
||||
|
||||
def add_vision_projector_scale_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
|
||||
|
||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||
|
||||
# audio models
|
||||
|
||||
def add_audio_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
|
||||
|
||||
def add_audio_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_audio_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_audio_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
|
||||
|
||||
def add_audio_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_audio_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
def add_audio_num_mel_bins(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
|
||||
|
||||
def add_audio_stack_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
|
||||
|
||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||
pack_prefix = ''
|
||||
if not skip_pack_prefix:
|
||||
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
||||
return struct.pack(f'{pack_prefix}{fmt}', value)
|
||||
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
|
||||
kv_data = bytearray()
|
||||
|
||||
if add_vtype:
|
||||
@@ -952,7 +1047,9 @@ class GGUFWriter:
|
||||
if len(val) == 0:
|
||||
raise ValueError("Invalid GGUF metadata array. Empty array")
|
||||
|
||||
if isinstance(val, bytes):
|
||||
if sub_type is not None:
|
||||
ltype = sub_type
|
||||
elif isinstance(val, bytes):
|
||||
ltype = GGUFValueType.UINT8
|
||||
else:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
# pyright: reportUnusedImport=false
|
||||
|
||||
from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
|
||||
from .gguf_dump import main as gguf_dump_entrypoint
|
||||
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
|
||||
from .gguf_new_metadata import main as gguf_new_metadata_entrypoint
|
||||
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
1621
gguf-py/gguf/scripts/gguf_editor_gui.py
Executable file
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,7 @@ class MetadataDetails(NamedTuple):
|
||||
type: gguf.GGUFValueType
|
||||
value: Any
|
||||
description: str = ''
|
||||
sub_type: gguf.GGUFValueType | None = None
|
||||
|
||||
|
||||
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||
@@ -57,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Removing {field.name}')
|
||||
continue
|
||||
|
||||
old_val = MetadataDetails(field.types[0], field.contents())
|
||||
val_type = field.types[0]
|
||||
sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
|
||||
old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
|
||||
val = new_metadata.get(field.name, old_val)
|
||||
|
||||
if field.name in new_metadata:
|
||||
@@ -67,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Copying {field.name}')
|
||||
|
||||
if val.value is not None:
|
||||
writer.add_key_value(field.name, val.value, val.type)
|
||||
writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
|
||||
|
||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||
logger.debug('Adding chat template(s)')
|
||||
|
||||
@@ -31,6 +31,7 @@ class TensorNameMap:
|
||||
"model.embeddings", # rwkv7
|
||||
"model.word_embeddings", # bailingmoe
|
||||
"language_model.model.embed_tokens", # llama4
|
||||
"encoder", # neobert
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -68,7 +69,7 @@ class TensorNameMap:
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
"language_model.lm_head", # llama4
|
||||
"lm_head", # llama4
|
||||
),
|
||||
|
||||
# Output norm
|
||||
@@ -91,7 +92,7 @@ class TensorNameMap:
|
||||
"rwkv.ln_out", # rwkv6
|
||||
"model.ln_out", # rwkv7
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
"language_model.model.norm", # llama4
|
||||
"model.norm", # llama4
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
@@ -133,7 +134,8 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"language_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -157,9 +159,11 @@ class TensorNameMap:
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.Wqkv", # jina
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
"transformer_encoder.{bid}.qkv", # neobert
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@@ -168,12 +172,13 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
"transformer.layer.{bid}.attention.q_lin", # distillbert
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention key
|
||||
@@ -182,13 +187,14 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.layer.{bid}.attention.k_lin", # distillbert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.k", # refact
|
||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention value
|
||||
@@ -196,13 +202,14 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.layer.{bid}.attention.v_lin", # distillbert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.v", # refact
|
||||
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
),
|
||||
|
||||
# Attention output
|
||||
@@ -216,6 +223,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.layer.{bid}.attention.out_lin", # distillbert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
@@ -224,17 +232,20 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.out_proj", # jina
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
"language_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"transformer_encoder.{bid}.wo", # neobert
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: (
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"transformer.layer.{bid}.sa_layer_norm", # distillbert
|
||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
@@ -268,7 +279,8 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.ffn_norm", # neobert
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -289,7 +301,8 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
"language_model.model.layers.{bid}.feed_forward.router", # llama4
|
||||
"model.layers.{bid}.feed_forward.router", # llama4
|
||||
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@@ -297,7 +310,7 @@ class TensorNameMap:
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
),
|
||||
|
||||
# Feed-forward up
|
||||
@@ -310,6 +323,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.layer.{bid}.ffn.lin1", # distillbert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"transformer.h.{bid}.mlp.linear_3", # refact
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
@@ -322,12 +336,16 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w3", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU)
|
||||
"encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU)
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"transformer_encoder.{bid}.ffn.w12", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -336,13 +354,14 @@ class TensorNameMap:
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@@ -359,26 +378,26 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@@ -391,6 +410,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.layer.{bid}.ffn.lin2", # distillbert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
@@ -407,7 +427,8 @@ class TensorNameMap:
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"transformer_encoder.{bid}.ffn.w3", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -417,13 +438,15 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
|
||||
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
|
||||
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
@@ -450,6 +473,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"transformer.layer.{bid}.output_layer_norm", # distillbert
|
||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||
@@ -677,6 +701,14 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_K_B: (
|
||||
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_V_B: (
|
||||
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
@@ -807,11 +839,14 @@ class TensorNameMap:
|
||||
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
"layer_norm", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS: (
|
||||
"classifier", # jina
|
||||
"classifier.dense", # roberta
|
||||
"pre_classifier", # distillbert
|
||||
"dense", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS_OUT: (
|
||||
@@ -878,6 +913,295 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: (
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
|
||||
#############################################################################
|
||||
## Vision encoder
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ: (
|
||||
"multi_modal_projector.linear_{bid}",
|
||||
"visual.merger.mlp.{bid}", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||
"model.connector.modality_projection.proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: (
|
||||
"model.mm_projector.mlp.mlp.{bid}",
|
||||
"vision_model.vision_adapter.mlp.fc{bid}", # llama 4
|
||||
"mlp1.{bid}", # InternVL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: (
|
||||
"model.mm_projector.peg.peg.{bid}",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: (
|
||||
"vision_tower.vision_model.embeddings.class_embedding",
|
||||
"vision_model.class_embedding", # llama 4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
"vision_tower.vision_model.embeddings.patch_embedding",
|
||||
"vpm.embeddings.patch_embedding",
|
||||
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
||||
"vision_tower.patch_conv", # pixtral
|
||||
"vision_model.patch_embedding.linear", # llama 4
|
||||
"visual.patch_embed.proj", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
"vision_tower.vision_model.embeddings.position_embedding",
|
||||
"vpm.embeddings.position_embedding",
|
||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||
"vision_model.positional_embedding_vlm", # llama 4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
|
||||
"vpm.encoder.layers.{bid}.layer_norm1",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
|
||||
"vision_model.model.layers.{bid}.input_layernorm", # llama4
|
||||
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
|
||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
|
||||
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
|
||||
"vpm.encoder.layers.{bid}.layer_norm2",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
||||
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
||||
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
|
||||
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
|
||||
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
||||
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
|
||||
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_2: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_PRE_NORM: (
|
||||
"vision_tower.vision_model.pre_layrnorm",
|
||||
"vision_tower.ln_pre", # pixtral
|
||||
"vision_model.layernorm_pre", # llama4
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_POST_NORM: (
|
||||
"vision_tower.vision_model.post_layernorm",
|
||||
"model.vision_model.post_layernorm", # SmolVLM
|
||||
"vision_model.layernorm_post", # llama4
|
||||
"visual.merger.ln_q", # qwen2vl
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_INP_PROJ: (
|
||||
"multi_modal_projector.mm_input_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_INP_NORM: (
|
||||
"multi_modal_projector.norm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||
"multi_modal_projector.mm_soft_emb_norm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
|
||||
"resampler.pos_embed_k",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: (
|
||||
"resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K: (
|
||||
"resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V: (
|
||||
"resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
|
||||
"resampler.attn.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_KV: (
|
||||
"resampler.kv_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM: (
|
||||
"resampler.ln_post",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM: (
|
||||
"resampler.ln_kv",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM: (
|
||||
"resampler.ln_q",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_PROJ: (
|
||||
"resampler.proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: (
|
||||
"resampler.query",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
|
||||
"v.token_embd.img_break", # for pixtral, this is a generated vector
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
||||
),
|
||||
|
||||
# audio (mtmd)
|
||||
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: (
|
||||
"audio_tower.embed_positions", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV1D: (
|
||||
"audio_tower.conv{bid}", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
|
||||
MODEL_TENSOR.A_POST_NORM: (
|
||||
"audio_tower.layer_norm", # ultravox
|
||||
"audio_tower.ln_post", # qwen2omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: (
|
||||
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: (
|
||||
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: (
|
||||
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: (
|
||||
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: (
|
||||
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
|
||||
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_UP: (
|
||||
"audio_tower.layers.{bid}.fc1", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE: (),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN: (
|
||||
"audio_tower.layers.{bid}.fc2", # ultravox
|
||||
),
|
||||
|
||||
# note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors
|
||||
# this prefix is added in the conversion code in modify_tensors()
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
"audio.multi_modal_projector.linear", # qwen2audio
|
||||
"audio_tower.proj", # qwen2omni
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||
"audio.multi_modal_projector.ln_pre", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_MID: (
|
||||
"audio.multi_modal_projector.ln_mid", # ultravox
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
||||
@@ -231,7 +231,7 @@ class SafetensorRemote:
|
||||
response.raise_for_status()
|
||||
|
||||
# Get raw byte data
|
||||
return response.content[:size]
|
||||
return response.content[slice(size if size > -1 else None)]
|
||||
|
||||
@classmethod
|
||||
def check_file_exist(cls, url: str) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.16.0"
|
||||
version = "0.17.0"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
@@ -23,16 +23,21 @@ numpy = ">=1.17"
|
||||
tqdm = ">=4.27"
|
||||
pyyaml = ">=5.1"
|
||||
sentencepiece = ">=0.1.98,<=0.2.0"
|
||||
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^5.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
gui = ["PySide6"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
gguf-convert-endian = "gguf.scripts:gguf_convert_endian_entrypoint"
|
||||
gguf-dump = "gguf.scripts:gguf_dump_entrypoint"
|
||||
gguf-set-metadata = "gguf.scripts:gguf_set_metadata_entrypoint"
|
||||
gguf-new-metadata = "gguf.scripts:gguf_new_metadata_entrypoint"
|
||||
gguf-convert-endian = "gguf.scripts.gguf_convert_endian:main"
|
||||
gguf-dump = "gguf.scripts.gguf_dump:main"
|
||||
gguf-set-metadata = "gguf.scripts.gguf_set_metadata:main"
|
||||
gguf-new-metadata = "gguf.scripts.gguf_new_metadata:main"
|
||||
gguf-editor-gui = "gguf.scripts.gguf_editor_gui:main"
|
||||
|
||||
Reference in New Issue
Block a user