mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	clip : refactor set input for cgraph + fix qwen2.5vl input (#13136)
* clip : refactor set input for cgraph * more strict assert * minicpmv : use clip_n_mmproj_embd instead of copying the same code everywhere * split qwen2 and qwen2.5 code blocks * minor style fix
This commit is contained in:
		| @@ -170,8 +170,8 @@ struct clip_hparams { | |||||||
|     std::vector<int32_t> image_grid_pinpoints; |     std::vector<int32_t> image_grid_pinpoints; | ||||||
|     int32_t image_crop_resolution; |     int32_t image_crop_resolution; | ||||||
|     std::unordered_set<int32_t> vision_feature_layer; |     std::unordered_set<int32_t> vision_feature_layer; | ||||||
|     int32_t attn_window_size; |     int32_t attn_window_size = 0; | ||||||
|     int32_t n_wa_pattern; |     int32_t n_wa_pattern = 0; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct clip_layer { | struct clip_layer { | ||||||
| @@ -325,7 +325,6 @@ struct clip_ctx { | |||||||
|     float image_std[3]; |     float image_std[3]; | ||||||
|     bool use_gelu = false; |     bool use_gelu = false; | ||||||
|     bool use_silu = false; |     bool use_silu = false; | ||||||
|     int32_t ftype = 1; |  | ||||||
|  |  | ||||||
|     gguf_context_ptr ctx_gguf; |     gguf_context_ptr ctx_gguf; | ||||||
|     ggml_context_ptr ctx_data; |     ggml_context_ptr ctx_data; | ||||||
| @@ -776,7 +775,6 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ | |||||||
|     const int image_size_width  = imgs.entries[0]->nx; |     const int image_size_width  = imgs.entries[0]->nx; | ||||||
|     const int image_size_height = imgs.entries[0]->ny; |     const int image_size_height = imgs.entries[0]->ny; | ||||||
|  |  | ||||||
|     const bool use_mrope       = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; |  | ||||||
|     const bool use_window_attn = hparams.n_wa_pattern > 0; |     const bool use_window_attn = hparams.n_wa_pattern > 0; | ||||||
|  |  | ||||||
|     const int n_wa_pattern         = hparams.n_wa_pattern; |     const int n_wa_pattern         = hparams.n_wa_pattern; | ||||||
| @@ -785,10 +783,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ | |||||||
|     const int patches_w            = image_size_width / patch_size; |     const int patches_w            = image_size_width / patch_size; | ||||||
|     const int patches_h            = image_size_height / patch_size; |     const int patches_h            = image_size_height / patch_size; | ||||||
|     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0); |     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0); | ||||||
|     const int num_position_ids     = use_mrope ? num_positions * 4 : num_positions; |     const int num_position_ids     = num_positions * 4; // m-rope requires 4 dim per position | ||||||
|     const int hidden_size          = hparams.hidden_size; |     const int hidden_size          = hparams.hidden_size; | ||||||
|     const int n_head               = hparams.n_head; |     const int n_head               = hparams.n_head; | ||||||
|     const int d_head               = hidden_size / n_head; |     const int d_head               = hidden_size / n_head; | ||||||
|  |     const int n_layer              = hparams.n_layer; | ||||||
|     const float eps                = hparams.eps; |     const float eps                = hparams.eps; | ||||||
|  |  | ||||||
|     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; |     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; | ||||||
| @@ -870,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_ | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // loop over layers |     // loop over layers | ||||||
|     for (int il = 0; il < ctx->max_feature_layer; il++) { |     for (int il = 0; il < n_layer; il++) { | ||||||
|         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states |         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states | ||||||
|  |  | ||||||
|         // rmsnorm1 |         // rmsnorm1 | ||||||
| @@ -1115,15 +1114,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im | |||||||
|     if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { |     if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { | ||||||
|         int pos_w = image_size_width/patch_size; |         int pos_w = image_size_width/patch_size; | ||||||
|         int pos_h = image_size_height/patch_size; |         int pos_h = image_size_height/patch_size; | ||||||
|         if (ctx->minicpmv_version == 2) { |         int n_output_dim = clip_n_mmproj_embd(ctx); | ||||||
|             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); |         pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, pos_w * pos_h, 1); | ||||||
|         } |  | ||||||
|         else if (ctx->minicpmv_version == 3) { |  | ||||||
|             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); |  | ||||||
|         } |  | ||||||
|         else if (ctx->minicpmv_version == 4) { |  | ||||||
|             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); |  | ||||||
|         } |  | ||||||
|         ggml_set_name(pos_embed, "pos_embed"); |         ggml_set_name(pos_embed, "pos_embed"); | ||||||
|         ggml_set_input(pos_embed); |         ggml_set_input(pos_embed); | ||||||
|     } |     } | ||||||
| @@ -1461,23 +1453,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         { // attention |         { // attention | ||||||
|             int hidden_size = 4096; |             int hidden_size = clip_n_mmproj_embd(ctx); | ||||||
|             const int d_head = 128; |             const int d_head = 128; | ||||||
|             int n_head = hidden_size/d_head; |             int n_head = hidden_size/d_head; | ||||||
|             int num_query = 96; |             int num_query = 96; | ||||||
|             if (ctx->minicpmv_version == 2) { |             if (ctx->minicpmv_version == 2) { | ||||||
|                 hidden_size = 4096; |  | ||||||
|                 n_head = hidden_size/d_head; |  | ||||||
|                 num_query = 96; |                 num_query = 96; | ||||||
|             } |             } | ||||||
|             else if (ctx->minicpmv_version == 3) { |             else if (ctx->minicpmv_version == 3) { | ||||||
|                 hidden_size = 3584; |  | ||||||
|                 n_head = hidden_size/d_head; |  | ||||||
|                 num_query = 64; |                 num_query = 64; | ||||||
|             } |             } | ||||||
|             else if (ctx->minicpmv_version == 4) { |             else if (ctx->minicpmv_version == 4) { | ||||||
|                 hidden_size = 3584; |  | ||||||
|                 n_head = hidden_size/d_head; |  | ||||||
|                 num_query = 64; |                 num_query = 64; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -1760,6 +1746,8 @@ struct clip_model_loader { | |||||||
|             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str()); |             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str()); | ||||||
|             LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector); |             LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector); | ||||||
|             LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version); |             LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version); | ||||||
|  |             LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor); | ||||||
|  |             LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern); | ||||||
|             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); |             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); | ||||||
|             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); |             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); | ||||||
|         } |         } | ||||||
| @@ -3038,15 +3026,43 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|     const int patch_size    = hparams.patch_size; |     const int patch_size    = hparams.patch_size; | ||||||
|     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size)); |     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size)); | ||||||
|     const int num_positions = num_patches + (model.class_embedding ? 1 : 0); |     const int num_positions = num_patches + (model.class_embedding ? 1 : 0); | ||||||
|     const int pos_w = ctx->load_image_size.width / patch_size; |     const int pos_w = ctx->load_image_size.width  / patch_size; | ||||||
|     const int pos_h = ctx->load_image_size.height / patch_size; |     const int pos_h = ctx->load_image_size.height / patch_size; | ||||||
|  |  | ||||||
|     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl |     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl | ||||||
|  |  | ||||||
|  |     auto get_inp_tensor = [&gf](const char * name) { | ||||||
|  |         struct ggml_tensor * inp = ggml_graph_get_tensor(gf, name); | ||||||
|  |         if (inp == nullptr) { | ||||||
|  |             GGML_ABORT("Failed to get tensor %s", name); | ||||||
|  |         } | ||||||
|  |         if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) { | ||||||
|  |             GGML_ABORT("Tensor %s is not an input tensor", name); | ||||||
|  |         } | ||||||
|  |         return inp; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector<float> & values) { | ||||||
|  |         ggml_tensor * cur = get_inp_tensor(name); | ||||||
|  |         GGML_ASSERT(cur->type == GGML_TYPE_F32); | ||||||
|  |         GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size()); | ||||||
|  |         ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur)); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector<int32_t> & values) { | ||||||
|  |         ggml_tensor * cur = get_inp_tensor(name); | ||||||
|  |         GGML_ASSERT(cur->type == GGML_TYPE_I32); | ||||||
|  |         GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size()); | ||||||
|  |         ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur)); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // set input pixel values | ||||||
|     { |     { | ||||||
|         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); |         size_t nelem = 0; | ||||||
|         std::vector<float> inp_data(ggml_nelements(inp_raw)); |         for (const auto & img : imgs.entries) { | ||||||
|         float * data = inp_data.data(); |             nelem += img->nx * img->ny * 3; | ||||||
|  |         } | ||||||
|  |         std::vector<float> inp_raw(nelem); | ||||||
|  |  | ||||||
|         // layout of data (note: the channel dim is unrolled to better visualize the layout): |         // layout of data (note: the channel dim is unrolled to better visualize the layout): | ||||||
|         // |         // | ||||||
| @@ -3065,7 +3081,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|             const int n = nx * ny; |             const int n = nx * ny; | ||||||
|  |  | ||||||
|             for (int b = 0; b < batch_size; b++) { |             for (int b = 0; b < batch_size; b++) { | ||||||
|                 float * batch_entry = data + b * (3*n); |                 float * batch_entry = inp_raw.data() + b * (3*n); | ||||||
|                 for (int y = 0; y < ny; y++) { |                 for (int y = 0; y < ny; y++) { | ||||||
|                     for (int x = 0; x < nx; x++) { |                     for (int x = 0; x < nx; x++) { | ||||||
|                         size_t base_src = 3*(y * nx + x); // idx of the first channel |                         size_t base_src = 3*(y * nx + x); // idx of the first channel | ||||||
| @@ -3077,266 +3093,207 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); |         set_input_f32("inp_raw", inp_raw); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { |     // set input per projector | ||||||
|         { |     switch (ctx->proj_type) { | ||||||
|             // inspired from siglip: |         case PROJECTOR_TYPE_MINICPMV: | ||||||
|             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit |             { | ||||||
|             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 |                 // inspired from siglip: | ||||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); |                 //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit | ||||||
|             std::vector<int> pos_data(ggml_nelements(positions)); |                 //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 | ||||||
|             int * data = pos_data.data(); |                 std::vector<int32_t> positions(pos_h * pos_w); | ||||||
|             int bucket_coords_h[1024]; |                 int bucket_coords_h[1024]; | ||||||
|             int bucket_coords_w[1024]; |                 int bucket_coords_w[1024]; | ||||||
|             for (int i = 0; i < pos_h; i++){ |                 for (int i = 0; i < pos_h; i++){ | ||||||
|                 bucket_coords_h[i] = std::floor(70.0*i/pos_h); |                     bucket_coords_h[i] = std::floor(70.0*i/pos_h); | ||||||
|             } |  | ||||||
|             for (int i = 0; i < pos_w; i++){ |  | ||||||
|                 bucket_coords_w[i] = std::floor(70.0*i/pos_w); |  | ||||||
|             } |  | ||||||
|             for (int i = 0, id = 0; i < pos_h; i++){ |  | ||||||
|                 for (int j = 0; j < pos_w; j++){ |  | ||||||
|                     data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; |  | ||||||
|                 } |                 } | ||||||
|             } |                 for (int i = 0; i < pos_w; i++){ | ||||||
|             ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); |                     bucket_coords_w[i] = std::floor(70.0*i/pos_w); | ||||||
|         } |  | ||||||
|  |  | ||||||
|         { |  | ||||||
|             // inspired from resampler of Qwen-VL: |  | ||||||
|             //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main |  | ||||||
|             //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 |  | ||||||
|             struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); |  | ||||||
|             int embed_dim = 4096; |  | ||||||
|             if (ctx->minicpmv_version == 2) { |  | ||||||
|                 embed_dim = 4096; |  | ||||||
|             } |  | ||||||
|             else if (ctx->minicpmv_version == 3) { |  | ||||||
|                 embed_dim = 3584; |  | ||||||
|             } |  | ||||||
|             else if (ctx->minicpmv_version == 4) { |  | ||||||
|                 embed_dim = 3584; |  | ||||||
|             } |  | ||||||
|             else { |  | ||||||
|                 GGML_ABORT("Unknown minicpmv version"); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? |  | ||||||
|             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); |  | ||||||
|  |  | ||||||
|             std::vector<float> pos_data(ggml_nelements(pos_embed)); |  | ||||||
|             float * data = pos_data.data(); |  | ||||||
|             for(int i = 0; i < pos_w * pos_h; ++i){ |  | ||||||
|                 for(int j = 0; j < embed_dim; ++j){ |  | ||||||
|                     data[i * embed_dim + j] = pos_embed_t[i][j]; |  | ||||||
|                 } |                 } | ||||||
|             } |                 for (int i = 0, id = 0; i < pos_h; i++){ | ||||||
|  |                     for (int j = 0; j < pos_w; j++){ | ||||||
|  |                         positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 set_input_i32("positions", positions); | ||||||
|  |  | ||||||
|             ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed)); |                 // inspired from resampler of Qwen-VL: | ||||||
|         } |                 //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main | ||||||
|     } |                 //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 | ||||||
|     else { |                 int embed_dim = clip_n_mmproj_embd(ctx); | ||||||
|         // non-minicpmv models |  | ||||||
|  |  | ||||||
|         if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { |                 // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? | ||||||
|             // pw * ph = number of tokens output by ViT after apply patch merger |                 auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); | ||||||
|             // ipw * ipw = number of vision token been processed inside ViT |  | ||||||
|             const int merge_ratio = 2; |  | ||||||
|             const int pw  = image_size_width  / patch_size / merge_ratio; |  | ||||||
|             const int ph  = image_size_height / patch_size / merge_ratio; |  | ||||||
|             const int ipw = image_size_width  / patch_size; |  | ||||||
|             const int iph = image_size_height / patch_size; |  | ||||||
|  |  | ||||||
|             std::vector<int> idx    (ph * pw); |                 std::vector<float> pos_embed(embed_dim * pos_w * pos_h); | ||||||
|             std::vector<int> inv_idx(ph * pw); |                 for(int i = 0; i < pos_w * pos_h; ++i){ | ||||||
|  |                     for(int j = 0; j < embed_dim; ++j){ | ||||||
|  |                         pos_embed[i * embed_dim + j] = pos_embed_t[i][j]; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|             if (use_window_attn) { |                 set_input_f32("pos_embed", pos_embed); | ||||||
|                 const int attn_window_size = 112; |             } break; | ||||||
|                 struct ggml_tensor * window_idx     = ggml_graph_get_tensor(gf, "window_idx"); |         case PROJECTOR_TYPE_QWEN2VL: | ||||||
|                 struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); |             { | ||||||
|                 struct ggml_tensor * window_mask    = ggml_graph_get_tensor(gf, "window_mask"); |                 const int merge_ratio = 2; | ||||||
|  |                 const int pw = image_size_width  / patch_size; | ||||||
|                 const int grid_window = attn_window_size / patch_size / merge_ratio; |                 const int ph = image_size_height / patch_size; | ||||||
|                 int dst = 0; |                 std::vector<int> positions(num_positions * 4); | ||||||
|                 // [num_vision_tokens, num_vision_tokens] attention mask tensor |                 int ptr = 0; | ||||||
|                 std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest()); |                 for (int y = 0; y < ph; y += merge_ratio) { | ||||||
|                 int mask_row = 0; |                     for (int x = 0; x < pw; x += merge_ratio) { | ||||||
|  |                         for (int dy = 0; dy < 2; dy++) { | ||||||
|                 for (int y = 0; y < ph; y += grid_window) |                             for (int dx = 0; dx < 2; dx++) { | ||||||
|                 { |                                 positions[                  ptr] = y + dy; | ||||||
|                     for (int x = 0; x < pw; x += grid_window) |                                 positions[    num_patches + ptr] = x + dx; | ||||||
|                     { |                                 positions[2 * num_patches + ptr] = y + dy; | ||||||
|                         const int win_h = std::min(grid_window, ph - y); |                                 positions[3 * num_patches + ptr] = x + dx; | ||||||
|                         const int win_w = std::min(grid_window, pw - x); |                                 ptr++; | ||||||
|                         const int dst_0 = dst; |  | ||||||
|                         // group all tokens belong to the same window togather (to a continue range) |  | ||||||
|                         for (int dy = 0; dy < win_h; dy++) { |  | ||||||
|                             for (int dx = 0; dx < win_w; dx++) { |  | ||||||
|                                 const int src = (y + dy) * pw + (x + dx); |  | ||||||
|                                 assert(src < (int)idx.size()); |  | ||||||
|                                 assert(dst < (int)inv_idx.size()); |  | ||||||
|                                 idx    [src] = dst; |  | ||||||
|                                 inv_idx[dst] = src; |  | ||||||
|                                 dst++; |  | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|  |  | ||||||
|                         for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { |  | ||||||
|                             int row_offset = mask_row * (ipw * iph); |  | ||||||
|                             std::fill( |  | ||||||
|                                 mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), |  | ||||||
|                                 mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio), |  | ||||||
|                                 0.0); |  | ||||||
|                             mask_row++; |  | ||||||
|                         } |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 ggml_backend_tensor_set(window_idx,     idx.data(),     0, ggml_nbytes(window_idx)); |                 set_input_i32("positions", positions); | ||||||
|                 ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); |             } break; | ||||||
|                 ggml_backend_tensor_set(window_mask,    mask.data(),    0, ggml_nbytes(window_mask)); |         case PROJECTOR_TYPE_QWEN25VL: | ||||||
|             } else { |  | ||||||
|                 std::iota(idx.begin(), idx.end(), 0); |  | ||||||
|                 std::iota(inv_idx.begin(), inv_idx.end(), 0); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); |  | ||||||
|             const int mpow = merge_ratio * merge_ratio; |  | ||||||
|             std::vector<int> positions_data(ggml_nelements(positions)); |  | ||||||
|             int * data = positions_data.data(); |  | ||||||
|  |  | ||||||
|             int ptr = 0; |  | ||||||
|             for (int y = 0; y < iph; y += merge_ratio) |  | ||||||
|             { |             { | ||||||
|                 for (int x = 0; x < ipw; x += merge_ratio) |                 // pw * ph = number of tokens output by ViT after apply patch merger | ||||||
|                 { |                 // ipw * ipw = number of vision token been processed inside ViT | ||||||
|                     for (int dy = 0; dy < 2; dy++) { |                 const int merge_ratio = 2; | ||||||
|                         for (int dx = 0; dx < 2; dx++) { |                 const int pw  = image_size_width  / patch_size / merge_ratio; | ||||||
|                             auto remap = idx[ptr / mpow]; |                 const int ph  = image_size_height / patch_size / merge_ratio; | ||||||
|                             remap = remap * mpow + (ptr % mpow); |                 const int ipw = image_size_width  / patch_size; | ||||||
|  |                 const int iph = image_size_height / patch_size; | ||||||
|  |  | ||||||
|                             data[                  remap] = y + dy; |                 std::vector<int> idx    (ph * pw); | ||||||
|                             data[    num_patches + remap] = x + dx; |                 std::vector<int> inv_idx(ph * pw); | ||||||
|                             data[2 * num_patches + remap] = y + dy; |  | ||||||
|                             data[3 * num_patches + remap] = x + dx; |                 if (use_window_attn) { | ||||||
|                             ptr++; |                     const int attn_window_size = 112; | ||||||
|  |                     const int grid_window = attn_window_size / patch_size / merge_ratio; | ||||||
|  |                     int dst = 0; | ||||||
|  |                     // [num_vision_tokens, num_vision_tokens] attention mask tensor | ||||||
|  |                     std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest()); | ||||||
|  |                     int mask_row = 0; | ||||||
|  |  | ||||||
|  |                     for (int y = 0; y < ph; y += grid_window) { | ||||||
|  |                         for (int x = 0; x < pw; x += grid_window) { | ||||||
|  |                             const int win_h = std::min(grid_window, ph - y); | ||||||
|  |                             const int win_w = std::min(grid_window, pw - x); | ||||||
|  |                             const int dst_0 = dst; | ||||||
|  |                             // group all tokens belong to the same window togather (to a continue range) | ||||||
|  |                             for (int dy = 0; dy < win_h; dy++) { | ||||||
|  |                                 for (int dx = 0; dx < win_w; dx++) { | ||||||
|  |                                     const int src = (y + dy) * pw + (x + dx); | ||||||
|  |                                     GGML_ASSERT(src < (int)idx.size()); | ||||||
|  |                                     GGML_ASSERT(dst < (int)inv_idx.size()); | ||||||
|  |                                     idx    [src] = dst; | ||||||
|  |                                     inv_idx[dst] = src; | ||||||
|  |                                     dst++; | ||||||
|  |                                 } | ||||||
|  |                             } | ||||||
|  |  | ||||||
|  |                             for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { | ||||||
|  |                                 int row_offset = mask_row * (ipw * iph); | ||||||
|  |                                 std::fill( | ||||||
|  |                                     mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), | ||||||
|  |                                     mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio), | ||||||
|  |                                     0.0); | ||||||
|  |                                 mask_row++; | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     set_input_i32("window_idx",     idx); | ||||||
|  |                     set_input_i32("inv_window_idx", inv_idx); | ||||||
|  |                     set_input_f32("window_mask",    mask); | ||||||
|  |                 } else { | ||||||
|  |                     for (int i = 0; i < ph * pw; i++) { | ||||||
|  |                         idx[i] = i; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 const int mpow = merge_ratio * merge_ratio; | ||||||
|  |                 std::vector<int> positions(num_positions * 4); | ||||||
|  |  | ||||||
|  |                 int ptr = 0; | ||||||
|  |                 for (int y = 0; y < iph; y += merge_ratio) { | ||||||
|  |                     for (int x = 0; x < ipw; x += merge_ratio) { | ||||||
|  |                         for (int dy = 0; dy < 2; dy++) { | ||||||
|  |                             for (int dx = 0; dx < 2; dx++) { | ||||||
|  |                                 auto remap = idx[ptr / mpow]; | ||||||
|  |                                 remap = (remap * mpow) + (ptr % mpow); | ||||||
|  |  | ||||||
|  |                                 positions[                  remap] = y + dy; | ||||||
|  |                                 positions[    num_patches + remap] = x + dx; | ||||||
|  |                                 positions[2 * num_patches + remap] = y + dy; | ||||||
|  |                                 positions[3 * num_patches + remap] = x + dx; | ||||||
|  |                                 ptr++; | ||||||
|  |                             } | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |  | ||||||
|  |  | ||||||
|             ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); |                 set_input_i32("positions", positions); | ||||||
|         } |             } break; | ||||||
|         else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { |         case PROJECTOR_TYPE_PIXTRAL: | ||||||
|             // do nothing |             { | ||||||
|         } |                 // set the 2D positions | ||||||
|         else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { |                 int n_patches_per_col = image_size_width / patch_size; | ||||||
|             // do nothing |                 std::vector<int> pos_data(num_positions); | ||||||
|         } |                 // dimension H | ||||||
|         else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { |                 for (int i = 0; i < num_positions; i++) { | ||||||
|             // set the 2D positions |                     pos_data[i] = i / n_patches_per_col; | ||||||
|             int n_patches_per_col = image_size_width / patch_size; |                 } | ||||||
|             std::vector<int> pos_data(num_positions); |                 set_input_i32("pos_h", pos_data); | ||||||
|             struct ggml_tensor * pos; |                 // dimension W | ||||||
|             // dimension H |                 for (int i = 0; i < num_positions; i++) { | ||||||
|             pos = ggml_graph_get_tensor(gf, "pos_h"); |                     pos_data[i] = i % n_patches_per_col; | ||||||
|             for (int i = 0; i < num_positions; i++) { |                 } | ||||||
|                 pos_data[i] = i / n_patches_per_col; |                 set_input_i32("pos_w", pos_data); | ||||||
|             } |             } break; | ||||||
|             ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos)); |         case PROJECTOR_TYPE_GLM_EDGE: | ||||||
|             // dimension W |         { | ||||||
|             pos = ggml_graph_get_tensor(gf, "pos_w"); |  | ||||||
|             for (int i = 0; i < num_positions; i++) { |  | ||||||
|                 pos_data[i] = i % n_patches_per_col; |  | ||||||
|             } |  | ||||||
|             ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos)); |  | ||||||
|         } |  | ||||||
|         else { |  | ||||||
|             // llava and other models |             // llava and other models | ||||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); |             std::vector<int32_t> positions(num_positions); | ||||||
|  |  | ||||||
|             int* positions_data = (int*)malloc(ggml_nbytes(positions)); |  | ||||||
|             for (int i = 0; i < num_positions; i++) { |             for (int i = 0; i < num_positions; i++) { | ||||||
|                 positions_data[i] = i; |                 positions[i] = i; | ||||||
|             } |             } | ||||||
|             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); |             set_input_i32("positions", positions); | ||||||
|             free(positions_data); |         } break; | ||||||
|  |         case PROJECTOR_TYPE_MLP: | ||||||
|  |         case PROJECTOR_TYPE_MLP_NORM: | ||||||
|  |         case PROJECTOR_TYPE_LDP: | ||||||
|  |         case PROJECTOR_TYPE_LDPV2: | ||||||
|  |             { | ||||||
|  |                 // llava and other models | ||||||
|  |                 std::vector<int32_t> positions(num_positions); | ||||||
|  |                 for (int i = 0; i < num_positions; i++) { | ||||||
|  |                     positions[i] = i; | ||||||
|  |                 } | ||||||
|  |                 set_input_i32("positions", positions); | ||||||
|  |  | ||||||
|             if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) { |  | ||||||
|                 struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); |  | ||||||
|                 // The patches vector is used to get rows to index into the embeds with; |                 // The patches vector is used to get rows to index into the embeds with; | ||||||
|                 // we should skip dim 0 only if we have CLS to avoid going out of bounds |                 // we should skip dim 0 only if we have CLS to avoid going out of bounds | ||||||
|                 // when retrieving the rows. |                 // when retrieving the rows. | ||||||
|                 int patch_offset = model.class_embedding ? 1 : 0; |                 int patch_offset = model.class_embedding ? 1 : 0; | ||||||
|                 int* patches_data = (int*)malloc(ggml_nbytes(patches)); |                 std::vector<int32_t> patches(num_patches); | ||||||
|                 for (int i = 0; i < num_patches; i++) { |                 for (int i = 0; i < num_patches; i++) { | ||||||
|                     patches_data[i] = i + patch_offset; |                     patches[i] = i + patch_offset; | ||||||
|                 } |                 } | ||||||
|                 ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); |                 set_input_i32("patches", patches); | ||||||
|                 free(patches_data); |             } break; | ||||||
|             } |         case PROJECTOR_TYPE_GEMMA3: | ||||||
|         } |         case PROJECTOR_TYPE_IDEFICS3: | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (use_window_attn && (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL)) { |  | ||||||
|         struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx"); |  | ||||||
|         struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx"); |  | ||||||
|         struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask"); |  | ||||||
|  |  | ||||||
|         const int merge_ratio = 2; |  | ||||||
|         const int attn_window_size = 112; |  | ||||||
|         const int pw = image_size_width / patch_size / merge_ratio; |  | ||||||
|         const int ph = image_size_height / patch_size / merge_ratio; |  | ||||||
|         const int grid_window = attn_window_size / patch_size / merge_ratio; |  | ||||||
|         const int ipw = image_size_width / patch_size; |  | ||||||
|         const int iph = image_size_height / patch_size; |  | ||||||
|         /* |  | ||||||
|         pw * ph = number of tokens output by ViT after apply patch merger |  | ||||||
|         ipw * ipw = number of vision token been processed inside ViT |  | ||||||
|         */ |  | ||||||
|  |  | ||||||
|         std::vector<int> idx(ph * pw); |  | ||||||
|         std::vector<int> inv_idx(ph * pw); |  | ||||||
|         int dst = 0; |  | ||||||
|         // [num_vision_tokens, num_vision_tokens] attention mask tensor |  | ||||||
|         std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest()); |  | ||||||
|         int mask_row = 0; |  | ||||||
|  |  | ||||||
|         for (int y = 0; y < ph; y+=grid_window) |  | ||||||
|         { |  | ||||||
|             for (int x = 0; x < pw; x+=grid_window) |  | ||||||
|             { |             { | ||||||
|                 const int win_h = std::min(grid_window, ph - y); |                 // do nothing | ||||||
|                 const int win_w = std::min(grid_window, pw - x); |             } break; | ||||||
|                 const int dst_0 = dst; |         default: | ||||||
|                 // group all tokens belong to the same window togather (to a continue range) |             GGML_ABORT("Unknown projector type"); | ||||||
|                 for (int dy = 0; dy < win_h; dy++) { |  | ||||||
|                     for (int dx = 0; dx < win_w; dx++) { |  | ||||||
|                         const int src = (y + dy) * pw + (x + dx); |  | ||||||
|                         assert(src < (int)idx.size()); |  | ||||||
|                         assert(dst < (int)inv_idx.size()); |  | ||||||
|                         idx[src] = dst; |  | ||||||
|                         inv_idx[dst] = src; |  | ||||||
|                         dst++; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) { |  | ||||||
|                     int row_offset = mask_row * (ipw * iph); |  | ||||||
|                     std::fill( |  | ||||||
|                         mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio), |  | ||||||
|                         mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio), |  | ||||||
|                         0.0); |  | ||||||
|                     mask_row++; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx)); |  | ||||||
|         ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx)); |  | ||||||
|         ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask)); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); |     ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); | ||||||
| @@ -3537,7 +3494,7 @@ bool clip_is_glm(const struct clip_ctx * ctx) { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool clip_is_qwen2vl(const struct clip_ctx * ctx) { | bool clip_is_qwen2vl(const struct clip_ctx * ctx) { | ||||||
|     return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL; |     return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool clip_is_llava(const struct clip_ctx * ctx) { | bool clip_is_llava(const struct clip_ctx * ctx) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen