mtmd : support Kimi VL model (#15458)

* convert : fix tensor naming conflict for llama 4 vision

* convert ok

* support kimi vision model

* clean up

* fix style

* fix calc number of output tokens

* refactor resize_position_embeddings

* add test case

* rename build fn

* correct a small bug
This commit is contained in:
Xuan-Son Nguyen
2025-08-26 12:54:19 +02:00
committed by GitHub
parent 85cc1ae998
commit 79a546220c
6 changed files with 211 additions and 61 deletions

View File

@@ -135,6 +135,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_UNKNOWN,
};
@@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@@ -526,57 +526,16 @@ struct clip_graph {
cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
// pixel_shuffle
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
const int scale_factor = model.hparams.proj_scale_factor;
const int n_embd = cur->ne[0];
const int seq = cur->ne[1];
const int bsz = 1; // batch size, always 1 for now since we don't support batching
const int height = std::sqrt(seq);
const int width = std::sqrt(seq);
GGML_ASSERT(scale_factor != 0);
cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_4d(ctx0, cur,
n_embd * scale_factor * scale_factor,
height / scale_factor,
width / scale_factor,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_3d(ctx0, cur,
n_embd * scale_factor * scale_factor,
seq / (scale_factor * scale_factor),
bsz);
cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block
const int scale_factor = model.hparams.proj_scale_factor;
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
if (pad_width || pad_height) {
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
width += pad_width;
height += pad_height;
}
// unshuffle h
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// unshuffle w
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = build_patch_merge_permute(cur, scale_factor);
// projection
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
@@ -1086,7 +1045,7 @@ struct clip_graph {
n_patches_x / scale_factor,
n_patches_y / scale_factor,
bsz);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
//cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// flatten to 2D
cur = ggml_cont_2d(ctx0, cur,
n_embd * scale_factor * scale_factor,
@@ -1113,6 +1072,67 @@ struct clip_graph {
return gf;
}
ggml_cgraph * build_kimivl() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings();
// build ViT with 2D position embeddings
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
// first half is X axis and second half is Y axis
return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
};
ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
learned_pos_embd,
add_pos);
cb(cur, "vit_out", -1);
{
// patch_merger
const int scale_factor = model.hparams.proj_scale_factor;
cur = build_patch_merge_permute(cur, scale_factor);
// projection norm
int proj_inp_dim = cur->ne[0];
cur = ggml_view_2d(ctx0, cur,
n_embd, cur->ne[1] * scale_factor * scale_factor,
ggml_row_size(cur->type, n_embd), 0);
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_view_2d(ctx0, cur,
proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
ggml_row_size(cur->type, proj_inp_dim), 0);
cb(cur, "proj_inp_normed", -1);
// projection mlp
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_add(ctx0, cur, model.mm_1_b);
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
cur = ggml_add(ctx0, cur, model.mm_2_b);
cb(cur, "proj_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
@@ -1611,18 +1631,20 @@ private:
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
if (!pos_embd || height * width == pos_embd->ne[1]) {
GGML_ASSERT(pos_embd);
if (height == n_per_side && width == n_per_side) {
return pos_embd;
}
const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
pos_embd = ggml_cont(ctx0, pos_embd);
pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
return pos_embd;
}
@@ -2021,6 +2043,39 @@ private:
return cur;
}
// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
// support dynamic resolution
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
if (pad_width || pad_height) {
cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
width += pad_width;
height += pad_height;
}
// unshuffle h
cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// unshuffle w
cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cb(cur, "pixel_shuffle", -1);
return cur;
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
@@ -2063,6 +2118,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_whisper_enc();
} break;
case PROJECTOR_TYPE_KIMIVL:
{
res = graph.build_kimivl();
} break;
default:
{
res = graph.build_llava();
@@ -2313,6 +2372,12 @@ struct clip_model_loader {
hparams.image_size = 1024;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_KIMIVL:
{
hparams.rope_theta = 10000.0f;
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
@@ -2477,7 +2542,20 @@ struct clip_model_loader {
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
bool is_ffn_swapped = (
// only old models need this fix
model.proj_type == PROJECTOR_TYPE_MLP
|| model.proj_type == PROJECTOR_TYPE_MLP_NORM
|| model.proj_type == PROJECTOR_TYPE_LDP
|| model.proj_type == PROJECTOR_TYPE_LDPV2
|| model.proj_type == PROJECTOR_TYPE_QWEN2VL
|| model.proj_type == PROJECTOR_TYPE_QWEN25VL
|| model.proj_type == PROJECTOR_TYPE_GLM_EDGE
|| model.proj_type == PROJECTOR_TYPE_GEMMA3
|| model.proj_type == PROJECTOR_TYPE_IDEFICS3
|| model.proj_type == PROJECTOR_TYPE_MINICPMV
) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
if (is_ffn_swapped) {
// swap up and down weights
ggml_tensor * tmp = layer.ff_up_w;
layer.ff_up_w = layer.ff_down_w;
@@ -2486,6 +2564,9 @@ struct clip_model_loader {
tmp = layer.ff_up_b;
layer.ff_up_b = layer.ff_down_b;
layer.ff_down_b = tmp;
if (il == 0) {
LOG_WRN("%s: ffn up/down are swapped\n", __func__);
}
}
}
@@ -2604,6 +2685,7 @@ struct clip_model_loader {
model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3507,7 +3589,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_y = inst.grid_size.height;
return true;
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
} else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
|| ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
) {
GGML_ASSERT(params.proj_scale_factor);
// smart resize
@@ -3708,12 +3792,21 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_LLAMA4:
case PROJECTOR_TYPE_LFM2:
{
// both W and H are divided by proj_scale_factor
// both X and Y are downscaled by the scale factor
int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches /= (scale_factor * scale_factor);
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
{
// dynamic size
int scale_factor = ctx->model.hparams.proj_scale_factor;
int out_patch_size = params.patch_size * scale_factor;
int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
// dynamic size
@@ -4096,6 +4189,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
@@ -4250,6 +4344,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");

View File

@@ -86,6 +86,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then
add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
# add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"