mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	clip : Add Qwen2.5VL support (#12402)
* implment vision model architecture, gguf convertor * handle window attention inputs * add debug utils * fix few incorrect tensor memory layout * move position id remap out of ggml to avoid int32 cuda operations * cleaning up * ignore transformers Qwen2_5_xxx type check * remove not so often use `qwen2vl-cli` debug functions * remove commented-out code blocks * fix attn weight scaling after rebase * add `PROJECTOR_TYPE_QWEN2_5_VL` * remove `KEY_USE_GLU_MLP`, `KEY_USE_RMS_NORM` * replace `KEY_FULLATTN_BLK_IDX` with `KEY_WIN_ATTN_PATTERN` * remove `attn_window_size` from gguf * fix model conversion * clean up * fix merging problem * add test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
		| @@ -28,6 +28,7 @@ | ||||
| #include <cinttypes> | ||||
| #include <limits> | ||||
| #include <array> | ||||
| #include <numeric> | ||||
|  | ||||
| struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; | ||||
|  | ||||
| @@ -169,6 +170,8 @@ struct clip_hparams { | ||||
|     std::vector<int32_t> image_grid_pinpoints; | ||||
|     int32_t image_crop_resolution; | ||||
|     std::unordered_set<int32_t> vision_feature_layer; | ||||
|     int32_t attn_window_size; | ||||
|     int32_t n_wa_pattern; | ||||
| }; | ||||
|  | ||||
| struct clip_layer { | ||||
| @@ -200,6 +203,9 @@ struct clip_layer { | ||||
|     struct ggml_tensor * ff_down_w = nullptr; | ||||
|     struct ggml_tensor * ff_down_b = nullptr; | ||||
|  | ||||
|     struct ggml_tensor * ff_g_w = NULL; | ||||
|     struct ggml_tensor * ff_g_b = NULL; | ||||
|  | ||||
|     // layernorm 2 | ||||
|     struct ggml_tensor * ln_2_w = nullptr; | ||||
|     struct ggml_tensor * ln_2_b = nullptr; | ||||
| @@ -319,6 +325,7 @@ struct clip_ctx { | ||||
|     float image_std[3]; | ||||
|     bool use_gelu = false; | ||||
|     bool use_silu = false; | ||||
|     int32_t ftype = 1; | ||||
|  | ||||
|     gguf_context_ptr ctx_gguf; | ||||
|     ggml_context_ptr ctx_data; | ||||
| @@ -762,6 +769,236 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i | ||||
|     return gf; | ||||
| } | ||||
|  | ||||
| static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) { | ||||
|     const auto & model = ctx->vision_model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
|     const int image_size_width  = imgs.entries[0]->nx; | ||||
|     const int image_size_height = imgs.entries[0]->ny; | ||||
|  | ||||
|     const bool use_mrope       = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; | ||||
|     const bool use_window_attn = hparams.n_wa_pattern > 0; | ||||
|  | ||||
|     const int n_wa_pattern         = hparams.n_wa_pattern; | ||||
|     const int patch_size           = hparams.patch_size; | ||||
|     const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size)); | ||||
|     const int patches_w            = image_size_width / patch_size; | ||||
|     const int patches_h            = image_size_height / patch_size; | ||||
|     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0); | ||||
|     const int num_position_ids     = use_mrope ? num_positions * 4 : num_positions; | ||||
|     const int hidden_size          = hparams.hidden_size; | ||||
|     const int n_head               = hparams.n_head; | ||||
|     const int d_head               = hidden_size / n_head; | ||||
|     const float eps                = hparams.eps; | ||||
|  | ||||
|     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; | ||||
|  | ||||
|     const int batch_size = imgs.entries.size(); | ||||
|     GGML_ASSERT(batch_size == 1); | ||||
|  | ||||
|     struct ggml_init_params params = { | ||||
|         /*.mem_size   =*/ ctx->buf_compute_meta.size(), | ||||
|         /*.mem_buffer =*/ ctx->buf_compute_meta.data(), | ||||
|         /*.no_alloc   =*/ true, | ||||
|     }; | ||||
|  | ||||
|     ggml_context_ptr ctx0_ptr(ggml_init(params)); | ||||
|     auto ctx0 = ctx0_ptr.get(); | ||||
|  | ||||
|     struct ggml_cgraph * gf = ggml_new_graph(ctx0); | ||||
|  | ||||
|     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); | ||||
|     ggml_set_name(inp_raw, "inp_raw"); | ||||
|     ggml_set_input(inp_raw); | ||||
|  | ||||
|     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); | ||||
|  | ||||
|     GGML_ASSERT(image_size_width  % (patch_size * 2) == 0); | ||||
|     GGML_ASSERT(image_size_height % (patch_size * 2) == 0); | ||||
|  | ||||
|     auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); | ||||
|     inp = ggml_add(ctx0, inp, inp_1); | ||||
|  | ||||
|     inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b] | ||||
|     inp = ggml_reshape_4d( | ||||
|         ctx0, inp, | ||||
|         hidden_size * 2, patches_w / 2, patches_h, batch_size); | ||||
|     inp = ggml_reshape_4d( | ||||
|         ctx0, inp, | ||||
|         hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); | ||||
|     inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); | ||||
|     inp = ggml_reshape_3d( | ||||
|         ctx0, inp, | ||||
|         hidden_size, patches_w * patches_h, batch_size); | ||||
|  | ||||
|     if (model.patch_bias) { | ||||
|         // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); | ||||
|         inp = ggml_add(ctx0, inp, model.patch_bias); | ||||
|     } | ||||
|     struct ggml_tensor * embeddings     = inp; | ||||
|     struct ggml_tensor * window_mask    = nullptr; | ||||
|     struct ggml_tensor * window_idx     = nullptr; | ||||
|     struct ggml_tensor * inv_window_idx = nullptr; | ||||
|  | ||||
|     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); | ||||
|     ggml_set_name(positions, "positions"); | ||||
|     ggml_set_input(positions); | ||||
|  | ||||
|     // pre-layernorm | ||||
|     if (model.pre_ln_w) { | ||||
|         embeddings = ggml_rms_norm(ctx0, embeddings, eps); | ||||
|         ggml_set_name(embeddings, "pre_ln"); | ||||
|  | ||||
|         embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w); | ||||
|     } | ||||
|  | ||||
|     if (use_window_attn) { | ||||
|         // handle window attention inputs | ||||
|         inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); | ||||
|         ggml_set_name(inv_window_idx, "inv_window_idx"); | ||||
|         ggml_set_input(inv_window_idx); | ||||
|         // mask for window attention | ||||
|         window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions); | ||||
|         ggml_set_name(window_mask, "window_mask"); | ||||
|         ggml_set_input(window_mask); | ||||
|  | ||||
|         // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] | ||||
|         GGML_ASSERT(batch_size == 1); | ||||
|         embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4); | ||||
|         embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx); | ||||
|         embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size); | ||||
|     } | ||||
|  | ||||
|     // loop over layers | ||||
|     for (int il = 0; il < ctx->max_feature_layer; il++) { | ||||
|         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states | ||||
|  | ||||
|         // rmsnorm1 | ||||
|         cur = ggml_rms_norm(ctx0, cur, eps); | ||||
|         cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w); | ||||
|  | ||||
|         // self-attention | ||||
|         { | ||||
|  | ||||
|             struct ggml_tensor * Q = | ||||
|                 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); | ||||
|  | ||||
|             Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); | ||||
|             Q = ggml_rope_multi( | ||||
|                 ctx0, Q, positions, nullptr, | ||||
|                 d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); | ||||
|             Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); | ||||
|             Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); | ||||
|  | ||||
|             struct ggml_tensor * K = | ||||
|                 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); | ||||
|  | ||||
|             K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); | ||||
|             K = ggml_rope_multi( | ||||
|                 ctx0, K, positions, nullptr, | ||||
|                 d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); | ||||
|             K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); | ||||
|             K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); | ||||
|  | ||||
|             struct ggml_tensor * V = | ||||
|                 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); | ||||
|  | ||||
|             V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); | ||||
|             V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); | ||||
|             V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); | ||||
|  | ||||
|             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); | ||||
|             const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; | ||||
|             if (full_attn) { | ||||
|                 KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); | ||||
|             } else { | ||||
|                 KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f); | ||||
|             } | ||||
|  | ||||
|             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); | ||||
|             KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); | ||||
|             KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); | ||||
|  | ||||
|             cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); | ||||
|         } | ||||
|  | ||||
|         // attention output | ||||
|         cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); | ||||
|  | ||||
|         // re-add the layer input, e.g., residual | ||||
|         cur = ggml_add(ctx0, cur, embeddings); | ||||
|  | ||||
|         embeddings = cur; // embeddings = residual, cur = hidden_states | ||||
|  | ||||
|         // rms norm2 | ||||
|         cur = ggml_rms_norm(ctx0, cur, eps); | ||||
|         cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w); | ||||
|  | ||||
|         // mlp | ||||
|         // ffn_up | ||||
|         auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); | ||||
|         cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_o_b); | ||||
|  | ||||
|         auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_g_w, cur); | ||||
|         cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_g_b); | ||||
|         // TODO : only 2 of these 3 are actually used, should we remove one of them? | ||||
|         if (ctx->use_gelu) { | ||||
|             cur_gate = ggml_gelu_inplace(ctx0, cur_gate); | ||||
|         } else if (ctx->use_silu) { | ||||
|             cur_gate = ggml_silu_inplace(ctx0, cur_gate); | ||||
|         } else { | ||||
|             cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate); | ||||
|         } | ||||
|         cur = ggml_mul(ctx0, cur_gate, cur_up); | ||||
|  | ||||
|         // ffn_down | ||||
|         cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); | ||||
|         cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); | ||||
|  | ||||
|         // residual 2 | ||||
|         cur = ggml_add(ctx0, embeddings, cur); | ||||
|  | ||||
|         embeddings = cur; | ||||
|     } | ||||
|  | ||||
|     // post-layernorm | ||||
|     if (model.post_ln_w) { | ||||
|         embeddings = ggml_rms_norm(ctx0, embeddings, eps); | ||||
|         ggml_set_name(embeddings, "post_ln"); | ||||
|  | ||||
|         embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w); | ||||
|     } | ||||
|  | ||||
|     embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); | ||||
|  | ||||
|     embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); | ||||
|     embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); | ||||
|  | ||||
|     // GELU activation | ||||
|     embeddings = ggml_gelu(ctx0, embeddings); | ||||
|  | ||||
|     // Second linear layer | ||||
|     embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); | ||||
|     embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); | ||||
|  | ||||
|     if (use_window_attn) { | ||||
|         window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4); | ||||
|         ggml_set_name(window_idx, "window_idx"); | ||||
|         ggml_set_input(window_idx); | ||||
|  | ||||
|         // embeddings shape: [hidden_size, patches_w * patches_h, batch_size] | ||||
|         GGML_ASSERT(batch_size == 1); | ||||
|         embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4); | ||||
|         embeddings = ggml_get_rows(ctx0, embeddings, window_idx); | ||||
|         embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size); | ||||
|     } | ||||
|  | ||||
|     // build the graph | ||||
|     ggml_build_forward_expand(gf, embeddings); | ||||
|  | ||||
|     return gf; | ||||
| } | ||||
|  | ||||
| static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { | ||||
|     const auto & model = ctx->vision_model; | ||||
|     const auto & hparams = model.hparams; | ||||
| @@ -1331,6 +1568,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | ||||
|                 GGML_ASSERT(imgs.entries.size() == 1); | ||||
|                 res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]); | ||||
|             } break; | ||||
|         case PROJECTOR_TYPE_QWEN25VL: | ||||
|             { | ||||
|                 res = clip_image_build_graph_qwen25vl(ctx, imgs); | ||||
|             } break; | ||||
|         default: | ||||
|             { | ||||
|                 // TODO: we should have one build_* function per model | ||||
| @@ -1507,6 +1748,10 @@ struct clip_model_loader { | ||||
|                     { | ||||
|                         hparams.rope_theta = 10000.0f; | ||||
|                     } break; | ||||
|                 case PROJECTOR_TYPE_QWEN25VL: | ||||
|                     { | ||||
|                         get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); | ||||
|                     } break; | ||||
|                 default: | ||||
|                     break; | ||||
|             } | ||||
| @@ -1600,8 +1845,10 @@ struct clip_model_loader { | ||||
|             // legacy naming (the in and out is reversed! don't ask me why) | ||||
|             layer.ff_i_w = layer.ff_down_w; | ||||
|             layer.ff_o_w = layer.ff_up_w; | ||||
|             layer.ff_g_w = layer.ff_gate_w; | ||||
|             layer.ff_i_b = layer.ff_down_b; | ||||
|             layer.ff_o_b = layer.ff_up_b; | ||||
|             layer.ff_g_b = layer.ff_gate_b; | ||||
|         } | ||||
|  | ||||
|         switch (ctx_clip.proj_type) { | ||||
| @@ -1700,6 +1947,7 @@ struct clip_model_loader { | ||||
|                     vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); | ||||
|                 } break; | ||||
|             case PROJECTOR_TYPE_QWEN2VL: | ||||
|             case PROJECTOR_TYPE_QWEN25VL: | ||||
|                 { | ||||
|                     vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); | ||||
|                     vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); | ||||
| @@ -2651,7 +2899,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i | ||||
|         else { | ||||
|             GGML_ABORT("Unknown minicpmv version"); | ||||
|         } | ||||
|     } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { | ||||
|     } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { | ||||
|         int patch_size = params.patch_size * 2; | ||||
|         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); | ||||
|         int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); | ||||
| @@ -2792,6 +3040,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | ||||
|     const int pos_w = ctx->load_image_size.width / patch_size; | ||||
|     const int pos_h = ctx->load_image_size.height / patch_size; | ||||
|  | ||||
|     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl | ||||
|  | ||||
|     { | ||||
|         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); | ||||
|         std::vector<float> inp_data(ggml_nelements(inp_raw)); | ||||
| @@ -2890,31 +3140,93 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | ||||
|         // non-minicpmv models | ||||
|  | ||||
|         if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { | ||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); | ||||
|             // pw * ph = number of tokens output by ViT after apply patch merger | ||||
|             // ipw * ipw = number of vision token been processed inside ViT | ||||
|             const int merge_ratio = 2; | ||||
|             const int pw  = image_size_width  / patch_size / merge_ratio; | ||||
|             const int ph  = image_size_height / patch_size / merge_ratio; | ||||
|             const int ipw = image_size_width  / patch_size; | ||||
|             const int iph = image_size_height / patch_size; | ||||
|  | ||||
|             const int pw = image_size_width / patch_size; | ||||
|             const int ph = image_size_height / patch_size; | ||||
|             int* positions_data = (int*)malloc(ggml_nbytes(positions)); | ||||
|             std::vector<int> idx    (ph * pw); | ||||
|             std::vector<int> inv_idx(ph * pw); | ||||
|  | ||||
|             if (use_window_attn) { | ||||
|                 const int attn_window_size = 112; | ||||
|                 struct ggml_tensor * window_idx     = ggml_graph_get_tensor(gf, "window_idx"); | ||||
|                 struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); | ||||
|                 struct ggml_tensor * window_mask    = ggml_graph_get_tensor(gf, "window_mask"); | ||||
|  | ||||
|                 const int grid_window = attn_window_size / patch_size / merge_ratio; | ||||
|                 int dst = 0; | ||||
|                 // [num_vision_tokens, num_vision_tokens] attention mask tensor | ||||
|                 std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest()); | ||||
|                 int mask_row = 0; | ||||
|  | ||||
|                 for (int y = 0; y < ph; y += grid_window) | ||||
|                 { | ||||
|                     for (int x = 0; x < pw; x += grid_window) | ||||
|                     { | ||||
|                         const int win_h = std::min(grid_window, ph - y); | ||||
|                         const int win_w = std::min(grid_window, pw - x); | ||||
|                         const int dst_0 = dst; | ||||
|                         // group all tokens belong to the same window togather (to a continue range) | ||||
|                         for (int dy = 0; dy < win_h; dy++) { | ||||
|                             for (int dx = 0; dx < win_w; dx++) { | ||||
|                                 const int src = (y + dy) * pw + (x + dx); | ||||
|                                 assert(src < (int)idx.size()); | ||||
|                                 assert(dst < (int)inv_idx.size()); | ||||
|                                 idx    [src] = dst; | ||||
|                                 inv_idx[dst] = src; | ||||
|                                 dst++; | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { | ||||
|                             int row_offset = mask_row * (ipw * iph); | ||||
|                             std::fill( | ||||
|                                 mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), | ||||
|                                 mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio), | ||||
|                                 0.0); | ||||
|                             mask_row++; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 ggml_backend_tensor_set(window_idx,     idx.data(),     0, ggml_nbytes(window_idx)); | ||||
|                 ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); | ||||
|                 ggml_backend_tensor_set(window_mask,    mask.data(),    0, ggml_nbytes(window_mask)); | ||||
|             } else { | ||||
|                 std::iota(idx.begin(), idx.end(), 0); | ||||
|                 std::iota(inv_idx.begin(), inv_idx.end(), 0); | ||||
|             } | ||||
|  | ||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); | ||||
|             const int mpow = merge_ratio * merge_ratio; | ||||
|             std::vector<int> positions_data(ggml_nelements(positions)); | ||||
|             int * data = positions_data.data(); | ||||
|  | ||||
|             int ptr = 0; | ||||
|             for (int y = 0; y < ph; y+=2) | ||||
|             for (int y = 0; y < iph; y += merge_ratio) | ||||
|             { | ||||
|                 for (int x = 0; x < pw; x+=2) | ||||
|                 for (int x = 0; x < ipw; x += merge_ratio) | ||||
|                 { | ||||
|                     for (int dy = 0; dy < 2; dy++) { | ||||
|                         for (int dx = 0; dx < 2; dx++) { | ||||
|                             positions_data[ptr]                 = y + dy; | ||||
|                             positions_data[num_patches + ptr]     = x + dx; | ||||
|                             positions_data[num_patches * 2 + ptr] = y + dy; | ||||
|                             positions_data[num_patches * 3 + ptr] = x + dx; | ||||
|                             auto remap = idx[ptr / mpow]; | ||||
|                             remap = remap * mpow + (ptr % mpow); | ||||
|  | ||||
|                             data[                  remap] = y + dy; | ||||
|                             data[    num_patches + remap] = x + dx; | ||||
|                             data[2 * num_patches + remap] = y + dy; | ||||
|                             data[3 * num_patches + remap] = x + dx; | ||||
|                             ptr++; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); | ||||
|             free(positions_data); | ||||
|             ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); | ||||
|         } | ||||
|         else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { | ||||
|             // do nothing | ||||
| @@ -2967,6 +3279,65 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (use_window_attn && ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { | ||||
|         struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); | ||||
|         struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); | ||||
|         struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); | ||||
|  | ||||
|         const int merge_ratio = 2; | ||||
|         const int attn_window_size = 112; | ||||
|         const int pw = image_size_width / patch_size / merge_ratio; | ||||
|         const int ph = image_size_height / patch_size / merge_ratio; | ||||
|         const int grid_window = attn_window_size / patch_size / merge_ratio; | ||||
|         const int ipw = image_size_width / patch_size; | ||||
|         const int iph = image_size_height / patch_size; | ||||
|         /* | ||||
|         pw * ph = number of tokens output by ViT after apply patch merger | ||||
|         ipw * ipw = number of vision token been processed inside ViT | ||||
|         */ | ||||
|  | ||||
|         std::vector<int> idx(ph * pw); | ||||
|         std::vector<int> inv_idx(ph * pw); | ||||
|         int dst = 0; | ||||
|         // [num_vision_tokens, num_vision_tokens] attention mask tensor | ||||
|         std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest()); | ||||
|         int mask_row = 0; | ||||
|  | ||||
|         for (int y = 0; y < ph; y+=grid_window) | ||||
|         { | ||||
|             for (int x = 0; x < pw; x+=grid_window) | ||||
|             { | ||||
|                 const int win_h = std::min(grid_window, ph - y); | ||||
|                 const int win_w = std::min(grid_window, pw - x); | ||||
|                 const int dst_0 = dst; | ||||
|                 // group all tokens belong to the same window togather (to a continue range) | ||||
|                 for (int dy = 0; dy < win_h; dy++) { | ||||
|                     for (int dx = 0; dx < win_w; dx++) { | ||||
|                         const int src = (y + dy) * pw + (x + dx); | ||||
|                         assert(src < (int)idx.size()); | ||||
|                         assert(dst < (int)inv_idx.size()); | ||||
|                         idx[src] = dst; | ||||
|                         inv_idx[dst] = src; | ||||
|                         dst++; | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { | ||||
|                     int row_offset = mask_row * (ipw * iph); | ||||
|                     std::fill( | ||||
|                         mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), | ||||
|                         mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio), | ||||
|                         0.0); | ||||
|                     mask_row++; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); | ||||
|         ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); | ||||
|         ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); | ||||
|     } | ||||
|  | ||||
|     ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); | ||||
|  | ||||
|     auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); | ||||
| @@ -3142,6 +3513,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { | ||||
|         case PROJECTOR_TYPE_GLM_EDGE: | ||||
|             return ctx->vision_model.mm_model_mlp_3_w->ne[1]; | ||||
|         case PROJECTOR_TYPE_QWEN2VL: | ||||
|         case PROJECTOR_TYPE_QWEN25VL: | ||||
|             return ctx->vision_model.mm_1_b->ne[0]; | ||||
|         case PROJECTOR_TYPE_GEMMA3: | ||||
|             return ctx->vision_model.mm_input_proj_w->ne[0]; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 HimariO
					HimariO