mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	clip : improve projector naming (#13118)
* clip : improve projector naming * no more kv has_llava_projector * rm unused kv * rm more unused
This commit is contained in:
		@@ -17,22 +17,15 @@
 | 
			
		||||
#define KEY_FTYPE               "general.file_type"
 | 
			
		||||
#define KEY_NAME                "general.name"
 | 
			
		||||
#define KEY_DESCRIPTION         "general.description"
 | 
			
		||||
#define KEY_HAS_TEXT_ENC        "clip.has_text_encoder"
 | 
			
		||||
#define KEY_HAS_VIS_ENC         "clip.has_vision_encoder"
 | 
			
		||||
#define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector"
 | 
			
		||||
#define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector"
 | 
			
		||||
#define KEY_HAS_GLM_PROJ        "clip.has_glm_projector"
 | 
			
		||||
#define KEY_MINICPMV_VERSION    "clip.minicpmv_version"
 | 
			
		||||
#define KEY_HAS_QWEN2VL_MERGER  "clip.has_qwen2vl_merger"
 | 
			
		||||
#define KEY_USE_GELU            "clip.use_gelu"
 | 
			
		||||
#define KEY_USE_SILU            "clip.use_silu"
 | 
			
		||||
#define KEY_N_EMBD              "clip.%s.embedding_length"
 | 
			
		||||
#define KEY_N_FF                "clip.%s.feed_forward_length"
 | 
			
		||||
#define KEY_N_BLOCK             "clip.%s.block_count"
 | 
			
		||||
#define KEY_N_HEAD              "clip.%s.attention.head_count"
 | 
			
		||||
#define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
 | 
			
		||||
#define KEY_PROJ_DIM            "clip.%s.projection_dim"
 | 
			
		||||
#define KEY_TOKENS              "tokenizer.ggml.tokens"
 | 
			
		||||
#define KEY_N_EMBD              "clip.vision.embedding_length"
 | 
			
		||||
#define KEY_N_FF                "clip.vision.feed_forward_length"
 | 
			
		||||
#define KEY_N_BLOCK             "clip.vision.block_count"
 | 
			
		||||
#define KEY_N_HEAD              "clip.vision.attention.head_count"
 | 
			
		||||
#define KEY_LAYER_NORM_EPS      "clip.vision.attention.layer_norm_epsilon"
 | 
			
		||||
#define KEY_PROJ_DIM            "clip.vision.projection_dim"
 | 
			
		||||
#define KEY_IMAGE_SIZE          "clip.vision.image_size"
 | 
			
		||||
#define KEY_PATCH_SIZE          "clip.vision.patch_size"
 | 
			
		||||
#define KEY_IMAGE_MEAN          "clip.vision.image_mean"
 | 
			
		||||
@@ -96,9 +89,9 @@ enum projector_type {
 | 
			
		||||
    PROJECTOR_TYPE_MLP_NORM,
 | 
			
		||||
    PROJECTOR_TYPE_LDP,
 | 
			
		||||
    PROJECTOR_TYPE_LDPV2,
 | 
			
		||||
    PROJECTOR_TYPE_RESAMPLER,
 | 
			
		||||
    PROJECTOR_TYPE_MINICPMV,
 | 
			
		||||
    PROJECTOR_TYPE_GLM_EDGE,
 | 
			
		||||
    PROJECTOR_TYPE_MERGER,
 | 
			
		||||
    PROJECTOR_TYPE_QWEN2VL,
 | 
			
		||||
    PROJECTOR_TYPE_GEMMA3,
 | 
			
		||||
    PROJECTOR_TYPE_IDEFICS3,
 | 
			
		||||
    PROJECTOR_TYPE_PIXTRAL,
 | 
			
		||||
@@ -109,9 +102,9 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
 | 
			
		||||
    { PROJECTOR_TYPE_MLP,       "mlp" },
 | 
			
		||||
    { PROJECTOR_TYPE_LDP,       "ldp" },
 | 
			
		||||
    { PROJECTOR_TYPE_LDPV2,     "ldpv2"},
 | 
			
		||||
    { PROJECTOR_TYPE_RESAMPLER, "resampler"},
 | 
			
		||||
    { PROJECTOR_TYPE_MINICPMV,  "resampler"},
 | 
			
		||||
    { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
 | 
			
		||||
    { PROJECTOR_TYPE_MERGER,    "qwen2vl_merger"},
 | 
			
		||||
    { PROJECTOR_TYPE_QWEN2VL,   "qwen2vl_merger"},
 | 
			
		||||
    { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
 | 
			
		||||
    { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
 | 
			
		||||
    { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
 | 
			
		||||
 
 | 
			
		||||
@@ -308,13 +308,8 @@ struct clip_vision_model {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct clip_ctx {
 | 
			
		||||
    bool has_text_encoder    = false;
 | 
			
		||||
    bool has_vision_encoder  = false;
 | 
			
		||||
    bool has_llava_projector = false;
 | 
			
		||||
    bool has_minicpmv_projector = false;
 | 
			
		||||
    bool has_glm_projector = false;
 | 
			
		||||
    bool has_qwen2vl_merger = false;
 | 
			
		||||
    int minicpmv_version = 2;
 | 
			
		||||
    int minicpmv_version = 0;
 | 
			
		||||
 | 
			
		||||
    struct clip_vision_model vision_model;
 | 
			
		||||
    projector_type proj_type = PROJECTOR_TYPE_MLP;
 | 
			
		||||
@@ -373,23 +368,20 @@ struct clip_ctx {
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
 | 
			
		||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
 | 
			
		||||
    const auto & model = ctx->vision_model;
 | 
			
		||||
    const auto & hparams = model.hparams;
 | 
			
		||||
 | 
			
		||||
    const int image_size = hparams.image_size;
 | 
			
		||||
    int image_size_width  = image_size;
 | 
			
		||||
    int image_size_height = image_size;
 | 
			
		||||
    int image_size_width  = img.nx;
 | 
			
		||||
    int image_size_height = img.ny;
 | 
			
		||||
 | 
			
		||||
    const int patch_size           = hparams.patch_size;
 | 
			
		||||
    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
 | 
			
		||||
    const int hidden_size          = hparams.hidden_size;
 | 
			
		||||
    const int n_head               = hparams.n_head;
 | 
			
		||||
    const int d_head               = hidden_size / n_head;
 | 
			
		||||
    const int n_layer              = hparams.n_layer;
 | 
			
		||||
    const float eps                = hparams.eps;
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
 | 
			
		||||
    const int patch_size  = hparams.patch_size;
 | 
			
		||||
    const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
 | 
			
		||||
    const int hidden_size = hparams.hidden_size;
 | 
			
		||||
    const int n_head      = hparams.n_head;
 | 
			
		||||
    const int d_head      = hidden_size / n_head;
 | 
			
		||||
    const int n_layer     = hparams.n_layer;
 | 
			
		||||
    const float eps       = hparams.eps;
 | 
			
		||||
 | 
			
		||||
    struct ggml_init_params params = {
 | 
			
		||||
        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
 | 
			
		||||
@@ -621,15 +613,14 @@ static ggml_tensor * build_rope_2d(
 | 
			
		||||
    return cur;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
 | 
			
		||||
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
 | 
			
		||||
    const auto & model = ctx->vision_model;
 | 
			
		||||
    const auto & hparams = model.hparams;
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
 | 
			
		||||
    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
 | 
			
		||||
 | 
			
		||||
    int image_size_width  = imgs.entries[0]->nx;
 | 
			
		||||
    int image_size_height = imgs.entries[0]->ny;
 | 
			
		||||
    int image_size_width  = img.nx;
 | 
			
		||||
    int image_size_height = img.ny;
 | 
			
		||||
 | 
			
		||||
    const int patch_size  = hparams.patch_size;
 | 
			
		||||
    const int n_patches_x = image_size_width  / patch_size;
 | 
			
		||||
@@ -772,18 +763,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
 | 
			
		||||
    if (!ctx->has_vision_encoder) {
 | 
			
		||||
        LOG_ERR("This gguf file seems to have no vision encoder\n");
 | 
			
		||||
        return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const auto & model = ctx->vision_model;
 | 
			
		||||
    const auto & hparams = model.hparams;
 | 
			
		||||
 | 
			
		||||
    const int image_size = hparams.image_size;
 | 
			
		||||
    int image_size_width  = image_size;
 | 
			
		||||
    int image_size_height = image_size;
 | 
			
		||||
    if (ctx->has_minicpmv_projector) {
 | 
			
		||||
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
 | 
			
		||||
        image_size_width  = load_image_size.width;
 | 
			
		||||
        image_size_height = load_image_size.height;
 | 
			
		||||
@@ -792,7 +779,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
            image_size_height = imgs.entries[0]->ny;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    else if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
        // use the image's native resolution when image is avaible
 | 
			
		||||
        if (is_inf) {
 | 
			
		||||
        // if (imgs->data->nx && imgs->data->ny) {
 | 
			
		||||
@@ -800,12 +788,13 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
            image_size_height = imgs.entries[0]->ny;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int patch_size           = hparams.patch_size;
 | 
			
		||||
    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
 | 
			
		||||
    const int patches_w            = image_size_width / patch_size;
 | 
			
		||||
    const int patches_h            = image_size_height / patch_size;
 | 
			
		||||
    const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
 | 
			
		||||
    const int num_position_ids     = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
 | 
			
		||||
    const int num_position_ids     = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
 | 
			
		||||
    const int hidden_size          = hparams.hidden_size;
 | 
			
		||||
    const int n_head               = hparams.n_head;
 | 
			
		||||
    const int d_head               = hidden_size / n_head;
 | 
			
		||||
@@ -814,7 +803,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
 | 
			
		||||
    const int batch_size = imgs.entries.size();
 | 
			
		||||
 | 
			
		||||
    if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
 | 
			
		||||
    if (ctx->has_llava_projector
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
        GGML_ASSERT(batch_size == 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -835,8 +826,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
 | 
			
		||||
 | 
			
		||||
    if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
        GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
        GGML_ASSERT(image_size_width  % (patch_size * 2) == 0);
 | 
			
		||||
        GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
 | 
			
		||||
 | 
			
		||||
        auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
 | 
			
		||||
@@ -865,29 +856,26 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
    struct ggml_tensor * embeddings = inp;
 | 
			
		||||
    struct ggml_tensor * pos_embed = nullptr;
 | 
			
		||||
 | 
			
		||||
    if (ctx->has_llava_projector) {
 | 
			
		||||
        // concat class_embeddings and patch_embeddings
 | 
			
		||||
        if (model.class_embedding) {
 | 
			
		||||
            embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
 | 
			
		||||
            ggml_set_name(embeddings, "embeddings");
 | 
			
		||||
            ggml_set_input(embeddings);
 | 
			
		||||
            embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
 | 
			
		||||
                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
 | 
			
		||||
            embeddings = ggml_acc(ctx0, embeddings, inp,
 | 
			
		||||
                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
 | 
			
		||||
        }
 | 
			
		||||
    // concat class_embeddings and patch_embeddings
 | 
			
		||||
    if (model.class_embedding) {
 | 
			
		||||
        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
 | 
			
		||||
        embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
 | 
			
		||||
        embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
 | 
			
		||||
                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
 | 
			
		||||
        embeddings = ggml_acc(ctx0, embeddings, inp,
 | 
			
		||||
                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
 | 
			
		||||
    ggml_set_name(positions, "positions");
 | 
			
		||||
    ggml_set_input(positions);
 | 
			
		||||
 | 
			
		||||
    if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
 | 
			
		||||
    if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
 | 
			
		||||
        embeddings =
 | 
			
		||||
            ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (ctx->has_minicpmv_projector) {
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        int pos_w = image_size_width/patch_size;
 | 
			
		||||
        int pos_h = image_size_height/patch_size;
 | 
			
		||||
        if (ctx->minicpmv_version == 2) {
 | 
			
		||||
@@ -941,7 +929,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
 | 
			
		||||
 | 
			
		||||
            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
            if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
                Q = ggml_rope_multi(
 | 
			
		||||
                    ctx0, Q, positions, nullptr,
 | 
			
		||||
                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
 | 
			
		||||
@@ -953,7 +941,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
 | 
			
		||||
 | 
			
		||||
            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
            if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
                K = ggml_rope_multi(
 | 
			
		||||
                    ctx0, K, positions, nullptr,
 | 
			
		||||
                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
 | 
			
		||||
@@ -1218,106 +1206,98 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    // minicpmv projector
 | 
			
		||||
    else if (ctx->has_minicpmv_projector)
 | 
			
		||||
    {
 | 
			
		||||
        if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
 | 
			
		||||
            struct ggml_tensor * q = model.mm_model_query;
 | 
			
		||||
            { // layernorm
 | 
			
		||||
                q = ggml_norm(ctx0, q, eps);
 | 
			
		||||
                q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
 | 
			
		||||
            }
 | 
			
		||||
            struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
 | 
			
		||||
            { // layernorm
 | 
			
		||||
                v = ggml_norm(ctx0, v, eps);
 | 
			
		||||
                v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
 | 
			
		||||
            }
 | 
			
		||||
            struct ggml_tensor * k;
 | 
			
		||||
            { // position
 | 
			
		||||
                // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
 | 
			
		||||
                k = ggml_add(ctx0, v, pos_embed);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            { // attention
 | 
			
		||||
                int hidden_size = 4096;
 | 
			
		||||
                const int d_head = 128;
 | 
			
		||||
                int n_head = hidden_size/d_head;
 | 
			
		||||
                int num_query = 96;
 | 
			
		||||
                if (ctx->minicpmv_version == 2) {
 | 
			
		||||
                    hidden_size = 4096;
 | 
			
		||||
                    n_head = hidden_size/d_head;
 | 
			
		||||
                    num_query = 96;
 | 
			
		||||
                }
 | 
			
		||||
                else if (ctx->minicpmv_version == 3) {
 | 
			
		||||
                    hidden_size = 3584;
 | 
			
		||||
                    n_head = hidden_size/d_head;
 | 
			
		||||
                    num_query = 64;
 | 
			
		||||
                }
 | 
			
		||||
                else if (ctx->minicpmv_version == 4) {
 | 
			
		||||
                    hidden_size = 3584;
 | 
			
		||||
                    n_head = hidden_size/d_head;
 | 
			
		||||
                    num_query = 64;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
 | 
			
		||||
                struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
 | 
			
		||||
                struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
 | 
			
		||||
                // permute
 | 
			
		||||
                Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
 | 
			
		||||
                Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
 | 
			
		||||
                Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
 | 
			
		||||
                K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
                K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
 | 
			
		||||
                K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
 | 
			
		||||
                V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
                V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
 | 
			
		||||
                V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
 | 
			
		||||
                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 | 
			
		||||
                KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
 | 
			
		||||
                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
 | 
			
		||||
                KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
 | 
			
		||||
                KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 | 
			
		||||
                KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
 | 
			
		||||
 | 
			
		||||
                embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
 | 
			
		||||
            }
 | 
			
		||||
            { // layernorm
 | 
			
		||||
                embeddings = ggml_norm(ctx0, embeddings, eps);
 | 
			
		||||
                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
 | 
			
		||||
            }
 | 
			
		||||
            embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        struct ggml_tensor * q = model.mm_model_query;
 | 
			
		||||
        { // layernorm
 | 
			
		||||
            q = ggml_norm(ctx0, q, eps);
 | 
			
		||||
            q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            GGML_ASSERT(false);
 | 
			
		||||
        struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
 | 
			
		||||
        { // layernorm
 | 
			
		||||
            v = ggml_norm(ctx0, v, eps);
 | 
			
		||||
            v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
 | 
			
		||||
        }
 | 
			
		||||
        struct ggml_tensor * k;
 | 
			
		||||
        { // position
 | 
			
		||||
            // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
 | 
			
		||||
            k = ggml_add(ctx0, v, pos_embed);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        { // attention
 | 
			
		||||
            int hidden_size = 4096;
 | 
			
		||||
            const int d_head = 128;
 | 
			
		||||
            int n_head = hidden_size/d_head;
 | 
			
		||||
            int num_query = 96;
 | 
			
		||||
            if (ctx->minicpmv_version == 2) {
 | 
			
		||||
                hidden_size = 4096;
 | 
			
		||||
                n_head = hidden_size/d_head;
 | 
			
		||||
                num_query = 96;
 | 
			
		||||
            }
 | 
			
		||||
            else if (ctx->minicpmv_version == 3) {
 | 
			
		||||
                hidden_size = 3584;
 | 
			
		||||
                n_head = hidden_size/d_head;
 | 
			
		||||
                num_query = 64;
 | 
			
		||||
            }
 | 
			
		||||
            else if (ctx->minicpmv_version == 4) {
 | 
			
		||||
                hidden_size = 3584;
 | 
			
		||||
                n_head = hidden_size/d_head;
 | 
			
		||||
                num_query = 64;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
 | 
			
		||||
            struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
 | 
			
		||||
            struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
 | 
			
		||||
            // permute
 | 
			
		||||
            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
 | 
			
		||||
            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
 | 
			
		||||
            Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
 | 
			
		||||
            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
 | 
			
		||||
            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
 | 
			
		||||
            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
 | 
			
		||||
            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
 | 
			
		||||
            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
 | 
			
		||||
            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 | 
			
		||||
            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
 | 
			
		||||
            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
 | 
			
		||||
            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
 | 
			
		||||
            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 | 
			
		||||
            KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
 | 
			
		||||
 | 
			
		||||
            embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
 | 
			
		||||
        }
 | 
			
		||||
        { // layernorm
 | 
			
		||||
            embeddings = ggml_norm(ctx0, embeddings, eps);
 | 
			
		||||
            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
 | 
			
		||||
        }
 | 
			
		||||
        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // glm projector
 | 
			
		||||
    else if (ctx->has_glm_projector) {
 | 
			
		||||
        if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
            size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
 | 
			
		||||
            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
 | 
			
		||||
            embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
 | 
			
		||||
            embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
 | 
			
		||||
            embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
 | 
			
		||||
            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
 | 
			
		||||
            embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
 | 
			
		||||
            //GLU
 | 
			
		||||
            {
 | 
			
		||||
                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
 | 
			
		||||
                embeddings = ggml_norm(ctx0, embeddings, eps);
 | 
			
		||||
                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
 | 
			
		||||
                embeddings = ggml_gelu_inplace(ctx0, embeddings);
 | 
			
		||||
                struct ggml_tensor * x = embeddings;
 | 
			
		||||
                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
 | 
			
		||||
                x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
 | 
			
		||||
                embeddings = ggml_silu_inplace(ctx0, embeddings);
 | 
			
		||||
                embeddings = ggml_mul(ctx0, embeddings,x);
 | 
			
		||||
                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            GGML_ABORT("fatal error");
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
        size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
 | 
			
		||||
        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
 | 
			
		||||
        embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
 | 
			
		||||
        embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
 | 
			
		||||
        embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
 | 
			
		||||
        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
 | 
			
		||||
        embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
 | 
			
		||||
        // GLU
 | 
			
		||||
        {
 | 
			
		||||
            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
 | 
			
		||||
            embeddings = ggml_norm(ctx0, embeddings, eps);
 | 
			
		||||
            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
 | 
			
		||||
            embeddings = ggml_gelu_inplace(ctx0, embeddings);
 | 
			
		||||
            struct ggml_tensor * x = embeddings;
 | 
			
		||||
            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
 | 
			
		||||
            x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
 | 
			
		||||
            embeddings = ggml_silu_inplace(ctx0, embeddings);
 | 
			
		||||
            embeddings = ggml_mul(ctx0, embeddings,x);
 | 
			
		||||
            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
 | 
			
		||||
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
        embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
 | 
			
		||||
 | 
			
		||||
        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
 | 
			
		||||
@@ -1343,11 +1323,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
 | 
			
		||||
        case PROJECTOR_TYPE_GEMMA3:
 | 
			
		||||
        case PROJECTOR_TYPE_IDEFICS3:
 | 
			
		||||
            {
 | 
			
		||||
                res = clip_image_build_graph_siglip(ctx, imgs);
 | 
			
		||||
                GGML_ASSERT(imgs.entries.size() == 1);
 | 
			
		||||
                res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
 | 
			
		||||
            } break;
 | 
			
		||||
        case PROJECTOR_TYPE_PIXTRAL:
 | 
			
		||||
            {
 | 
			
		||||
                res = clip_image_build_graph_pixtral(ctx, imgs);
 | 
			
		||||
                GGML_ASSERT(imgs.entries.size() == 1);
 | 
			
		||||
                res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
 | 
			
		||||
            } break;
 | 
			
		||||
        default:
 | 
			
		||||
            {
 | 
			
		||||
@@ -1419,8 +1401,8 @@ struct clip_model_loader {
 | 
			
		||||
        auto & hparams = ctx_clip.vision_model.hparams;
 | 
			
		||||
 | 
			
		||||
        // projector type
 | 
			
		||||
        std::string proj_type;
 | 
			
		||||
        {
 | 
			
		||||
            std::string proj_type;
 | 
			
		||||
            get_string(KEY_PROJ_TYPE, proj_type, false);
 | 
			
		||||
            if (!proj_type.empty()) {
 | 
			
		||||
                ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
 | 
			
		||||
@@ -1432,33 +1414,27 @@ struct clip_model_loader {
 | 
			
		||||
 | 
			
		||||
        // other hparams
 | 
			
		||||
        {
 | 
			
		||||
            get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false);
 | 
			
		||||
            get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false);
 | 
			
		||||
            GGML_ASSERT(ctx_clip.has_vision_encoder);
 | 
			
		||||
            GGML_ASSERT(!ctx_clip.has_text_encoder);
 | 
			
		||||
 | 
			
		||||
            // legacy keys, use KEY_PROJ_TYPE instead
 | 
			
		||||
            get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false);
 | 
			
		||||
            get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false);
 | 
			
		||||
            get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
 | 
			
		||||
            get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false);
 | 
			
		||||
            get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false);
 | 
			
		||||
            // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead
 | 
			
		||||
 | 
			
		||||
            get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
 | 
			
		||||
            get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
 | 
			
		||||
 | 
			
		||||
            get_u32(string_format(KEY_N_EMBD,         "vision"), hparams.hidden_size);
 | 
			
		||||
            get_u32(string_format(KEY_N_HEAD,         "vision"), hparams.n_head);
 | 
			
		||||
            get_u32(string_format(KEY_N_FF,           "vision"), hparams.n_intermediate);
 | 
			
		||||
            get_u32(string_format(KEY_N_BLOCK,        "vision"), hparams.n_layer);
 | 
			
		||||
            get_u32(string_format(KEY_PROJ_DIM,       "vision"), hparams.projection_dim);
 | 
			
		||||
            get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps);
 | 
			
		||||
            get_u32(KEY_IMAGE_SIZE, hparams.image_size);
 | 
			
		||||
            get_u32(KEY_PATCH_SIZE, hparams.patch_size);
 | 
			
		||||
            get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
 | 
			
		||||
            get_u32(KEY_N_EMBD,         hparams.hidden_size);
 | 
			
		||||
            get_u32(KEY_N_HEAD,         hparams.n_head);
 | 
			
		||||
            get_u32(KEY_N_FF,           hparams.n_intermediate);
 | 
			
		||||
            get_u32(KEY_N_BLOCK,        hparams.n_layer);
 | 
			
		||||
            get_u32(KEY_PROJ_DIM,       hparams.projection_dim);
 | 
			
		||||
            get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
 | 
			
		||||
            get_u32(KEY_IMAGE_SIZE,     hparams.image_size);
 | 
			
		||||
            get_u32(KEY_PATCH_SIZE,     hparams.patch_size);
 | 
			
		||||
            get_u32(KEY_IMAGE_CROP_RESOLUTION,    hparams.image_crop_resolution, false);
 | 
			
		||||
            get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
 | 
			
		||||
 | 
			
		||||
            ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
 | 
			
		||||
                                        || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
 | 
			
		||||
                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
 | 
			
		||||
                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
 | 
			
		||||
 | 
			
		||||
            {
 | 
			
		||||
                std::string mm_patch_merge_type;
 | 
			
		||||
                get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
 | 
			
		||||
@@ -1491,32 +1467,56 @@ struct clip_model_loader {
 | 
			
		||||
            for (auto & layer : vision_feature_layer) {
 | 
			
		||||
                hparams.vision_feature_layer.insert(layer);
 | 
			
		||||
            }
 | 
			
		||||
            // Calculate the deepest feature layer based on hparams and projector type
 | 
			
		||||
            ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip);
 | 
			
		||||
 | 
			
		||||
            LOG_INF("%s: text_encoder:       %d\n", __func__, ctx_clip.has_text_encoder);
 | 
			
		||||
            LOG_INF("%s: vision_encoder:     %d\n", __func__, ctx_clip.has_vision_encoder);
 | 
			
		||||
            LOG_INF("%s: llava_projector:    %d\n", __func__, ctx_clip.has_llava_projector);
 | 
			
		||||
            LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector);
 | 
			
		||||
            // Calculate the deepest feature layer based on hparams and projector type
 | 
			
		||||
            // NOTE: This is only used by build_graph_legacy()
 | 
			
		||||
            {
 | 
			
		||||
                // Get the index of the second to last layer; this is the default for models that have a llava projector
 | 
			
		||||
                int n_layer = hparams.n_layer - 1;
 | 
			
		||||
                int deepest_feature_layer = -1;
 | 
			
		||||
 | 
			
		||||
                if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
 | 
			
		||||
                        || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
 | 
			
		||||
                        || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
                    n_layer += 1;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // If we set explicit vision feature layers, only go up to the deepest one
 | 
			
		||||
                // NOTE: only used by granite-vision models for now
 | 
			
		||||
                for (const auto & feature_layer : hparams.vision_feature_layer) {
 | 
			
		||||
                    if (feature_layer > deepest_feature_layer) {
 | 
			
		||||
                        deepest_feature_layer = feature_layer;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // model-specific params
 | 
			
		||||
            switch (ctx_clip.proj_type) {
 | 
			
		||||
                case PROJECTOR_TYPE_MINICPMV:
 | 
			
		||||
                    {
 | 
			
		||||
                        if (ctx_clip.minicpmv_version == 0) {
 | 
			
		||||
                            ctx_clip.minicpmv_version = 2; // default to 2 if not set
 | 
			
		||||
                        }
 | 
			
		||||
                    } break;
 | 
			
		||||
                case PROJECTOR_TYPE_IDEFICS3:
 | 
			
		||||
                    {
 | 
			
		||||
                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
 | 
			
		||||
                    } break;
 | 
			
		||||
                case PROJECTOR_TYPE_PIXTRAL:
 | 
			
		||||
                    {
 | 
			
		||||
                        hparams.rope_theta = 10000.0f;
 | 
			
		||||
                    } break;
 | 
			
		||||
                default:
 | 
			
		||||
                    break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
 | 
			
		||||
            LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector);
 | 
			
		||||
            LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version);
 | 
			
		||||
            LOG_INF("%s: glm_projector:      %d\n", __func__, ctx_clip.has_glm_projector);
 | 
			
		||||
            LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
 | 
			
		||||
            LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // model-specific params
 | 
			
		||||
        switch (ctx_clip.proj_type) {
 | 
			
		||||
            case PROJECTOR_TYPE_IDEFICS3:
 | 
			
		||||
                {
 | 
			
		||||
                    get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
 | 
			
		||||
                } break;
 | 
			
		||||
            case PROJECTOR_TYPE_PIXTRAL:
 | 
			
		||||
                {
 | 
			
		||||
                    hparams.rope_theta = 10000.0f;
 | 
			
		||||
                } break;
 | 
			
		||||
            default:
 | 
			
		||||
                break;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void load_tensors() {
 | 
			
		||||
@@ -1569,9 +1569,6 @@ struct clip_model_loader {
 | 
			
		||||
        vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
 | 
			
		||||
        vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
 | 
			
		||||
        vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
 | 
			
		||||
        if (vision_model.patch_embeddings_1 == nullptr) {
 | 
			
		||||
            ctx_clip.has_qwen2vl_merger = false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
 | 
			
		||||
 | 
			
		||||
@@ -1669,7 +1666,7 @@ struct clip_model_loader {
 | 
			
		||||
                    vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
 | 
			
		||||
                    vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
 | 
			
		||||
                } break;
 | 
			
		||||
            case PROJECTOR_TYPE_RESAMPLER:
 | 
			
		||||
            case PROJECTOR_TYPE_MINICPMV:
 | 
			
		||||
                {
 | 
			
		||||
                    // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
 | 
			
		||||
                    vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
 | 
			
		||||
@@ -1702,7 +1699,7 @@ struct clip_model_loader {
 | 
			
		||||
                    vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight"));
 | 
			
		||||
                    vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
 | 
			
		||||
                } break;
 | 
			
		||||
            case PROJECTOR_TYPE_MERGER:
 | 
			
		||||
            case PROJECTOR_TYPE_QWEN2VL:
 | 
			
		||||
                {
 | 
			
		||||
                    vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
 | 
			
		||||
                    vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
 | 
			
		||||
@@ -2479,11 +2476,6 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
 | 
			
		||||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 | 
			
		||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
 | 
			
		||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
 | 
			
		||||
    if (!ctx->has_vision_encoder) {
 | 
			
		||||
        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    clip_image_size original_size{img->nx, img->ny};
 | 
			
		||||
    bool pad_to_square = true;
 | 
			
		||||
    auto & params = ctx->vision_model.hparams;
 | 
			
		||||
@@ -2504,7 +2496,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
 | 
			
		||||
        }
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
    else if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
        clip_image_u8 resized;
 | 
			
		||||
        auto patch_size = clip_get_patch_size(ctx) * 2;
 | 
			
		||||
        int nx = ceil((float)img->nx / patch_size) * patch_size;
 | 
			
		||||
@@ -2518,7 +2510,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
 | 
			
		||||
        res_imgs->entries.push_back(std::move(img_f32));
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
    else if (ctx->has_glm_projector
 | 
			
		||||
    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
 | 
			
		||||
        clip_image_u8 resized_image;
 | 
			
		||||
@@ -2646,7 +2638,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
 | 
			
		||||
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
        n_patches /= 4;
 | 
			
		||||
    } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
 | 
			
		||||
    } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        if (ctx->minicpmv_version == 2) {
 | 
			
		||||
            n_patches = 96;
 | 
			
		||||
        }
 | 
			
		||||
@@ -2656,7 +2648,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
 | 
			
		||||
        else if (ctx->minicpmv_version == 4) {
 | 
			
		||||
            n_patches = 64;
 | 
			
		||||
        }
 | 
			
		||||
    } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
 | 
			
		||||
        else {
 | 
			
		||||
            GGML_ABORT("Unknown minicpmv version");
 | 
			
		||||
        }
 | 
			
		||||
    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
        int patch_size = params.patch_size * 2;
 | 
			
		||||
        int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
 | 
			
		||||
        int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
 | 
			
		||||
@@ -2761,11 +2756,6 @@ static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, co
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
 | 
			
		||||
    if (!ctx->has_vision_encoder) {
 | 
			
		||||
        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    clip_image_f32_batch imgs;
 | 
			
		||||
    clip_image_f32_ptr img_copy(clip_image_f32_init());
 | 
			
		||||
    *img_copy = *img;
 | 
			
		||||
@@ -2776,20 +2766,11 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
 | 
			
		||||
 | 
			
		||||
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
 | 
			
		||||
    const clip_image_f32_batch & imgs = *imgs_c_ptr;
 | 
			
		||||
 | 
			
		||||
    if (!ctx->has_vision_encoder) {
 | 
			
		||||
        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
 | 
			
		||||
        return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int batch_size = imgs.entries.size();
 | 
			
		||||
    if (ctx->has_llava_projector) {
 | 
			
		||||
        GGML_ASSERT(batch_size == 1); // TODO: support multiple images
 | 
			
		||||
    }
 | 
			
		||||
    if (ctx->has_minicpmv_projector) {
 | 
			
		||||
        GGML_ASSERT(batch_size == 1);
 | 
			
		||||
    }
 | 
			
		||||
    if (ctx->has_glm_projector) {
 | 
			
		||||
 | 
			
		||||
    if (ctx->has_llava_projector
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
 | 
			
		||||
            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
        GGML_ASSERT(batch_size == 1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -2799,21 +2780,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
    ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
 | 
			
		||||
 | 
			
		||||
    // set inputs
 | 
			
		||||
    const auto & model = ctx->vision_model;
 | 
			
		||||
    const auto & model   = ctx->vision_model;
 | 
			
		||||
    const auto & hparams = model.hparams;
 | 
			
		||||
 | 
			
		||||
    // TODO @ngxson : this is ugly, need to refactor later
 | 
			
		||||
    bool support_dynamic_size = ctx->has_minicpmv_projector
 | 
			
		||||
        || ctx->has_qwen2vl_merger
 | 
			
		||||
        || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
 | 
			
		||||
    const int image_size_width  = imgs.entries[0]->nx;
 | 
			
		||||
    const int image_size_height = imgs.entries[0]->ny;
 | 
			
		||||
 | 
			
		||||
    const int image_size = hparams.image_size;
 | 
			
		||||
    int image_size_width  = image_size;
 | 
			
		||||
    int image_size_height = image_size;
 | 
			
		||||
    if (support_dynamic_size) {
 | 
			
		||||
        image_size_width  = imgs.entries[0]->nx;
 | 
			
		||||
        image_size_height = imgs.entries[0]->ny;
 | 
			
		||||
    }
 | 
			
		||||
    const int patch_size    = hparams.patch_size;
 | 
			
		||||
    const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
 | 
			
		||||
    const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
 | 
			
		||||
@@ -2839,14 +2811,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
        for (size_t i = 0; i < imgs.entries.size(); i++) {
 | 
			
		||||
            const int nx = imgs.entries[i]->nx;
 | 
			
		||||
            const int ny = imgs.entries[i]->ny;
 | 
			
		||||
 | 
			
		||||
            if (ctx->has_glm_projector
 | 
			
		||||
                    || ctx->has_llava_projector
 | 
			
		||||
                    || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
 | 
			
		||||
                    || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
 | 
			
		||||
                GGML_ASSERT(nx == image_size && ny == image_size);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            const int n = nx * ny;
 | 
			
		||||
 | 
			
		||||
            for (int b = 0; b < batch_size; b++) {
 | 
			
		||||
@@ -2864,13 +2828,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
        }
 | 
			
		||||
        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
 | 
			
		||||
    }
 | 
			
		||||
    if (ctx->has_minicpmv_projector) {
 | 
			
		||||
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        {
 | 
			
		||||
            // inspired from siglip:
 | 
			
		||||
            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
 | 
			
		||||
            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
 | 
			
		||||
            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 | 
			
		||||
            int* positions_data = (int*)malloc(ggml_nbytes(positions));
 | 
			
		||||
            std::vector<int> pos_data(ggml_nelements(positions));
 | 
			
		||||
            int * data = pos_data.data();
 | 
			
		||||
            int bucket_coords_h[1024];
 | 
			
		||||
            int bucket_coords_w[1024];
 | 
			
		||||
            for (int i = 0; i < pos_h; i++){
 | 
			
		||||
@@ -2881,11 +2847,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
            }
 | 
			
		||||
            for (int i = 0, id = 0; i < pos_h; i++){
 | 
			
		||||
                for (int j = 0; j < pos_w; j++){
 | 
			
		||||
                    positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
 | 
			
		||||
                    data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
 | 
			
		||||
            free(positions_data);
 | 
			
		||||
            ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
@@ -2903,30 +2868,28 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
            else if (ctx->minicpmv_version == 4) {
 | 
			
		||||
                embed_dim = 3584;
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
                GGML_ABORT("Unknown minicpmv version");
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
 | 
			
		||||
            auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
 | 
			
		||||
 | 
			
		||||
            float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
 | 
			
		||||
            for(int i=0;i < pos_w * pos_h; ++i){
 | 
			
		||||
                for(int j=0; j < embed_dim; ++j){
 | 
			
		||||
                    pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
 | 
			
		||||
            std::vector<float> pos_data(ggml_nelements(pos_embed));
 | 
			
		||||
            float * data = pos_data.data();
 | 
			
		||||
            for(int i = 0; i < pos_w * pos_h; ++i){
 | 
			
		||||
                for(int j = 0; j < embed_dim; ++j){
 | 
			
		||||
                    data[i * embed_dim + j] = pos_embed_t[i][j];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
 | 
			
		||||
            free(pos_embed_data);
 | 
			
		||||
            ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
        if (model.class_embedding) {
 | 
			
		||||
            struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
 | 
			
		||||
        // non-minicpmv models
 | 
			
		||||
 | 
			
		||||
            void* zero_mem = malloc(ggml_nbytes(embeddings));
 | 
			
		||||
            memset(zero_mem, 0, ggml_nbytes(embeddings));
 | 
			
		||||
            ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
 | 
			
		||||
            free(zero_mem);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (ctx->has_qwen2vl_merger) {
 | 
			
		||||
        if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
 | 
			
		||||
            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 | 
			
		||||
 | 
			
		||||
            const int pw = image_size_width / patch_size;
 | 
			
		||||
@@ -2978,6 +2941,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
            ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            // llava and other models
 | 
			
		||||
            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 | 
			
		||||
 | 
			
		||||
            int* positions_data = (int*)malloc(ggml_nbytes(positions));
 | 
			
		||||
@@ -2987,7 +2951,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 | 
			
		||||
            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
 | 
			
		||||
            free(positions_data);
 | 
			
		||||
 | 
			
		||||
            if (!ctx->has_glm_projector) {
 | 
			
		||||
            if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
 | 
			
		||||
                struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
 | 
			
		||||
                // The patches vector is used to get rows to index into the embeds with;
 | 
			
		||||
                // we should skip dim 0 only if we have CLS to avoid going out of bounds
 | 
			
		||||
@@ -3166,7 +3130,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
 | 
			
		||||
            return ctx->vision_model.mm_2_b->ne[0];
 | 
			
		||||
        case PROJECTOR_TYPE_MLP_NORM:
 | 
			
		||||
            return ctx->vision_model.mm_3_b->ne[0];
 | 
			
		||||
        case PROJECTOR_TYPE_RESAMPLER:
 | 
			
		||||
        case PROJECTOR_TYPE_MINICPMV:
 | 
			
		||||
            if (ctx->minicpmv_version == 2) {
 | 
			
		||||
                return 4096;
 | 
			
		||||
            } else if (ctx->minicpmv_version == 3) {
 | 
			
		||||
@@ -3174,36 +3138,33 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
 | 
			
		||||
            } else if (ctx->minicpmv_version == 4) {
 | 
			
		||||
                return 3584;
 | 
			
		||||
            }
 | 
			
		||||
            break; // Should not happen if version is valid
 | 
			
		||||
            GGML_ABORT("Unknown minicpmv version");
 | 
			
		||||
        case PROJECTOR_TYPE_GLM_EDGE:
 | 
			
		||||
            return ctx->vision_model.mm_model_mlp_3_w->ne[1];
 | 
			
		||||
        case PROJECTOR_TYPE_MERGER:
 | 
			
		||||
        case PROJECTOR_TYPE_QWEN2VL:
 | 
			
		||||
            return ctx->vision_model.mm_1_b->ne[0];
 | 
			
		||||
        case PROJECTOR_TYPE_GEMMA3:
 | 
			
		||||
            return ctx->vision_model.mm_input_proj_w->ne[0];
 | 
			
		||||
        case PROJECTOR_TYPE_IDEFICS3:
 | 
			
		||||
            return ctx->vision_model.projection->ne[1];
 | 
			
		||||
        default:
 | 
			
		||||
            break; // Fall through to throw
 | 
			
		||||
            GGML_ABORT("Unknown projector type");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
 | 
			
		||||
    throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int clip_is_minicpmv(const struct clip_ctx * ctx) {
 | 
			
		||||
    if (ctx->has_minicpmv_projector) {
 | 
			
		||||
    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
 | 
			
		||||
        return ctx->minicpmv_version;
 | 
			
		||||
    }
 | 
			
		||||
    return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool clip_is_glm(const struct clip_ctx * ctx) {
 | 
			
		||||
    return ctx->has_glm_projector;
 | 
			
		||||
    return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
 | 
			
		||||
    return ctx->has_qwen2vl_merger;
 | 
			
		||||
    return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool clip_is_llava(const struct clip_ctx * ctx) {
 | 
			
		||||
@@ -3214,29 +3175,6 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
 | 
			
		||||
    return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Determine the number of encoder layers to iterate over
 | 
			
		||||
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
 | 
			
		||||
    // Get the index of the second to last layer; this is the
 | 
			
		||||
    // default for models that have a llava projector
 | 
			
		||||
    const auto & hparams = ctx->vision_model.hparams;
 | 
			
		||||
    int n_layer = hparams.n_layer - 1;
 | 
			
		||||
    int deepest_feature_layer = -1;
 | 
			
		||||
 | 
			
		||||
    // Handle other projectors; incrementing here indicates that we
 | 
			
		||||
    // should use the last encoder layer for the vision features.
 | 
			
		||||
    if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
 | 
			
		||||
        n_layer += 1;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // If we set explicit vision feature layers, only go up to the deepest one
 | 
			
		||||
    for (const auto & feature_layer : hparams.vision_feature_layer) {
 | 
			
		||||
        if (feature_layer > deepest_feature_layer) {
 | 
			
		||||
            deepest_feature_layer = feature_layer;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
 | 
			
		||||
    clip_image_f32 clip_img;
 | 
			
		||||
    clip_img.buf.resize(h * w * 3);
 | 
			
		||||
 
 | 
			
		||||
@@ -114,8 +114,6 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
 | 
			
		||||
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
 | 
			
		||||
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
 | 
			
		||||
 | 
			
		||||
CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
 | 
			
		||||
 | 
			
		||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user