mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
support qwen3vl series.
Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com> Co-authored-by: yairpatch <yairpatch@users.noreply.github.com> Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
This commit is contained in:
@@ -196,6 +196,8 @@ struct clip_hparams {
|
||||
int32_t n_wa_pattern = 0;
|
||||
int32_t spatial_merge_size = 0;
|
||||
|
||||
std::vector<int32_t> deepstack_layers; // qwen3vl deepstack layers
|
||||
|
||||
// audio
|
||||
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||
int32_t proj_stack_factor = 0; // ultravox
|
||||
@@ -359,6 +361,17 @@ struct clip_model {
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
|
||||
// qwen3vl deepstack
|
||||
struct deepstack_merger {
|
||||
ggml_tensor * norm_w = nullptr;
|
||||
ggml_tensor * norm_b = nullptr;
|
||||
ggml_tensor * fc1_w = nullptr;
|
||||
ggml_tensor * fc1_b = nullptr;
|
||||
ggml_tensor * fc2_w = nullptr;
|
||||
ggml_tensor * fc2_b = nullptr;
|
||||
};
|
||||
std::vector<deepstack_merger> deepstack_mergers;
|
||||
|
||||
bool audio_has_avgpool() const {
|
||||
return proj_type == PROJECTOR_TYPE_QWEN2A
|
||||
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
|
||||
@@ -831,6 +844,201 @@ struct clip_graph {
|
||||
return gf;
|
||||
}
|
||||
|
||||
// Qwen3VL
|
||||
ggml_cgraph * build_qwen3vl() {
|
||||
GGML_ASSERT(model.patch_bias != nullptr);
|
||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||
GGML_ASSERT(model.class_embedding == nullptr);
|
||||
GGML_ASSERT(!hparams.deepstack_layers.empty());
|
||||
|
||||
const int batch_size = 1;
|
||||
const int n_pos = n_patches;
|
||||
const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
|
||||
|
||||
norm_type norm_t = NORM_TYPE_NORMAL;
|
||||
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
ggml_tensor * inp_raw = build_inp_raw();
|
||||
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
||||
|
||||
// second conv dimension
|
||||
{
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
|
||||
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_cont_4d(
|
||||
ctx0, inp,
|
||||
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
||||
inp = ggml_cont_3d(
|
||||
ctx0, inp,
|
||||
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||
}
|
||||
|
||||
// add patch bias
|
||||
if (model.patch_bias != nullptr) {
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
cb(inp, "patch_bias", -1);
|
||||
}
|
||||
|
||||
// calculate absolute position embedding and apply
|
||||
ggml_tensor * learned_pos_embd = resize_position_embeddings();
|
||||
learned_pos_embd = ggml_cont_4d(
|
||||
ctx0, learned_pos_embd,
|
||||
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
||||
learned_pos_embd = ggml_reshape_4d(
|
||||
ctx0, learned_pos_embd,
|
||||
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
||||
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
|
||||
learned_pos_embd = ggml_cont_3d(
|
||||
ctx0, learned_pos_embd,
|
||||
n_embd, n_patches_x * n_patches_y, batch_size);
|
||||
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
||||
cb(inp, "inp_pos_emb", -1);
|
||||
|
||||
ggml_tensor * inpL = inp;
|
||||
|
||||
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
// pre-layernorm
|
||||
if (model.pre_ln_w) {
|
||||
inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
|
||||
}
|
||||
|
||||
// deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
|
||||
ggml_tensor * deepstack_features = nullptr;
|
||||
const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
auto & layer = model.layers[il];
|
||||
|
||||
ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
|
||||
|
||||
// layernorm1
|
||||
cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
|
||||
cb(cur, "ln1", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
|
||||
ggml_tensor * Kcur = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
|
||||
ggml_tensor * Vcur = ggml_add(ctx0,
|
||||
ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// apply M-RoPE
|
||||
Qcur = ggml_rope_multi(
|
||||
ctx0, Qcur, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
Kcur = ggml_rope_multi(
|
||||
ctx0, Kcur, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
|
||||
cur = build_attn(layer.o_w, layer.o_b,
|
||||
Qcur, Kcur, Vcur, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
inpL = cur; // inpL = residual, cur = hidden_states
|
||||
|
||||
cb(cur, "ffn_inp", il);
|
||||
|
||||
// layernorm2
|
||||
cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
|
||||
cb(cur, "ffn_inp_normed", il);
|
||||
|
||||
// ffn
|
||||
cur = build_ffn(cur,
|
||||
layer.ff_up_w, layer.ff_up_b,
|
||||
layer.ff_gate_w, layer.ff_gate_b,
|
||||
layer.ff_down_w, layer.ff_down_b,
|
||||
hparams.ffn_op, il);
|
||||
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, inpL, cur);
|
||||
cb(cur, "layer_out", il);
|
||||
|
||||
if (std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) != hparams.deepstack_layers.end()) {
|
||||
const int deepstack_idx = std::find(hparams.deepstack_layers.begin(), hparams.deepstack_layers.end(), il) - hparams.deepstack_layers.begin();
|
||||
auto & merger = model.deepstack_mergers[deepstack_idx];
|
||||
ggml_tensor * feat = ggml_dup(ctx0, cur);
|
||||
feat = ggml_reshape_3d(ctx0, feat, n_embd * merge_factor, n_pos / merge_factor, batch_size);
|
||||
|
||||
feat = build_norm(feat, merger.norm_w, merger.norm_b, norm_t, eps, il);
|
||||
feat = ggml_mul_mat(ctx0, merger.fc1_w, feat);
|
||||
feat = ggml_add(ctx0, feat, merger.fc1_b);
|
||||
|
||||
feat = ggml_gelu(ctx0, feat);
|
||||
|
||||
feat = ggml_mul_mat(ctx0, merger.fc2_w, feat);
|
||||
feat = ggml_add(ctx0, feat, merger.fc2_b);
|
||||
|
||||
if(!deepstack_features) {
|
||||
deepstack_features = feat;
|
||||
} else {
|
||||
// concat along the feature dimension
|
||||
deepstack_features = ggml_concat(ctx0, deepstack_features, feat, 0);
|
||||
}
|
||||
}
|
||||
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (model.post_ln_w) {
|
||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
|
||||
}
|
||||
|
||||
// multimodal projection
|
||||
ggml_tensor * embeddings = inpL;
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
// GELU activation
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
|
||||
embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
ggml_cgraph * build_minicpmv() {
|
||||
const int batch_size = 1;
|
||||
|
||||
@@ -2103,6 +2311,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
res = graph.build_qwen2vl();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
res = graph.build_qwen3vl();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_MINICPMV:
|
||||
{
|
||||
res = graph.build_minicpmv();
|
||||
@@ -2421,6 +2633,13 @@ struct clip_model_loader {
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
hparams.image_size = 1024; // still need this?
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||
get_arr_int(KEY_DEEPSTACK_LAYERS, hparams.deepstack_layers, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
@@ -2459,6 +2678,15 @@ struct clip_model_loader {
|
||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
|
||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||
if (hparams.spatial_merge_size > 0) {
|
||||
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
|
||||
}
|
||||
if (!hparams.deepstack_layers.empty()) {
|
||||
LOG_INF("%s: deepstack_layers: ", __func__);
|
||||
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) {
|
||||
LOG_CNT("%d%s", hparams.deepstack_layers[i], i < hparams.deepstack_layers.size() - 1 ? ", " : "\n");
|
||||
}
|
||||
}
|
||||
} else if (is_audio) {
|
||||
LOG_INF("\n--- audio hparams ---\n");
|
||||
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
|
||||
@@ -2691,6 +2919,26 @@ struct clip_model_loader {
|
||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
|
||||
if (!hparams.deepstack_layers.empty()) {
|
||||
model.deepstack_mergers.resize(hparams.deepstack_layers.size());
|
||||
for (size_t i = 0; i < hparams.deepstack_layers.size(); i++) {
|
||||
auto & merger = model.deepstack_mergers[i];
|
||||
merger.norm_w = get_tensor(string_format("v.deepstack.%d.norm.weight", (int)i), false);
|
||||
merger.norm_b = get_tensor(string_format("v.deepstack.%d.norm.bias", (int)i), false);
|
||||
merger.fc1_w = get_tensor(string_format("v.deepstack.%d.fc1.weight", (int)i), false);
|
||||
merger.fc1_b = get_tensor(string_format("v.deepstack.%d.fc1.bias", (int)i), false);
|
||||
merger.fc2_w = get_tensor(string_format("v.deepstack.%d.fc2.weight", (int)i), false);
|
||||
merger.fc2_b = get_tensor(string_format("v.deepstack.%d.fc2.bias", (int)i), false);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
{
|
||||
model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
|
||||
@@ -3554,7 +3802,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->grid_y = inst.grid_size.height;
|
||||
return true;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
clip_image_u8 resized;
|
||||
auto patch_size = params.patch_size * 2;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
|
||||
@@ -3774,7 +4022,7 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->model.hparams;
|
||||
const int n_total = clip_n_output_tokens(ctx, img);
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
|
||||
}
|
||||
return n_total;
|
||||
@@ -3782,7 +4030,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
|
||||
|
||||
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->model.hparams;
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
|
||||
}
|
||||
return 1;
|
||||
@@ -3838,6 +4086,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
// dynamic size (2 conv, so double patch size)
|
||||
int patch_size = params.patch_size * 2;
|
||||
@@ -4142,6 +4391,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
set_input_f32("pos_embed", pos_embed);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
const int merge_ratio = 2;
|
||||
const int pw = image_size_width / patch_size;
|
||||
@@ -4387,6 +4637,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
return ctx->model.mm_1_b->ne[0];
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
return ctx->model.mm_1_b->ne[0] * ((int)ctx->model.hparams.deepstack_layers.size() + 1); // main path + deepstack paths
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
return ctx->model.mm_input_proj_w->ne[0];
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
@@ -4421,7 +4673,8 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
|
||||
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL;
|
||||
}
|
||||
|
||||
bool clip_is_llava(const struct clip_ctx * ctx) {
|
||||
|
||||
Reference in New Issue
Block a user