mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
mtmd: refactor preprocessing + support max/min pixels (#16878)
* mtmd: refactor preprocessing + support max/min pixels * fix mlp type * implement mix/max pixels * improve hparams * better image preproc for qwen * fix * fix out of bound composite * fix (2) * fix token calculation * get_merge_kernel_size() * fix llama4 and lfm2 * gonna fix them all * use simple resize for qwen * qwen: increase min tokens * no resize if dst size == src size * restore to initial min/max tokens value for qwen
This commit is contained in:
@@ -154,8 +154,8 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_LFM2,
|
||||
PROJECTOR_TYPE_KIMIVL,
|
||||
PROJECTOR_TYPE_LIGHTONOCR,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
PROJECTOR_TYPE_COGVLM,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
|
||||
@@ -171,8 +171,10 @@ struct clip_hparams {
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
// idefics3
|
||||
int32_t preproc_image_size = 0; // aka max_dimension
|
||||
int32_t proj_scale_factor = 0;
|
||||
int32_t image_longest_edge = 0;
|
||||
int32_t image_min_pixels = 0;
|
||||
int32_t image_max_pixels = 0;
|
||||
int32_t n_merge = 0; // number of patch merges **per-side**
|
||||
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
@@ -194,7 +196,6 @@ struct clip_hparams {
|
||||
std::unordered_set<int32_t> vision_feature_layer;
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
int32_t spatial_merge_size = 0;
|
||||
|
||||
// audio
|
||||
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||
@@ -204,6 +205,21 @@ struct clip_hparams {
|
||||
bool has_llava_projector = false;
|
||||
int minicpmv_version = 0;
|
||||
int32_t minicpmv_query_num = 0; // MiniCPM-V query number
|
||||
|
||||
void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
|
||||
image_min_pixels = n_tokens_min * patch_area;
|
||||
image_max_pixels = n_tokens_max * patch_area;
|
||||
warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
|
||||
}
|
||||
|
||||
void set_warmup_n_tokens(int n_tokens) {
|
||||
int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
|
||||
GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
|
||||
const int cur_merge = n_merge == 0 ? 1 : n_merge;
|
||||
warmup_image_size = n_tok_per_side * patch_size * cur_merge;
|
||||
}
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
@@ -532,7 +548,7 @@ struct clip_graph {
|
||||
const int batch_size = 1;
|
||||
GGML_ASSERT(n_patches_x == n_patches_y);
|
||||
const int patches_per_image = n_patches_x;
|
||||
const int kernel_size = hparams.proj_scale_factor;
|
||||
const int kernel_size = hparams.n_merge;
|
||||
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
|
||||
@@ -554,13 +570,13 @@ struct clip_graph {
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// pixel_shuffle
|
||||
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
cur = ggml_mul_mat(ctx0, model.projection, cur);
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
|
||||
// pixel unshuffle block
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
|
||||
// projection
|
||||
@@ -584,7 +600,7 @@ struct clip_graph {
|
||||
}
|
||||
|
||||
ggml_cgraph * build_pixtral() {
|
||||
const int n_merge = hparams.spatial_merge_size;
|
||||
const int n_merge = hparams.n_merge;
|
||||
|
||||
// 2D input positions
|
||||
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
@@ -610,7 +626,7 @@ struct clip_graph {
|
||||
// mistral small 3.1 patch merger
|
||||
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
|
||||
if (model.mm_patch_merger_w) {
|
||||
GGML_ASSERT(hparams.spatial_merge_size > 0);
|
||||
GGML_ASSERT(hparams.n_merge > 0);
|
||||
|
||||
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
||||
|
||||
@@ -926,7 +942,7 @@ struct clip_graph {
|
||||
|
||||
// deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
|
||||
ggml_tensor * deepstack_features = nullptr;
|
||||
const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl
|
||||
const int merge_factor = hparams.n_merge > 0 ? hparams.n_merge * hparams.n_merge : 4; // default 2x2=4 for qwen3vl
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
@@ -1149,7 +1165,7 @@ struct clip_graph {
|
||||
|
||||
// pixel shuffle
|
||||
{
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||
const int height = n_patches_y;
|
||||
const int width = n_patches_x;
|
||||
@@ -1239,7 +1255,7 @@ struct clip_graph {
|
||||
// based on Llama4VisionPixelShuffleMLP
|
||||
// https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
|
||||
{
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
const int bsz = 1; // batch size, always 1 for now since we don't support batching
|
||||
GGML_ASSERT(scale_factor > 0);
|
||||
GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
|
||||
@@ -1311,7 +1327,7 @@ struct clip_graph {
|
||||
|
||||
{
|
||||
// patch_merger
|
||||
const int scale_factor = model.hparams.proj_scale_factor;
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
|
||||
// projection norm
|
||||
@@ -2577,7 +2593,6 @@ struct clip_model_loader {
|
||||
|
||||
if (is_vision) {
|
||||
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
||||
get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false);
|
||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||
get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
|
||||
@@ -2686,65 +2701,68 @@ struct clip_model_loader {
|
||||
hparams.minicpmv_version = 2; // default to 2 if not set
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
// ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json
|
||||
hparams.set_limit_image_tokens(64, 256);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json
|
||||
// TODO: verify the image_min_tokens
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
// Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/14310
|
||||
hparams.image_size = 1024;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
|
||||
hparams.set_limit_image_tokens(8, 1024);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
// TODO: check kimivl preprocessor for exact values
|
||||
hparams.set_limit_image_tokens(8, 1024);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
{
|
||||
// default value (used by all model sizes in gemma 3 family)
|
||||
// number of patches for each **side** is reduced by a factor of 4
|
||||
hparams.proj_scale_factor = 4;
|
||||
hparams.n_merge = 4;
|
||||
// test model (tinygemma3) has a different value, we optionally read it
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
{
|
||||
// max image size = sqrt(max_pixels) = 3584
|
||||
// ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json
|
||||
// however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
|
||||
// ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
|
||||
hparams.image_size = 1024;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
{
|
||||
// max image size = sqrt(max_pixels)
|
||||
// https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
|
||||
// however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
|
||||
// ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
|
||||
hparams.image_size = 1024;
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
hparams.image_size = 1024; // still need this?
|
||||
hparams.warmup_image_size = hparams.patch_size * 8;
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||
hparams.n_merge = 2; // default value for Qwen 2 and 2.5
|
||||
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
|
||||
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // only 2.5 requires it
|
||||
// ref: https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
|
||||
// the actual max limit is 12845056/14/14/2/2/4 = 4096 tokens
|
||||
// but we set a lower value to avoid OOM
|
||||
// TODO: make it configurable by user
|
||||
// TODO (2): bbox coordinates become inaccurate with small number of tokens,
|
||||
// therefore we need to increase the min_tokens
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/16842#issuecomment-3475144858
|
||||
hparams.set_limit_image_tokens(8, 2048);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
set_llava_uhd_res_candidates(model, 3);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
@@ -2777,10 +2795,13 @@ struct clip_model_loader {
|
||||
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
|
||||
LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector);
|
||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version);
|
||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||
LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge);
|
||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||
if (hparams.spatial_merge_size > 0) {
|
||||
LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
|
||||
if (hparams.image_min_pixels > 0) {
|
||||
LOG_INF("%s: image_min_pixels: %d\n", __func__, hparams.image_min_pixels);
|
||||
}
|
||||
if (hparams.image_max_pixels > 0) {
|
||||
LOG_INF("%s: image_max_pixels: %d\n", __func__, hparams.image_max_pixels);
|
||||
}
|
||||
} else if (is_audio) {
|
||||
LOG_INF("\n--- audio hparams ---\n");
|
||||
@@ -3181,9 +3202,11 @@ struct clip_model_loader {
|
||||
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
|
||||
img->nx = hparams.warmup_image_size;
|
||||
img->ny = hparams.warmup_image_size;
|
||||
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
|
||||
} else {
|
||||
img->nx = hparams.warmup_audio_size;
|
||||
img->ny = hparams.n_mel_bins;
|
||||
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
|
||||
}
|
||||
batch.entries.push_back(std::move(img));
|
||||
|
||||
@@ -3399,9 +3422,169 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32
|
||||
|
||||
// set of tools to manupulate images
|
||||
// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
|
||||
struct image_manipulation {
|
||||
struct img_tool {
|
||||
enum resize_algo {
|
||||
RESIZE_ALGO_BILINEAR,
|
||||
RESIZE_ALGO_BICUBIC,
|
||||
// RESIZE_ALGO_LANCZOS, // TODO
|
||||
};
|
||||
|
||||
static void resize(
|
||||
const clip_image_u8 & src,
|
||||
clip_image_u8 & dst,
|
||||
const clip_image_size & target_resolution,
|
||||
resize_algo algo,
|
||||
bool add_padding = true, // TODO: define the behavior for add_padding = false
|
||||
std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
|
||||
dst.nx = target_resolution.width;
|
||||
dst.ny = target_resolution.height;
|
||||
dst.buf.resize(3 * dst.nx * dst.ny);
|
||||
|
||||
if (dst.nx == src.nx && dst.ny == src.ny) {
|
||||
// no resize needed, simple copy
|
||||
dst.buf = src.buf;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!add_padding) {
|
||||
// direct resize
|
||||
switch (algo) {
|
||||
case RESIZE_ALGO_BILINEAR:
|
||||
resize_bilinear(src, dst, target_resolution.width, target_resolution.height);
|
||||
break;
|
||||
case RESIZE_ALGO_BICUBIC:
|
||||
resize_bicubic(src, dst, target_resolution.width, target_resolution.height);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported resize algorithm");
|
||||
}
|
||||
} else {
|
||||
// resize with padding
|
||||
clip_image_u8 resized_image;
|
||||
float scale_w = static_cast<float>(target_resolution.width) / src.nx;
|
||||
float scale_h = static_cast<float>(target_resolution.height) / src.ny;
|
||||
float scale = std::min(scale_w, scale_h);
|
||||
int new_width = std::min(static_cast<int>(std::ceil(src.nx * scale)), target_resolution.width);
|
||||
int new_height = std::min(static_cast<int>(std::ceil(src.ny * scale)), target_resolution.height);
|
||||
|
||||
switch (algo) {
|
||||
case RESIZE_ALGO_BILINEAR:
|
||||
resize_bilinear(src, resized_image, new_width, new_height);
|
||||
break;
|
||||
case RESIZE_ALGO_BICUBIC:
|
||||
resize_bicubic(src, resized_image, new_width, new_height);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported resize algorithm");
|
||||
}
|
||||
|
||||
// fill dst with pad_color
|
||||
fill(dst, pad_color);
|
||||
|
||||
int offset_x = (target_resolution.width - new_width) / 2;
|
||||
int offset_y = (target_resolution.height - new_height) / 2;
|
||||
|
||||
composite(dst, resized_image, offset_x, offset_y);
|
||||
}
|
||||
}
|
||||
|
||||
static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
|
||||
dst.nx = w;
|
||||
dst.ny = h;
|
||||
dst.buf.resize(3 * w * h);
|
||||
|
||||
for (int i = 0; i < h; ++i) {
|
||||
for (int j = 0; j < w; ++j) {
|
||||
int src_idx = 3 * ((y + i)*image.nx + (x + j));
|
||||
int dst_idx = 3 * (i*w + j);
|
||||
dst.buf[dst_idx] = image.buf[src_idx];
|
||||
dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
|
||||
dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the size of the **resized** image, while preserving the aspect ratio
|
||||
// the calculated size will be aligned to the nearest multiple of align_size
|
||||
// if H or W size is larger than longest_edge, it will be resized to longest_edge
|
||||
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int longest_edge) {
|
||||
GGML_ASSERT(align_size > 0);
|
||||
if (inp_size.width <= 0 || inp_size.height <= 0 || longest_edge <= 0) {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
float scale = std::min(static_cast<float>(longest_edge) / inp_size.width,
|
||||
static_cast<float>(longest_edge) / inp_size.height);
|
||||
|
||||
float target_width_f = static_cast<float>(inp_size.width) * scale;
|
||||
float target_height_f = static_cast<float>(inp_size.height) * scale;
|
||||
|
||||
auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
|
||||
int aligned_width = ceil_by_factor(target_width_f);
|
||||
int aligned_height = ceil_by_factor(target_height_f);
|
||||
|
||||
return {aligned_width, aligned_height};
|
||||
}
|
||||
|
||||
// calculate the size of the **resized** image, while preserving the aspect ratio
|
||||
// the calculated size will have min_pixels <= W*H <= max_pixels
|
||||
// this is referred as "smart_resize" in transformers code
|
||||
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int min_pixels, const int max_pixels) {
|
||||
GGML_ASSERT(align_size > 0);
|
||||
const int width = inp_size.width;
|
||||
const int height = inp_size.height;
|
||||
|
||||
auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
|
||||
auto floor_by_factor = [f = align_size](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
|
||||
|
||||
// always align up first
|
||||
int h_bar = std::max(align_size, ceil_by_factor(height));
|
||||
int w_bar = std::max(align_size, ceil_by_factor(width));
|
||||
|
||||
if (h_bar * w_bar > max_pixels) {
|
||||
const auto beta = std::sqrt(static_cast<float>(height * width) / max_pixels);
|
||||
h_bar = std::max(align_size, floor_by_factor(height / beta));
|
||||
w_bar = std::max(align_size, floor_by_factor(width / beta));
|
||||
} else if (h_bar * w_bar < min_pixels) {
|
||||
const auto beta = std::sqrt(static_cast<float>(min_pixels) / (height * width));
|
||||
h_bar = ceil_by_factor(height * beta);
|
||||
w_bar = ceil_by_factor(width * beta);
|
||||
}
|
||||
|
||||
return {w_bar, h_bar};
|
||||
}
|
||||
|
||||
// draw src image into dst image at offset (offset_x, offset_y)
|
||||
static void composite(clip_image_u8 & dst, const clip_image_u8 & src, int offset_x, int offset_y) {
|
||||
for (int y = 0; y < src.ny; ++y) {
|
||||
for (int x = 0; x < src.nx; ++x) {
|
||||
int dx = x + offset_x;
|
||||
int dy = y + offset_y;
|
||||
// skip pixels that would be out of bounds in the destination
|
||||
if (dx < 0 || dy < 0 || dx >= dst.nx || dy >= dst.ny) {
|
||||
continue;
|
||||
}
|
||||
size_t dst_idx = 3 * (static_cast<size_t>(dy) * dst.nx + static_cast<size_t>(dx));
|
||||
size_t src_idx = 3 * (static_cast<size_t>(y) * src.nx + static_cast<size_t>(x));
|
||||
dst.buf[dst_idx + 0] = src.buf[src_idx + 0];
|
||||
dst.buf[dst_idx + 1] = src.buf[src_idx + 1];
|
||||
dst.buf[dst_idx + 2] = src.buf[src_idx + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fill the image with a solid color
|
||||
static void fill(clip_image_u8 & img, const std::array<uint8_t, 3> & color) {
|
||||
for (size_t i = 0; i < img.buf.size(); i += 3) {
|
||||
img.buf[i] = color[0];
|
||||
img.buf[i + 1] = color[1];
|
||||
img.buf[i + 2] = color[2];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Bilinear resize function
|
||||
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
|
||||
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
@@ -3437,7 +3620,7 @@ struct image_manipulation {
|
||||
|
||||
// Bicubic resize function
|
||||
// part of image will be cropped if the aspect ratio is different
|
||||
static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
static bool resize_bicubic(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
const int nx = img.nx;
|
||||
const int ny = img.ny;
|
||||
|
||||
@@ -3500,93 +3683,6 @@ struct image_manipulation {
|
||||
return true;
|
||||
}
|
||||
|
||||
// llava-1.6 type of resize_and_pad
|
||||
// if the ratio is not 1:1, padding with pad_color will be applied
|
||||
// pad_color is single channel, default is 0 (black)
|
||||
static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
|
||||
int target_width = target_resolution.width;
|
||||
int target_height = target_resolution.height;
|
||||
|
||||
float scale_w = static_cast<float>(target_width) / image.nx;
|
||||
float scale_h = static_cast<float>(target_height) / image.ny;
|
||||
|
||||
int new_width, new_height;
|
||||
|
||||
if (scale_w < scale_h) {
|
||||
new_width = target_width;
|
||||
new_height = std::min(static_cast<int>(std::ceil(image.ny * scale_w)), target_height);
|
||||
} else {
|
||||
new_height = target_height;
|
||||
new_width = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
|
||||
}
|
||||
|
||||
clip_image_u8 resized_image;
|
||||
bicubic_resize(image, resized_image, new_width, new_height);
|
||||
|
||||
clip_image_u8 padded_image;
|
||||
padded_image.nx = target_width;
|
||||
padded_image.ny = target_height;
|
||||
padded_image.buf.resize(3 * target_width * target_height);
|
||||
|
||||
// Fill the padded image with the fill color
|
||||
for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
|
||||
padded_image.buf[i] = pad_color[0];
|
||||
padded_image.buf[i + 1] = pad_color[1];
|
||||
padded_image.buf[i + 2] = pad_color[2];
|
||||
}
|
||||
|
||||
// Calculate padding offsets
|
||||
int pad_x = (target_width - new_width) / 2;
|
||||
int pad_y = (target_height - new_height) / 2;
|
||||
|
||||
// Copy the resized image into the center of the padded buffer
|
||||
for (int y = 0; y < new_height; ++y) {
|
||||
for (int x = 0; x < new_width; ++x) {
|
||||
for (int c = 0; c < 3; ++c) {
|
||||
padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
dst = std::move(padded_image);
|
||||
}
|
||||
|
||||
static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
|
||||
dst.nx = w;
|
||||
dst.ny = h;
|
||||
dst.buf.resize(3 * w * h);
|
||||
|
||||
for (int i = 0; i < h; ++i) {
|
||||
for (int j = 0; j < w; ++j) {
|
||||
int src_idx = 3 * ((y + i)*image.nx + (x + j));
|
||||
int dst_idx = 3 * (i*w + j);
|
||||
dst.buf[dst_idx] = image.buf[src_idx];
|
||||
dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
|
||||
dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the size of the **resized** image, while preserving the aspect ratio
|
||||
// the calculated size will be aligned to the nearest multiple of align_size
|
||||
// if H or W size is larger than max_dimension, it will be resized to max_dimension
|
||||
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
|
||||
if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
float scale = std::min(static_cast<float>(max_dimension) / inp_size.width,
|
||||
static_cast<float>(max_dimension) / inp_size.height);
|
||||
|
||||
float target_width_f = static_cast<float>(inp_size.width) * scale;
|
||||
float target_height_f = static_cast<float>(inp_size.height) * scale;
|
||||
|
||||
int aligned_width = CLIP_ALIGN((int)target_width_f, align_size);
|
||||
int aligned_height = CLIP_ALIGN((int)target_height_f, align_size);
|
||||
|
||||
return {aligned_width, aligned_height};
|
||||
}
|
||||
|
||||
private:
|
||||
static inline int clip(int x, int lower, int upper) {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
}
|
||||
@@ -3735,10 +3831,11 @@ struct llava_uhd {
|
||||
|
||||
static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
|
||||
std::vector<clip_image_u8_ptr> output;
|
||||
img_tool::resize_algo interpolation = img_tool::RESIZE_ALGO_BILINEAR; // TODO: make it configurable
|
||||
|
||||
// resize to overview size
|
||||
clip_image_u8_ptr resized_img(clip_image_u8_init());
|
||||
image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size);
|
||||
img_tool::resize(*img, *resized_img, inst.overview_size, interpolation);
|
||||
output.push_back(std::move(resized_img));
|
||||
if (inst.slices.empty()) {
|
||||
// no slices, just return the resized image
|
||||
@@ -3748,9 +3845,11 @@ struct llava_uhd {
|
||||
// resize to refined size
|
||||
clip_image_u8_ptr refined_img(clip_image_u8_init());
|
||||
if (inst.padding_refined) {
|
||||
image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
|
||||
img_tool::resize(*img, *refined_img, inst.refined_size, interpolation);
|
||||
} else {
|
||||
image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
|
||||
// only algo bicubic preserves the ratio; old models rely on this behavior
|
||||
// TODO: do we need to support other algos here?
|
||||
img_tool::resize(*img, *refined_img, inst.refined_size, img_tool::RESIZE_ALGO_BICUBIC, false);
|
||||
}
|
||||
|
||||
// create slices
|
||||
@@ -3761,7 +3860,7 @@ struct llava_uhd {
|
||||
int h = slice.size.height;
|
||||
|
||||
clip_image_u8_ptr img_slice(clip_image_u8_init());
|
||||
image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
|
||||
img_tool::crop(*refined_img, *img_slice, x, y, w, h);
|
||||
output.push_back(std::move(img_slice));
|
||||
}
|
||||
|
||||
@@ -3896,14 +3995,11 @@ private:
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
|
||||
clip_image_size original_size{img->nx, img->ny};
|
||||
bool pad_to_square = true;
|
||||
auto & params = ctx->model.hparams;
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
|
||||
pad_to_square = false;
|
||||
}
|
||||
|
||||
if (clip_is_minicpmv(ctx)) {
|
||||
switch (ctx->proj_type()) {
|
||||
case PROJECTOR_TYPE_MINICPMV:
|
||||
{
|
||||
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
|
||||
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
|
||||
|
||||
@@ -3916,21 +4012,30 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
|
||||
res_imgs->grid_x = inst.grid_size.width;
|
||||
res_imgs->grid_y = inst.grid_size.height;
|
||||
return true;
|
||||
} break;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
case PROJECTOR_TYPE_QWEN2VL:
|
||||
case PROJECTOR_TYPE_QWEN25VL:
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
// step 1: make a blank canvas which aligns to the grid
|
||||
clip_image_u8 resized;
|
||||
auto patch_size = params.patch_size * 2;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
|
||||
image_manipulation::bicubic_resize(*img, resized, new_size.width, new_size.height);
|
||||
|
||||
const clip_image_size new_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
params.patch_size * 2,
|
||||
params.image_min_pixels,
|
||||
params.image_max_pixels);
|
||||
img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
|
||||
// clip_image_save_to_bmp(resized, "preproc.bmp");
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
// clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
|
||||
// res_imgs->data[0] = *res;
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
{
|
||||
// The refined size has two steps:
|
||||
// 1. Resize w/ aspect-ratio preserving such that the longer side is
|
||||
// the preprocessor longest size
|
||||
@@ -3938,8 +4043,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
// multiples of image_size (always rounding up)
|
||||
//
|
||||
// CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737
|
||||
const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio(
|
||||
original_size, params.image_size, params.preproc_image_size);
|
||||
const clip_image_size refined_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size, params.image_size, params.image_longest_edge);
|
||||
// LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n",
|
||||
// __func__, original_size.width, original_size.height,
|
||||
// refined_size.width, refined_size.height);
|
||||
@@ -3976,32 +4081,41 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
|
||||
res_imgs->grid_x = instructions.grid_size.width;
|
||||
res_imgs->grid_y = instructions.grid_size.height;
|
||||
return true;
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
|
||||
) {
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
|
||||
{
|
||||
clip_image_u8 resized_image;
|
||||
int sz = params.image_size;
|
||||
image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
|
||||
img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR);
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
//clip_image_save_to_bmp(resized_image, "resized.bmp");
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
} break;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
|
||||
) {
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
|
||||
clip_image_u8 resized_image;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
|
||||
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
|
||||
// the original pixtral model doesn't have n_merge
|
||||
const int cur_merge = params.n_merge == 0 ? 1 : params.n_merge;
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
params.patch_size * cur_merge,
|
||||
params.image_min_pixels,
|
||||
params.image_max_pixels);
|
||||
img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BILINEAR);
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
} break;
|
||||
|
||||
} else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
GGML_ASSERT(!params.image_res_candidates.empty());
|
||||
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
|
||||
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
|
||||
@@ -4014,55 +4128,41 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
|
||||
res_imgs->grid_x = inst.grid_size.width;
|
||||
res_imgs->grid_y = inst.grid_size.height;
|
||||
return true;
|
||||
|
||||
} else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
|
||||
|| ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
|
||||
) {
|
||||
GGML_ASSERT(params.proj_scale_factor);
|
||||
|
||||
// smart resize
|
||||
const int width = img->nx;
|
||||
const int height = img->ny;
|
||||
const int total_factor = params.patch_size * params.proj_scale_factor;
|
||||
constexpr int min_image_tokens = 64;
|
||||
constexpr int max_image_tokens = 1024;
|
||||
const float min_pixels = min_image_tokens * total_factor * total_factor;
|
||||
const float max_pixels = max_image_tokens * total_factor * total_factor;
|
||||
|
||||
auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
|
||||
auto ceil_by_factor = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
|
||||
auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
|
||||
|
||||
int h_bar = std::max(total_factor, round_by_factor(height));
|
||||
int w_bar = std::max(total_factor, round_by_factor(width));
|
||||
|
||||
if (h_bar * w_bar > max_pixels) {
|
||||
const auto beta = std::sqrt((height * width) / max_pixels);
|
||||
h_bar = std::max(total_factor, floor_by_factor(height / beta));
|
||||
w_bar = std::max(total_factor, floor_by_factor(width / beta));
|
||||
} else if (h_bar * w_bar < min_pixels) {
|
||||
const auto beta = std::sqrt(min_pixels / (height * width));
|
||||
h_bar = ceil_by_factor(height * beta);
|
||||
w_bar = ceil_by_factor(width * beta);
|
||||
}
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
params.patch_size * params.n_merge,
|
||||
params.image_min_pixels,
|
||||
params.image_max_pixels);
|
||||
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
|
||||
|
||||
clip_image_u8 resized_img;
|
||||
image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
|
||||
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
case PROJECTOR_TYPE_LDP:
|
||||
case PROJECTOR_TYPE_LDPV2:
|
||||
case PROJECTOR_TYPE_COGVLM: // TODO @ngxson : is this correct for cogvlm?
|
||||
{
|
||||
// TODO @ngxson : refactor the code below to avoid duplicated logic
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
|
||||
clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
|
||||
|
||||
if (pad_to_square) {
|
||||
// The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
|
||||
if (params.image_res_candidates.empty()) { // pad_to_square
|
||||
// for llava-1.5, we resize image to a square, and pad the shorter side with a background color
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
const int longer_side = std::max(img->nx, img->ny);
|
||||
@@ -4074,14 +4174,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
const std::array<uint8_t, 3> pad_color = {122, 116, 104};
|
||||
|
||||
// resize the image to the target_size
|
||||
image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
|
||||
img_tool::resize(*img, *temp, clip_image_size{params.image_size, params.image_size}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
|
||||
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
return true;
|
||||
|
||||
} else if (!params.image_res_candidates.empty()) {
|
||||
} else {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
|
||||
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
|
||||
@@ -4092,12 +4191,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
||||
return true;
|
||||
} else {
|
||||
GGML_ABORT("Unknown image preprocessing type");
|
||||
default:
|
||||
LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
|
||||
@@ -4145,7 +4247,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
|
||||
const auto & params = ctx->model.hparams;
|
||||
const int n_total = clip_n_output_tokens(ctx, img);
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
|
||||
return img->nx / (params.patch_size * 2);
|
||||
}
|
||||
return n_total;
|
||||
}
|
||||
@@ -4153,7 +4255,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
|
||||
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->model.hparams;
|
||||
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
|
||||
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
|
||||
return img->ny / (params.patch_size * 2);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
@@ -4211,9 +4313,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_QWEN3VL:
|
||||
{
|
||||
// dynamic size (2 conv, so double patch size)
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
int x_patch = img->nx / (params.patch_size * 2);
|
||||
int y_patch = img->ny / (params.patch_size * 2);
|
||||
n_patches = x_patch * y_patch;
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
@@ -4222,15 +4323,14 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
// both X and Y are downscaled by the scale factor
|
||||
int scale_factor = ctx->model.hparams.proj_scale_factor;
|
||||
int scale_factor = ctx->model.hparams.n_merge;
|
||||
n_patches /= (scale_factor * scale_factor);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
{
|
||||
// dynamic size
|
||||
int scale_factor = ctx->model.hparams.proj_scale_factor;
|
||||
int out_patch_size = params.patch_size * scale_factor;
|
||||
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
|
||||
int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
|
||||
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
|
||||
n_patches = x_patch * y_patch;
|
||||
@@ -4239,7 +4339,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// dynamic size
|
||||
int n_merge = params.spatial_merge_size;
|
||||
int n_merge = ctx->model.hparams.n_merge;
|
||||
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
|
||||
if (ctx->model.token_embd_img_break) {
|
||||
|
||||
Reference in New Issue
Block a user