model: add support for qwen3vl series (#16780)

* support qwen3vl series.

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>

* bugfix: fix the arch check for qwen3vl-moe.

* use build_ffn

* optimize deepstack structure

* optimize deepstack feature saving

* Revert "optimize deepstack feature saving" for temporal fix

This reverts commit f321b9fdf1.

* code clean

* use fused qkv in clip

* clean up / rm is_deepstack_layers for simplification

* add test model

* move test model to "big" section

* fix imrope check

* remove trailing whitespace

* fix rope fail

* metal : add imrope support

* add imrope support for sycl

* vulkan: add imrope w/o check

* fix vulkan

* webgpu: add imrope w/o check

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix tensor mapping

---------

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
JJJYmmm
2025-10-30 23:19:14 +08:00
committed by GitHub
parent dcca0d3ab8
commit d261223d24
28 changed files with 1125 additions and 97 deletions

View File

@@ -241,6 +241,18 @@ struct clip_layer {
// layer scale (no bias)
ggml_tensor * ls_1_w = nullptr;
ggml_tensor * ls_2_w = nullptr;
// qwen3vl deepstack merger
ggml_tensor * deepstack_norm_w = nullptr;
ggml_tensor * deepstack_norm_b = nullptr;
ggml_tensor * deepstack_fc1_w = nullptr;
ggml_tensor * deepstack_fc1_b = nullptr;
ggml_tensor * deepstack_fc2_w = nullptr;
ggml_tensor * deepstack_fc2_b = nullptr;
bool has_deepstack() const {
return deepstack_fc1_w != nullptr;
}
};
struct clip_model {
@@ -260,6 +272,8 @@ struct clip_model {
std::vector<clip_layer> layers;
int32_t n_deepstack_layers = 0; // used by Qwen3-VL, calculated from clip_layer
ggml_tensor * post_ln_w;
ggml_tensor * post_ln_b;
@@ -840,6 +854,189 @@ struct clip_graph {
return gf;
}
// Qwen3VL
ggml_cgraph * build_qwen3vl() {
GGML_ASSERT(model.patch_bias != nullptr);
GGML_ASSERT(model.position_embeddings != nullptr);
GGML_ASSERT(model.class_embedding == nullptr);
const int batch_size = 1;
const int n_pos = n_patches;
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
norm_type norm_t = NORM_TYPE_NORMAL;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
ggml_tensor * inp_raw = build_inp_raw();
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
// second conv dimension
{
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_add(ctx0, inp, inp_1);
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
inp = ggml_cont_4d(
ctx0, inp,
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
inp = ggml_reshape_4d(
ctx0, inp,
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
inp = ggml_cont_3d(
ctx0, inp,
n_embd, n_patches_x * n_patches_y, batch_size);
}
// add patch bias
if (model.patch_bias != nullptr) {
inp = ggml_add(ctx0, inp, model.patch_bias);
cb(inp, "patch_bias", -1);
}
// calculate absolute position embedding and apply
ggml_tensor * learned_pos_embd = resize_position_embeddings();
learned_pos_embd = ggml_cont_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
learned_pos_embd = ggml_reshape_4d(
ctx0, learned_pos_embd,
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
learned_pos_embd = ggml_cont_3d(
ctx0, learned_pos_embd,
n_embd, n_patches_x * n_patches_y, batch_size);
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "inp_pos_emb", -1);
ggml_tensor * inpL = inp;
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
// pre-layernorm
if (model.pre_ln_w) {
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
}
// deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
ggml_tensor * deepstack_features = nullptr;
const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl
// loop over layers
for (int il = 0; il < n_layer; il++) {
auto & layer = model.layers[il];
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
// layernorm1
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
cb(cur, "ln1", il);
// self-attention
{
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
cur = ggml_add(ctx0, cur, layer.qkv_b);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 0);
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], n_embd * sizeof(float));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, d_head*sizeof(float),
cur->nb[1], 2 * n_embd * sizeof(float));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// apply M-RoPE
Qcur = ggml_rope_multi(
ctx0, Qcur, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
Kcur = ggml_rope_multi(
ctx0, Kcur, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
cb(Qcur, "Qcur_rope", il);
cb(Kcur, "Kcur_rope", il);
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, inpL);
inpL = cur; // inpL = residual, cur = hidden_states
cb(cur, "ffn_inp", il);
// layernorm2
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
cb(cur, "ffn_inp_normed", il);
// ffn
cur = build_ffn(cur,
layer.ff_up_w, layer.ff_up_b,
layer.ff_gate_w, layer.ff_gate_b,
layer.ff_down_w, layer.ff_down_b,
hparams.ffn_op, il);
cb(cur, "ffn_out", il);
// residual 2
cur = ggml_add(ctx0, inpL, cur);
cb(cur, "layer_out", il);
if (layer.has_deepstack()) {
ggml_tensor * feat = ggml_reshape_3d(ctx0, cur, n_embd * merge_factor, n_pos / merge_factor, batch_size);
feat = build_norm(feat, layer.deepstack_norm_w, layer.deepstack_norm_b, norm_t, eps, il);
feat = build_ffn(feat,
layer.deepstack_fc1_w, layer.deepstack_fc1_b,
nullptr, nullptr,
layer.deepstack_fc2_w, layer.deepstack_fc2_b,
ffn_op_type::FFN_GELU, il);
if(!deepstack_features) {
deepstack_features = feat;
} else {
// concat along the feature dimension
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
}
}
inpL = cur;
}
// post-layernorm
if (model.post_ln_w) {
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
}
// multimodal projection
ggml_tensor * embeddings = inpL;
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
embeddings = build_ffn(embeddings,
model.mm_0_w, model.mm_0_b,
nullptr, nullptr,
model.mm_1_w, model.mm_1_b,
ffn_op_type::FFN_GELU, -1);
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
// build the graph
ggml_build_forward_expand(gf, embeddings);
return gf;
}
ggml_cgraph * build_minicpmv() {
const int batch_size = 1;
@@ -2211,6 +2408,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_qwen2vl();
} break;
case PROJECTOR_TYPE_QWEN3VL:
{
res = graph.build_qwen3vl();
} break;
case PROJECTOR_TYPE_MINICPMV:
{
res = graph.build_minicpmv();
@@ -2534,6 +2735,12 @@ struct clip_model_loader {
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
} break;
case PROJECTOR_TYPE_QWEN3VL:
{
hparams.image_size = 1024; // still need this?
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_LLAMA4:
{
hparams.rope_theta = 10000.0f;
@@ -2572,6 +2779,9 @@ struct clip_model_loader {
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
if (hparams.spatial_merge_size > 0) {
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
}
} else if (is_audio) {
LOG_INF("\n--- audio hparams ---\n");
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
@@ -2671,6 +2881,18 @@ struct clip_model_loader {
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
// qwen3vl deepstack layer
layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false);
layer.deepstack_norm_b = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "bias"), false);
layer.deepstack_fc1_w = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "weight"), false);
layer.deepstack_fc1_b = get_tensor(string_format(TN_DEEPSTACK_FC1, il, "bias"), false);
layer.deepstack_fc2_w = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "weight"), false);
layer.deepstack_fc2_b = get_tensor(string_format(TN_DEEPSTACK_FC2, il, "bias"), false);
if (layer.has_deepstack()) {
model.n_deepstack_layers++;
}
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
bool is_ffn_swapped = (
@@ -2806,6 +3028,13 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_QWEN3VL:
{
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_GEMMA3:
{
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
@@ -3689,7 +3918,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->grid_y = inst.grid_size.height;
return true;
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
clip_image_u8 resized;
auto patch_size = params.patch_size * 2;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
@@ -3915,7 +4144,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->model.hparams;
const int n_total = clip_n_output_tokens(ctx, img);
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
}
return n_total;
@@ -3923,7 +4152,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->model.hparams;
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
}
return 1;
@@ -3979,6 +4208,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
{
// dynamic size (2 conv, so double patch size)
int patch_size = params.patch_size * 2;
@@ -4292,6 +4522,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
set_input_f32("pos_embed", pos_embed);
} break;
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN3VL:
{
const int merge_ratio = 2;
const int pw = image_size_width / patch_size;
@@ -4540,6 +4771,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
return ctx->model.mm_1_b->ne[0];
case PROJECTOR_TYPE_QWEN3VL:
// main path + deepstack paths
return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers);
case PROJECTOR_TYPE_GEMMA3:
return ctx->model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
@@ -4576,7 +4810,8 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL;
}
bool clip_is_llava(const struct clip_ctx * ctx) {