mtmd : support Kimi VL model (#15458)

* convert : fix tensor naming conflict for llama 4 vision

* convert ok

* support kimi vision model

* clean up

* fix style

* fix calc number of output tokens

* refactor resize_position_embeddings

* add test case

* rename build fn

* correct a small bug
This commit is contained in:
Xuan-Son Nguyen
2025-08-26 12:54:19 +02:00
committed by GitHub
parent 85cc1ae998
commit 79a546220c
6 changed files with 211 additions and 61 deletions

View File

@@ -2850,6 +2850,7 @@ class VisionProjectorType:
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
# Items here are (block size, type size)

View File

@@ -1122,6 +1122,7 @@ class TensorNameMap:
"vision_encoder.patch_conv", # pixtral
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
"vision_tower.patch_embed.proj", # kimi-vl
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -1130,6 +1131,7 @@ class TensorNameMap:
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
"vision_tower.patch_embed.pos_emb", # kimi-vl
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1141,6 +1143,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@@ -1157,6 +1160,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@@ -1173,6 +1177,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
@@ -1185,6 +1190,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_ATTN_O: (
@@ -1197,6 +1203,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@@ -1209,6 +1216,7 @@ class TensorNameMap:
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_FFN_UP: (
@@ -1221,6 +1229,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
@@ -1239,6 +1248,7 @@ class TensorNameMap:
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
@@ -1263,6 +1273,7 @@ class TensorNameMap:
"model.vision_model.post_layernorm", # SmolVLM
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
"vision_tower.encoder.final_layernorm", # kimi-vl
),
MODEL_TENSOR.V_MM_INP_PROJ: (
@@ -1272,6 +1283,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
),