This commit is contained in:
caitianchi
2024-05-29 02:21:41 +08:00
parent 9495504e7b
commit b37ab0b1e5
3 changed files with 33 additions and 37 deletions

View File

@@ -1,7 +1,3 @@
// NOTE: This is modified from clip.cpp only for LLaVA,
// so there might be still unnecessary artifacts hanging around
// I'll gradually clean and extend it
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
#include "common.h"
#include "log.h"
@@ -1664,6 +1660,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
{
// inspired from siglip:
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
@@ -1675,6 +1674,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
{
// inspired from resampler of Qwen-VL:
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
@@ -1692,16 +1694,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
free(pos_embed_data);
}
// {
// struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// int* patches_data = (int*)malloc(ggml_nbytes(patches));
// for (int i = 0; i < num_patches; i++) {
// patches_data[i] = i + 1;
// }
// ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
// free(patches_data);
// }
if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
}