mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llava : support MiniCPM-V-2.5 (#7599)
* init * rename * add run android for termux in readme * add android readme * add instructions in readme * change name in readme * Update README.md * fixed line * add result in readme * random pos_embed * add positions index * change for ollama * change for ollama * better pos_embed in clip * support ollama * updata cmakelist * updata cmakelist * rename wrapper * clear code * replace and organize code * add link * sync master * fix warnings * fix warnings * fix bug in bicubic resize when need resize iamge smaller * receive review comments and modify * receive review comments and modify * put all code into llava dir * fix quality problem in pr code * change n_layer * add space in "-1" * imitate reshape bug of python code * fix bug in clip * fix issues for merging * fix llama-minicpmv-cli in cmake file * change pr readme * fix code review * remove in line 33 directory in the /cmakelists.txt (not in example, in the main dir * fix cmakefile * add warn * fix KEY_HAS_MINICPMV_PROJ * remove load_image_size into clip_ctx * remove the extern "C", MINICPMV_API * fix uhd code for review comment * delete minicpmv-wrapper in pr * remove uhd_image_embed * Modify 2 notes * clip : style changes * del common.h in clip * fix Type-Check error * fix Type-Check error * fix Type-Check error * fix Type-Check error * fix makefile error * fix ubuntu-make error * try fix clip * try fix 1 --------- Co-authored-by: Hongji Zhu <fireyoucan@gmail.com> Co-authored-by: harvestingmoon <leewenyeong@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -79,7 +79,6 @@ models-mnt | |||||||
| !models/ggml-vocab-*.gguf* | !models/ggml-vocab-*.gguf* | ||||||
|  |  | ||||||
| # Zig | # Zig | ||||||
|  |  | ||||||
| zig-out/ | zig-out/ | ||||||
| zig-cache/ | zig-cache/ | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @@ -19,6 +19,7 @@ BUILD_TARGETS = \ | |||||||
| 	llama-imatrix \ | 	llama-imatrix \ | ||||||
| 	llama-infill \ | 	llama-infill \ | ||||||
| 	llama-llava-cli \ | 	llama-llava-cli \ | ||||||
|  | 	llama-minicpmv-cli\ | ||||||
| 	llama-lookahead \ | 	llama-lookahead \ | ||||||
| 	llama-lookup \ | 	llama-lookup \ | ||||||
| 	llama-lookup-create \ | 	llama-lookup-create \ | ||||||
| @@ -1463,6 +1464,17 @@ llama-llava-cli: examples/llava/llava-cli.cpp \ | |||||||
| 	$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) | 	$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) | ||||||
| 	$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
|  | llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \ | ||||||
|  | 	examples/llava/clip.h \ | ||||||
|  | 	examples/llava/clip.cpp \ | ||||||
|  | 	examples/llava/llava.h \ | ||||||
|  | 	examples/llava/llava.cpp \ | ||||||
|  | 	$(OBJ_ALL) | ||||||
|  | 	$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) | ||||||
|  | 	$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp  -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual | ||||||
|  | 	$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp) | ||||||
|  | 	$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| ifeq ($(UNAME_S),Darwin) | ifeq ($(UNAME_S),Darwin) | ||||||
| swift: examples/batched.swift | swift: examples/batched.swift | ||||||
| 	(cd examples/batched.swift; make build) | 	(cd examples/batched.swift; make build) | ||||||
|   | |||||||
| @@ -36,3 +36,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli) | |||||||
| install(TARGETS ${TARGET} RUNTIME) | install(TARGETS ${TARGET} RUNTIME) | ||||||
| target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) | target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) | ||||||
| target_compile_features(${TARGET} PRIVATE cxx_std_11) | target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||||||
|  |  | ||||||
|  | set(TARGET llama-minicpmv-cli) | ||||||
|  | add_executable(${TARGET} minicpmv-cli.cpp) | ||||||
|  | set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli) | ||||||
|  | install(TARGETS ${TARGET} RUNTIME) | ||||||
|  | target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) | ||||||
|  | target_compile_features(${TARGET} PRIVATE cxx_std_11) | ||||||
|   | |||||||
							
								
								
									
										99
									
								
								examples/llava/README-minicpmv2.5.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								examples/llava/README-minicpmv2.5.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | |||||||
|  | ## MiniCPM-Llama3-V 2.5 | ||||||
|  |  | ||||||
|  | ### Prepare models and code | ||||||
|  |  | ||||||
|  | Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) PyTorch model from huggingface to "MiniCPM-Llama3-V-2_5" folder. | ||||||
|  |  | ||||||
|  | Clone llama.cpp: | ||||||
|  | ```bash | ||||||
|  | git clone https://github.com/ggerganov/llama.cpp | ||||||
|  | cd llama.cpp | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Usage | ||||||
|  |  | ||||||
|  | Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 | ||||||
|  | python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 | ||||||
|  | python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model | ||||||
|  |  | ||||||
|  | # quantize int4 version | ||||||
|  | ./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Build for Linux or Mac | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | make | ||||||
|  | make llama-minicpmv-cli | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Inference on Linux or Mac | ||||||
|  | ``` | ||||||
|  | # run f16 version | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" | ||||||
|  |  | ||||||
|  | # run quantized int4 version | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?" | ||||||
|  |  | ||||||
|  | # or run in interactive mode | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Android | ||||||
|  |  | ||||||
|  | #### Build on Android device using Termux | ||||||
|  | We found that build on Android device would bring better runtime performance, so we recommend to build on device. | ||||||
|  |  | ||||||
|  | [Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required). | ||||||
|  |  | ||||||
|  | Install tools in Termux: | ||||||
|  | ``` | ||||||
|  | apt update && apt upgrade -y | ||||||
|  | apt install git make cmake | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | It's recommended to move your model inside the `~/` directory for best performance: | ||||||
|  | ``` | ||||||
|  | cd storage/downloads | ||||||
|  | mv model.gguf ~/ | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | #### Building the Project using Android NDK | ||||||
|  | Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. | ||||||
|  |  | ||||||
|  | Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | mkdir build-android | ||||||
|  | cd build-android | ||||||
|  | export NDK=/your_ndk_path | ||||||
|  | cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. | ||||||
|  | make | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). | ||||||
|  |  | ||||||
|  | Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: | ||||||
|  |  | ||||||
|  | (Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) | ||||||
|  | ``` | ||||||
|  | $cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/ | ||||||
|  | $cd /data/data/com.termux/files/home/bin | ||||||
|  | $chmod +x ./* | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/` | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | $mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/ | ||||||
|  | $mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/ | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Now, you can start chatting: | ||||||
|  | ``` | ||||||
|  | $cd /data/data/com.termux/files/home/bin | ||||||
|  | $./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?" | ||||||
|  | ``` | ||||||
| @@ -74,26 +74,27 @@ static std::string format(const char * fmt, ...) { | |||||||
| // key constants | // key constants | ||||||
| // | // | ||||||
|  |  | ||||||
| #define KEY_FTYPE          "general.file_type" | #define KEY_FTYPE               "general.file_type" | ||||||
| #define KEY_NAME           "general.name" | #define KEY_NAME                "general.name" | ||||||
| #define KEY_DESCRIPTION    "general.description" | #define KEY_DESCRIPTION         "general.description" | ||||||
| #define KEY_HAS_TEXT_ENC   "clip.has_text_encoder" | #define KEY_HAS_TEXT_ENC        "clip.has_text_encoder" | ||||||
| #define KEY_HAS_VIS_ENC    "clip.has_vision_encoder" | #define KEY_HAS_VIS_ENC         "clip.has_vision_encoder" | ||||||
| #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" | #define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector" | ||||||
| #define KEY_USE_GELU       "clip.use_gelu" | #define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector" | ||||||
| #define KEY_N_EMBD         "clip.%s.embedding_length" | #define KEY_USE_GELU            "clip.use_gelu" | ||||||
| #define KEY_N_FF           "clip.%s.feed_forward_length" | #define KEY_N_EMBD              "clip.%s.embedding_length" | ||||||
| #define KEY_N_BLOCK        "clip.%s.block_count" | #define KEY_N_FF                "clip.%s.feed_forward_length" | ||||||
| #define KEY_N_HEAD         "clip.%s.attention.head_count" | #define KEY_N_BLOCK             "clip.%s.block_count" | ||||||
| #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" | #define KEY_N_HEAD              "clip.%s.attention.head_count" | ||||||
| #define KEY_PROJ_DIM       "clip.%s.projection_dim" | #define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon" | ||||||
| #define KEY_TOKENS         "tokenizer.ggml.tokens" | #define KEY_PROJ_DIM            "clip.%s.projection_dim" | ||||||
| #define KEY_N_POSITIONS    "clip.text.context_length" | #define KEY_TOKENS              "tokenizer.ggml.tokens" | ||||||
| #define KEY_IMAGE_SIZE     "clip.vision.image_size" | #define KEY_N_POSITIONS         "clip.text.context_length" | ||||||
| #define KEY_PATCH_SIZE     "clip.vision.patch_size" | #define KEY_IMAGE_SIZE          "clip.vision.image_size" | ||||||
| #define KEY_IMAGE_MEAN     "clip.vision.image_mean" | #define KEY_PATCH_SIZE          "clip.vision.patch_size" | ||||||
| #define KEY_IMAGE_STD      "clip.vision.image_std" | #define KEY_IMAGE_MEAN          "clip.vision.image_mean" | ||||||
| #define KEY_PROJ_TYPE      "clip.projector_type" | #define KEY_IMAGE_STD           "clip.vision.image_std" | ||||||
|  | #define KEY_PROJ_TYPE           "clip.projector_type" | ||||||
|  |  | ||||||
| #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type" | #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type" | ||||||
| #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints" | #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints" | ||||||
| @@ -127,12 +128,20 @@ static std::string format(const char * fmt, ...) { | |||||||
| #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s" | #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s" | ||||||
| #define TN_IMAGE_NEWLINE   "model.image_newline" | #define TN_IMAGE_NEWLINE   "model.image_newline" | ||||||
|  |  | ||||||
|  | #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" | ||||||
|  | #define TN_MINICPMV_QUERY "resampler.query" | ||||||
|  | #define TN_MINICPMV_PROJ "resampler.proj.weight" | ||||||
|  | #define TN_MINICPMV_KV_PROJ "resampler.kv.weight" | ||||||
|  | #define TN_MINICPMV_ATTN "resampler.attn.%s.%s" | ||||||
|  | #define TN_MINICPMV_LN "resampler.ln_%s.%s" | ||||||
|  |  | ||||||
|  |  | ||||||
| enum projector_type { | enum projector_type { | ||||||
|     PROJECTOR_TYPE_MLP, |     PROJECTOR_TYPE_MLP, | ||||||
|     PROJECTOR_TYPE_MLP_NORM, |     PROJECTOR_TYPE_MLP_NORM, | ||||||
|     PROJECTOR_TYPE_LDP, |     PROJECTOR_TYPE_LDP, | ||||||
|     PROJECTOR_TYPE_LDPV2, |     PROJECTOR_TYPE_LDPV2, | ||||||
|  |     PROJECTOR_TYPE_RESAMPLER, | ||||||
|     PROJECTOR_TYPE_UNKNOWN, |     PROJECTOR_TYPE_UNKNOWN, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -140,6 +149,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = { | |||||||
|     { PROJECTOR_TYPE_MLP, "mlp" }, |     { PROJECTOR_TYPE_MLP, "mlp" }, | ||||||
|     { PROJECTOR_TYPE_LDP, "ldp" }, |     { PROJECTOR_TYPE_LDP, "ldp" }, | ||||||
|     { PROJECTOR_TYPE_LDPV2, "ldpv2"}, |     { PROJECTOR_TYPE_LDPV2, "ldpv2"}, | ||||||
|  |     { PROJECTOR_TYPE_RESAMPLER, "resampler"}, | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -492,12 +502,33 @@ struct clip_vision_model { | |||||||
|     struct ggml_tensor * mm_model_mlp_2_b; |     struct ggml_tensor * mm_model_mlp_2_b; | ||||||
|     struct ggml_tensor * mm_model_peg_0_w; |     struct ggml_tensor * mm_model_peg_0_w; | ||||||
|     struct ggml_tensor * mm_model_peg_0_b; |     struct ggml_tensor * mm_model_peg_0_b; | ||||||
|  |  | ||||||
|  |     // MINICPMV projection | ||||||
|  |     struct ggml_tensor * mm_model_pos_embed_k; | ||||||
|  |     struct ggml_tensor * mm_model_query; | ||||||
|  |     struct ggml_tensor * mm_model_proj; | ||||||
|  |     struct ggml_tensor * mm_model_kv_proj; | ||||||
|  |     struct ggml_tensor * mm_model_attn_q_w; | ||||||
|  |     struct ggml_tensor * mm_model_attn_q_b; | ||||||
|  |     struct ggml_tensor * mm_model_attn_k_w; | ||||||
|  |     struct ggml_tensor * mm_model_attn_k_b; | ||||||
|  |     struct ggml_tensor * mm_model_attn_v_w; | ||||||
|  |     struct ggml_tensor * mm_model_attn_v_b; | ||||||
|  |     struct ggml_tensor * mm_model_attn_o_w; | ||||||
|  |     struct ggml_tensor * mm_model_attn_o_b; | ||||||
|  |     struct ggml_tensor * mm_model_ln_q_w; | ||||||
|  |     struct ggml_tensor * mm_model_ln_q_b; | ||||||
|  |     struct ggml_tensor * mm_model_ln_kv_w; | ||||||
|  |     struct ggml_tensor * mm_model_ln_kv_b; | ||||||
|  |     struct ggml_tensor * mm_model_ln_post_w; | ||||||
|  |     struct ggml_tensor * mm_model_ln_post_b; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct clip_ctx { | struct clip_ctx { | ||||||
|     bool has_text_encoder    = false; |     bool has_text_encoder    = false; | ||||||
|     bool has_vision_encoder  = false; |     bool has_vision_encoder  = false; | ||||||
|     bool has_llava_projector = false; |     bool has_llava_projector = false; | ||||||
|  |     bool has_minicpmv_projector = false; | ||||||
|  |  | ||||||
|     struct clip_vision_model vision_model; |     struct clip_vision_model vision_model; | ||||||
|     projector_type proj_type = PROJECTOR_TYPE_MLP; |     projector_type proj_type = PROJECTOR_TYPE_MLP; | ||||||
| @@ -522,9 +553,11 @@ struct clip_ctx { | |||||||
|  |  | ||||||
|     ggml_backend_t backend       = NULL; |     ggml_backend_t backend       = NULL; | ||||||
|     ggml_gallocr_t compute_alloc = NULL; |     ggml_gallocr_t compute_alloc = NULL; | ||||||
|  |  | ||||||
|  |     struct clip_image_size * load_image_size; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { | static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { | ||||||
|     if (!ctx->has_vision_encoder) { |     if (!ctx->has_vision_encoder) { | ||||||
|         LOG_TEE("This gguf file seems to have no vision encoder\n"); |         LOG_TEE("This gguf file seems to have no vision encoder\n"); | ||||||
|         return nullptr; |         return nullptr; | ||||||
| @@ -533,20 +566,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|     const auto & model = ctx->vision_model; |     const auto & model = ctx->vision_model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
|     const int image_size           = hparams.image_size; |     const int image_size = hparams.image_size; | ||||||
|  |     int image_size_width  = image_size; | ||||||
|  |     int image_size_height = image_size; | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|  |         if (load_image_size == nullptr) { | ||||||
|  |             load_image_size = clip_image_size_init(); | ||||||
|  |         } | ||||||
|  |         LOG_TEE("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); | ||||||
|  |         image_size_width  = load_image_size->width; | ||||||
|  |         image_size_height = load_image_size->height; | ||||||
|  |         if (is_inf) { | ||||||
|  |             image_size_width  = imgs->data->nx; | ||||||
|  |             image_size_height = imgs->data->ny; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|     const int patch_size           = hparams.patch_size; |     const int patch_size           = hparams.patch_size; | ||||||
|     const int num_patches          = ((image_size / patch_size) * (image_size / patch_size)); |     const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size)); | ||||||
|     const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side); |  | ||||||
|     const int num_positions        = num_patches + (ctx->has_class_embedding ? 1 : 0); |     const int num_positions        = num_patches + (ctx->has_class_embedding ? 1 : 0); | ||||||
|     const int hidden_size          = hparams.hidden_size; |     const int hidden_size          = hparams.hidden_size; | ||||||
|     const int n_head               = hparams.n_head; |     const int n_head               = hparams.n_head; | ||||||
|     const int d_head               = hidden_size / n_head; |     const int d_head               = hidden_size / n_head; | ||||||
|     const int n_layer              = hparams.n_layer; |     int n_layer                    = hparams.n_layer; | ||||||
|     const float eps                = hparams.eps; |     const float eps                = hparams.eps; | ||||||
|  |  | ||||||
|     const int batch_size = imgs->size; |     const int batch_size = imgs->size; | ||||||
|  |  | ||||||
|     if (ctx->has_llava_projector) { |     if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { | ||||||
|         GGML_ASSERT(batch_size == 1); |         GGML_ASSERT(batch_size == 1); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -559,7 +605,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|     struct ggml_context * ctx0 = ggml_init(params); |     struct ggml_context * ctx0 = ggml_init(params); | ||||||
|     struct ggml_cgraph * gf = ggml_new_graph(ctx0); |     struct ggml_cgraph * gf = ggml_new_graph(ctx0); | ||||||
|  |  | ||||||
|     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); |     struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); | ||||||
|     ggml_set_name(inp_raw, "inp_raw"); |     ggml_set_name(inp_raw, "inp_raw"); | ||||||
|     ggml_set_input(inp_raw); |     ggml_set_input(inp_raw); | ||||||
|  |  | ||||||
| @@ -572,19 +618,21 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|         // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); |         // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); | ||||||
|         inp = ggml_add(ctx0, inp, model.patch_bias); |         inp = ggml_add(ctx0, inp, model.patch_bias); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // concat class_embeddings and patch_embeddings |  | ||||||
|     struct ggml_tensor * embeddings = inp; |     struct ggml_tensor * embeddings = inp; | ||||||
|     if (ctx->has_class_embedding) { |     struct ggml_tensor * pos_embed = nullptr; | ||||||
|         embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); |  | ||||||
|         ggml_set_name(embeddings, "embeddings"); |  | ||||||
|         ggml_set_input(embeddings); |  | ||||||
|         embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, |  | ||||||
|                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); |  | ||||||
|         embeddings = ggml_acc(ctx0, embeddings, inp, |  | ||||||
|                 embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |     if (ctx->has_llava_projector) { | ||||||
|  |         // concat class_embeddings and patch_embeddings | ||||||
|  |         if (ctx->has_class_embedding) { | ||||||
|  |             embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); | ||||||
|  |             ggml_set_name(embeddings, "embeddings"); | ||||||
|  |             ggml_set_input(embeddings); | ||||||
|  |             embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, | ||||||
|  |                     embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); | ||||||
|  |             embeddings = ggml_acc(ctx0, embeddings, inp, | ||||||
|  |                     embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); |     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); | ||||||
|     ggml_set_name(positions, "positions"); |     ggml_set_name(positions, "positions"); | ||||||
| @@ -593,6 +641,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|     embeddings = |     embeddings = | ||||||
|         ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); |         ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); | ||||||
|  |  | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|  |         int pos_w = image_size_width/patch_size; | ||||||
|  |         int pos_h = image_size_height/patch_size; | ||||||
|  |         pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); | ||||||
|  |         ggml_set_name(pos_embed, "pos_embed"); | ||||||
|  |         ggml_set_input(pos_embed); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // pre-layernorm |     // pre-layernorm | ||||||
|     if (ctx->has_pre_norm) { |     if (ctx->has_pre_norm) { | ||||||
|         embeddings = ggml_norm(ctx0, embeddings, eps); |         embeddings = ggml_norm(ctx0, embeddings, eps); | ||||||
| @@ -602,6 +658,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // loop over layers |     // loop over layers | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|  |         n_layer += 1; | ||||||
|  |     } | ||||||
|     for (int il = 0; il < n_layer - 1; il++) { |     for (int il = 0; il < n_layer - 1; il++) { | ||||||
|         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states |         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states | ||||||
|  |  | ||||||
| @@ -691,7 +750,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // llava projector |     // llava projector | ||||||
|     { |     if (ctx->has_llava_projector) { | ||||||
|         embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); |         embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); | ||||||
|  |  | ||||||
|         struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); |         struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); | ||||||
| @@ -872,6 +931,65 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|             GGML_ABORT("fatal error"); |             GGML_ABORT("fatal error"); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |     // minicpmv projector | ||||||
|  |     else if (ctx->has_minicpmv_projector) | ||||||
|  |     { | ||||||
|  |         if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { | ||||||
|  |             struct ggml_tensor * q = model.mm_model_query; | ||||||
|  |             { // layernorm | ||||||
|  |                 q = ggml_norm(ctx0, q, eps); | ||||||
|  |                 q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); | ||||||
|  |             } | ||||||
|  |             struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); | ||||||
|  |             { // layernorm | ||||||
|  |                 v = ggml_norm(ctx0, v, eps); | ||||||
|  |                 v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); | ||||||
|  |             } | ||||||
|  |             struct ggml_tensor * k; | ||||||
|  |             { // position | ||||||
|  |                 // q = ggml_add(ctx0, q, model.mm_model_pos_embed); | ||||||
|  |                 k = ggml_add(ctx0, v, pos_embed); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             { // attention | ||||||
|  |                 const int hidden_size = 4096; | ||||||
|  |                 const int d_head = 128; | ||||||
|  |                 const int n_head = hidden_size/d_head; | ||||||
|  |                 const int num_query = 96; | ||||||
|  |  | ||||||
|  |                 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); | ||||||
|  |                 Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); | ||||||
|  |                 struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); | ||||||
|  |                 struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); | ||||||
|  |                 // permute | ||||||
|  |                 Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); | ||||||
|  |                 Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); | ||||||
|  |                 Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); | ||||||
|  |                 K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); | ||||||
|  |                 K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); | ||||||
|  |                 K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); | ||||||
|  |                 V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); | ||||||
|  |                 V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); | ||||||
|  |                 V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); | ||||||
|  |                 struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); | ||||||
|  |                 KQ = ggml_soft_max_inplace(ctx0, KQ); | ||||||
|  |                 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); | ||||||
|  |                 KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); | ||||||
|  |                 KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); | ||||||
|  |                 KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); | ||||||
|  |  | ||||||
|  |                 embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); | ||||||
|  |             } | ||||||
|  |             { // layernorm | ||||||
|  |                 embeddings = ggml_norm(ctx0, embeddings, eps); | ||||||
|  |                 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); | ||||||
|  |             } | ||||||
|  |             embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             GGML_ASSERT(false); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // build the graph |     // build the graph | ||||||
|     ggml_build_forward_expand(gf, embeddings); |     ggml_build_forward_expand(gf, embeddings); | ||||||
| @@ -1029,7 +1147,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||||||
|             new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); |             new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search |         idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ); | ||||||
|  |         if (idx != -1) { | ||||||
|  |             new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search | ||||||
|  |  | ||||||
|         GGML_ASSERT(new_clip->has_vision_encoder); |         GGML_ASSERT(new_clip->has_vision_encoder); | ||||||
|         GGML_ASSERT(!new_clip->has_text_encoder); |         GGML_ASSERT(!new_clip->has_text_encoder); | ||||||
|  |  | ||||||
| @@ -1040,6 +1164,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||||||
|             LOG_TEE("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder); |             LOG_TEE("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder); | ||||||
|             LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); |             LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); | ||||||
|             LOG_TEE("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector); |             LOG_TEE("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector); | ||||||
|  |             LOG_TEE("%s: minicpmv_projector:  %d\n", __func__, new_clip->has_minicpmv_projector); | ||||||
|             LOG_TEE("%s: model size:     %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); |             LOG_TEE("%s: model size:     %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); | ||||||
|             LOG_TEE("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); |             LOG_TEE("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); | ||||||
|         } |         } | ||||||
| @@ -1281,6 +1406,27 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||||||
|             vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); |             vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight")); | ||||||
|             vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); |             vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias")); | ||||||
|         } |         } | ||||||
|  |         else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) { | ||||||
|  |             // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); | ||||||
|  |             vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K); | ||||||
|  |             vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY); | ||||||
|  |             vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ); | ||||||
|  |             vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ); | ||||||
|  |             vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight")); | ||||||
|  |             vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight")); | ||||||
|  |             vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight")); | ||||||
|  |             vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias")); | ||||||
|  |             vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias")); | ||||||
|  |             vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias")); | ||||||
|  |             vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight")); | ||||||
|  |             vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias")); | ||||||
|  |             vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight")); | ||||||
|  |             vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias")); | ||||||
|  |             vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight")); | ||||||
|  |             vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias")); | ||||||
|  |             vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); | ||||||
|  |             vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); | ||||||
|  |         } | ||||||
|         else { |         else { | ||||||
|             std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; |             std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; | ||||||
|             throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); |             throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); | ||||||
| @@ -1319,7 +1465,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||||||
|         new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); |         new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); | ||||||
|         clip_image_f32_batch batch; |         clip_image_f32_batch batch; | ||||||
|         batch.size = 1; |         batch.size = 1; | ||||||
|         ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); |         ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); | ||||||
|         ggml_gallocr_reserve(new_clip->compute_alloc, gf); |         ggml_gallocr_reserve(new_clip->compute_alloc, gf); | ||||||
|         size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); |         size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); | ||||||
|         LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); |         LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); | ||||||
| @@ -1328,6 +1474,17 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | |||||||
|     return new_clip; |     return new_clip; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { | ||||||
|  |     ctx_clip->load_image_size = load_image_size; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | struct clip_image_size * clip_image_size_init() { | ||||||
|  |     struct clip_image_size * load_image_size = new struct clip_image_size(); | ||||||
|  |     load_image_size->width = 448; | ||||||
|  |     load_image_size->height = 448; | ||||||
|  |     return load_image_size; | ||||||
|  | } | ||||||
|  |  | ||||||
| struct clip_image_u8 * clip_image_u8_init() { | struct clip_image_u8 * clip_image_u8_init() { | ||||||
|     return new clip_image_u8(); |     return new clip_image_u8(); | ||||||
| } | } | ||||||
| @@ -1598,9 +1755,184 @@ static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & im | |||||||
|     return patches; |     return patches; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static int ensure_divide(int length, int patch_size) { | ||||||
|  |     return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static std::pair<int, int> uhd_find_best_resize(std::pair<int, int> original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { | ||||||
|  |     int width = original_size.first; | ||||||
|  |     int height = original_size.second; | ||||||
|  |     if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { | ||||||
|  |         float r = static_cast<float>(width) / height; | ||||||
|  |         height = static_cast<int>(scale_resolution / std::sqrt(r)); | ||||||
|  |         width = static_cast<int>(height * r); | ||||||
|  |     } | ||||||
|  |     int best_width = ensure_divide(width, patch_size); | ||||||
|  |     int best_height = ensure_divide(height, patch_size); | ||||||
|  |     return std::make_pair(best_width, best_height); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static std::pair<int, int> uhd_get_refine_size(std::pair<int, int> original_size, std::pair<int, int> grid, int scale_resolution, int patch_size, bool allow_upscale = false) { | ||||||
|  |     int width, height; | ||||||
|  |     std::tie(width, height) = original_size; | ||||||
|  |     int grid_x, grid_y; | ||||||
|  |     std::tie(grid_x, grid_y) = grid; | ||||||
|  |  | ||||||
|  |     int refine_width = ensure_divide(width, grid_x); | ||||||
|  |     int refine_height = ensure_divide(height, grid_y); | ||||||
|  |  | ||||||
|  |     int grid_width = refine_width / grid_x; | ||||||
|  |     int grid_height = refine_height / grid_y; | ||||||
|  |  | ||||||
|  |    // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) | ||||||
|  |     auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair | ||||||
|  |     int best_grid_width, best_grid_height; | ||||||
|  |     std::tie(best_grid_width, best_grid_height) = best_grid_size; | ||||||
|  |  | ||||||
|  |   //  std::pair<int, int> refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) | ||||||
|  |     std::pair<int, int> refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) | ||||||
|  |     return refine_size; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int clip(int x, int lower, int upper) { | ||||||
|  |     return std::max(lower, std::min(x, upper)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static std::pair<int, int> uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { | ||||||
|  |     std::vector<int> candidate_split_grids_nums; | ||||||
|  |     for (int i : {multiple - 1, multiple, multiple + 1}) { | ||||||
|  |         if (i == 1 || i > max_slice_nums) { | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |         candidate_split_grids_nums.push_back(i); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::vector<std::pair<int, int>> candidate_grids; | ||||||
|  |     for (int split_grids_nums : candidate_split_grids_nums) { | ||||||
|  |         int m = 1; | ||||||
|  |         while (m <= split_grids_nums) { | ||||||
|  |             if (split_grids_nums % m == 0) { | ||||||
|  |                 candidate_grids.emplace_back(m, split_grids_nums / m); | ||||||
|  |             } | ||||||
|  |             ++m; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::pair<int, int> best_grid{1, 1}; | ||||||
|  |     float min_error = std::numeric_limits<float>::infinity(); | ||||||
|  |     for (const auto& grid : candidate_grids) { | ||||||
|  |         float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); | ||||||
|  |         if (error < min_error) { | ||||||
|  |             best_grid = grid; | ||||||
|  |             min_error = error; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return best_grid; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // inspired from LLaVA-UHD: | ||||||
|  | //    -> https://arxiv.org/pdf/2403.11703 | ||||||
|  | //    -> https://github.com/thunlp/LLaVA-UHD | ||||||
|  | //    -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 | ||||||
|  | static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { | ||||||
|  |     const std::pair<int, int> original_size={img->nx,img->ny}; | ||||||
|  |     const int original_width = img->nx; | ||||||
|  |     const int original_height = img->ny; | ||||||
|  |     const float log_ratio = log(1.0*original_width/original_height); | ||||||
|  |     const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); | ||||||
|  |     const int multiple = fmin(ceil(ratio), max_slice_nums); | ||||||
|  |  | ||||||
|  |     std::vector<std::vector<clip_image_u8 *>> images; | ||||||
|  |     LOG_TEE("%s: multiple %d\n", __func__, multiple); | ||||||
|  |     images.push_back(std::vector<clip_image_u8 *>()); | ||||||
|  |  | ||||||
|  |     if (multiple <= 1) { | ||||||
|  |         auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); | ||||||
|  |         clip_image_u8 * source_image = clip_image_u8_init(); | ||||||
|  |         bicubic_resize(*img, *source_image, best_size.first, best_size.second); | ||||||
|  |         // source_image = image.resize(best_size, Image.Resampling.BICUBIC) | ||||||
|  |         images[images.size()-1].push_back(source_image); | ||||||
|  |     } | ||||||
|  |     else if (multiple > 1) { | ||||||
|  |         auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); | ||||||
|  |         clip_image_u8 * source_image = clip_image_u8_init(); | ||||||
|  |         bicubic_resize(*img, *source_image, best_size.first, best_size.second); | ||||||
|  |         // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) | ||||||
|  |         LOG_TEE("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); | ||||||
|  |         images[images.size()-1].push_back(source_image); | ||||||
|  |  | ||||||
|  |         std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); | ||||||
|  |         LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); | ||||||
|  |  | ||||||
|  |         auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); | ||||||
|  |         clip_image_u8 * refine_image = clip_image_u8_init(); | ||||||
|  |         bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); | ||||||
|  |  | ||||||
|  |         LOG_TEE("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); | ||||||
|  |  | ||||||
|  |         // split_to_patches | ||||||
|  |         int width = refine_image->nx; | ||||||
|  |         int height = refine_image->ny; | ||||||
|  |         int grid_x = int(width / best_grid.first); | ||||||
|  |         int grid_y = int(height / best_grid.second); | ||||||
|  |         for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ | ||||||
|  |             images.push_back(std::vector<clip_image_u8 *>()); | ||||||
|  |             for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ | ||||||
|  |                 clip_image_u8 * patch = clip_image_u8_init(); | ||||||
|  |                 patch->nx = grid_x; | ||||||
|  |                 patch->ny = grid_y; | ||||||
|  |                 patch->buf.resize(3 * patch->nx * patch->ny); | ||||||
|  |                 for (int y = patches_i; y < patches_i + grid_y; ++y) { | ||||||
|  |                     for (int x = patches_j; x < patches_j + grid_x; ++x) { | ||||||
|  |                         const int i = 3 * (y * refine_image->nx + x); | ||||||
|  |                         const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j)); | ||||||
|  |                         patch->buf[j]   = refine_image->buf[i]; | ||||||
|  |                         patch->buf[j+1] = refine_image->buf[i+1]; | ||||||
|  |                         patch->buf[j+2] = refine_image->buf[i+2]; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 images[images.size()-1].push_back(patch); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return images; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { | ||||||
|  |     const int max_slice_nums=9; | ||||||
|  |     const int scale_resolution=448; | ||||||
|  |     const int original_width = ctx_clip->load_image_size->width; | ||||||
|  |     const int original_height = ctx_clip->load_image_size->height; | ||||||
|  |     const float log_ratio = log(1.0*original_width/original_height); | ||||||
|  |     const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); | ||||||
|  |     const int multiple = fmin(ceil(ratio), max_slice_nums); | ||||||
|  |     std::pair<int, int> best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); | ||||||
|  |     return best_grid.first; | ||||||
|  | } | ||||||
|  |  | ||||||
| // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector | // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector | ||||||
| // res_imgs memory is being allocated here, previous allocations will be freed if found | // res_imgs memory is being allocated here, previous allocations will be freed if found | ||||||
| bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { | bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) { | ||||||
|  |     if (clip_is_minicpmv(ctx)) { | ||||||
|  |         std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img); | ||||||
|  |         res_imgs->size = 0; | ||||||
|  |         for (size_t i = 0; i < imgs.size(); ++i) { | ||||||
|  |             res_imgs->size += imgs[i].size(); | ||||||
|  |         } | ||||||
|  |         res_imgs->data = new clip_image_f32[res_imgs->size]; | ||||||
|  |         int idx = 0; | ||||||
|  |         for (size_t i = 0; i < imgs.size(); ++i) { | ||||||
|  |             for (size_t j = 0; j < imgs[i].size(); ++j) { | ||||||
|  |                 LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); | ||||||
|  |                 clip_image_f32 * res = clip_image_f32_init(); | ||||||
|  |                 normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); | ||||||
|  |                 res_imgs->data[idx++] = *res; | ||||||
|  |                 clip_image_f32_free(res); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     bool pad_to_square = true; |     bool pad_to_square = true; | ||||||
|     if (!ctx->has_vision_encoder) { |     if (!ctx->has_vision_encoder) { | ||||||
|         LOG_TEE("This gguf file seems to have no vision encoder\n"); |         LOG_TEE("This gguf file seems to have no vision encoder\n"); | ||||||
| @@ -1816,11 +2148,99 @@ int clip_n_patches(const struct clip_ctx * ctx) { | |||||||
|  |  | ||||||
|     if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { |     if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { | ||||||
|         n_patches /= 4; |         n_patches /= 4; | ||||||
|  |     } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { | ||||||
|  |         n_patches = 96; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return n_patches; |     return n_patches; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) { | ||||||
|  |     assert(embed_dim % 2 == 0); | ||||||
|  |     int H = pos.size(); | ||||||
|  |     int W = pos[0].size(); | ||||||
|  |  | ||||||
|  |     std::vector<float> omega(embed_dim / 2); | ||||||
|  |     for (int i = 0; i < embed_dim / 2; ++i) { | ||||||
|  |         omega[i] = 1.0 / pow(10000.0, static_cast<float>(i) / (embed_dim / 2)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim))); | ||||||
|  |     for (int h = 0; h < H; ++h) { | ||||||
|  |         for (int w = 0; w < W; ++w) { | ||||||
|  |             for (int d = 0; d < embed_dim / 2; ++d) { | ||||||
|  |                 float out_value = pos[h][w] * omega[d]; | ||||||
|  |                 emb[h][w][d] = sin(out_value); | ||||||
|  |                 emb[h][w][d + embed_dim / 2] = cos(out_value); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return emb; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static std::vector<std::vector<std::vector<float>>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector<std::vector<std::vector<float>>> & grid) { | ||||||
|  |     assert(embed_dim % 2 == 0); | ||||||
|  |     std::vector<std::vector<std::vector<float>>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) | ||||||
|  |     std::vector<std::vector<std::vector<float>>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) | ||||||
|  |  | ||||||
|  |     int H = emb_h.size(); | ||||||
|  |     int W = emb_h[0].size(); | ||||||
|  |     std::vector<std::vector<std::vector<float>>> emb(H, std::vector<std::vector<float>>(W, std::vector<float>(embed_dim))); | ||||||
|  |  | ||||||
|  |     for (int h = 0; h < H; ++h) { | ||||||
|  |         for (int w = 0; w < W; ++w) { | ||||||
|  |             for (int d = 0; d < embed_dim / 2; ++d) { | ||||||
|  |                 emb[h][w][d] = emb_h[h][w][d]; | ||||||
|  |                 emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return emb; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, const std::pair<int, int> image_size) { | ||||||
|  |     int grid_h_size = image_size.first; | ||||||
|  |     int grid_w_size = image_size.second; | ||||||
|  |  | ||||||
|  |     std::vector<float> grid_h(grid_h_size); | ||||||
|  |     std::vector<float> grid_w(grid_w_size); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < grid_h_size; ++i) { | ||||||
|  |         grid_h[i] = static_cast<float>(i); | ||||||
|  |     } | ||||||
|  |     for (int i = 0; i < grid_w_size; ++i) { | ||||||
|  |         grid_w[i] = static_cast<float>(i); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::vector<std::vector<float>> grid(grid_h_size, std::vector<float>(grid_w_size)); | ||||||
|  |     for (int h = 0; h < grid_h_size; ++h) { | ||||||
|  |         for (int w = 0; w < grid_w_size; ++w) { | ||||||
|  |             grid[h][w] = grid_w[w]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     std::vector<std::vector<std::vector<float>>> grid_2d = {grid, grid}; | ||||||
|  |     for (int h = 0; h < grid_h_size; ++h) { | ||||||
|  |         for (int w = 0; w < grid_w_size; ++w) { | ||||||
|  |             grid_2d[0][h][w] = grid_h[h]; | ||||||
|  |             grid_2d[1][h][w] = grid_w[w]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::vector<std::vector<std::vector<float>>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); | ||||||
|  |  | ||||||
|  |     int H = image_size.first; | ||||||
|  |     int W = image_size.second; | ||||||
|  |     std::vector<std::vector<float>> pos_embed_2d(H * W, std::vector<float>(embed_dim)); | ||||||
|  |     for (int h = 0; h < H; ++h) { | ||||||
|  |         for (int w = 0; w < W; ++w) { | ||||||
|  |             pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return pos_embed_2d; | ||||||
|  | } | ||||||
|  |  | ||||||
| bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { | bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { | ||||||
|     if (!ctx->has_vision_encoder) { |     if (!ctx->has_vision_encoder) { | ||||||
|         LOG_TEE("This gguf file seems to have no vision encoder\n"); |         LOG_TEE("This gguf file seems to have no vision encoder\n"); | ||||||
| @@ -1843,18 +2263,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|     if (ctx->has_llava_projector) { |     if (ctx->has_llava_projector) { | ||||||
|         GGML_ASSERT(batch_size == 1); // TODO: support multiple images |         GGML_ASSERT(batch_size == 1); // TODO: support multiple images | ||||||
|     } |     } | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|  |         GGML_ASSERT(batch_size == 1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // build the inference graph |     // build the inference graph | ||||||
|     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); |     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); | ||||||
|     ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); |     ggml_gallocr_alloc_graph(ctx->compute_alloc, gf); | ||||||
|  |  | ||||||
|     // set inputs |     // set inputs | ||||||
|     const auto & model = ctx->vision_model; |     const auto & model = ctx->vision_model; | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
|     const int image_size    = hparams.image_size; |     const int image_size = hparams.image_size; | ||||||
|  |     int image_size_width  = image_size; | ||||||
|  |     int image_size_height = image_size; | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|  |         image_size_width  = imgs->data[0].nx; | ||||||
|  |         image_size_height = imgs->data[0].ny; | ||||||
|  |     } | ||||||
|     const int patch_size    = hparams.patch_size; |     const int patch_size    = hparams.patch_size; | ||||||
|     const int num_patches   = ((image_size / patch_size) * (image_size / patch_size)); |     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size)); | ||||||
|     const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); |     const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); | ||||||
|  |  | ||||||
|     { |     { | ||||||
| @@ -1864,7 +2293,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|         for (size_t i = 0; i < imgs->size; i++) { |         for (size_t i = 0; i < imgs->size; i++) { | ||||||
|             const int nx = imgs->data[i].nx; |             const int nx = imgs->data[i].nx; | ||||||
|             const int ny = imgs->data[i].ny; |             const int ny = imgs->data[i].ny; | ||||||
|             GGML_ASSERT(nx == image_size && ny == image_size); |             if (!ctx->has_minicpmv_projector) { | ||||||
|  |                 GGML_ASSERT(nx == image_size && ny == image_size); | ||||||
|  |             } | ||||||
|  |  | ||||||
|             const int n = nx * ny; |             const int n = nx * ny; | ||||||
|  |  | ||||||
| @@ -1881,37 +2312,75 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|         ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); |         ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); | ||||||
|         free(data); |         free(data); | ||||||
|     } |     } | ||||||
|  |     if (ctx->has_minicpmv_projector) { | ||||||
|     { |         { | ||||||
|         if (ctx->has_class_embedding) { |             // inspired from siglip: | ||||||
|             struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); |             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit | ||||||
|  |             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 | ||||||
|             void* zero_mem = malloc(ggml_nbytes(embeddings)); |             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); | ||||||
|             memset(zero_mem, 0, ggml_nbytes(embeddings)); |             int* positions_data = (int*)malloc(ggml_nbytes(positions)); | ||||||
|             ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); |             for (int i = 0; i < num_positions; i++) { | ||||||
|             free(zero_mem); |                 positions_data[i] = std::floor(70.0*i/num_positions); | ||||||
|  |             } | ||||||
|  |             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); | ||||||
|  |             free(positions_data); | ||||||
|         } |         } | ||||||
|     } |  | ||||||
|  |  | ||||||
|     { |         { | ||||||
|         struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); |             // inspired from resampler of Qwen-VL: | ||||||
|  |             //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main | ||||||
|  |             //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 | ||||||
|  |             struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); | ||||||
|  |             if(ctx->load_image_size==nullptr){ | ||||||
|  |                 ctx->load_image_size= clip_image_size_init(); | ||||||
|  |             } | ||||||
|  |             int pos_w = ctx->load_image_size->width/patch_size; | ||||||
|  |             int pos_h = ctx->load_image_size->height/patch_size; | ||||||
|  |             int embed_dim = 4096; | ||||||
|  |             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); | ||||||
|  |  | ||||||
|         int* positions_data = (int*)malloc(ggml_nbytes(positions)); |             float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); | ||||||
|         for (int i = 0; i < num_positions; i++) { |             for(int i=0;i<pos_w * pos_h;++i){ | ||||||
|             positions_data[i] = i; |                 for(int j=0;j<embed_dim;++j){ | ||||||
|  |                     pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); | ||||||
|  |             free(pos_embed_data); | ||||||
|         } |         } | ||||||
|         ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); |     } else { | ||||||
|         free(positions_data); |         { | ||||||
|     } |             if (ctx->has_class_embedding) { | ||||||
|  |                 struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); | ||||||
|  |  | ||||||
|     { |                 void* zero_mem = malloc(ggml_nbytes(embeddings)); | ||||||
|         struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); |                 memset(zero_mem, 0, ggml_nbytes(embeddings)); | ||||||
|         int* patches_data = (int*)malloc(ggml_nbytes(patches)); |                 ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); | ||||||
|         for (int i = 0; i < num_patches; i++) { |                 free(zero_mem); | ||||||
|             patches_data[i] = i + 1; |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         { | ||||||
|  |             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); | ||||||
|  |  | ||||||
|  |             int* positions_data = (int*)malloc(ggml_nbytes(positions)); | ||||||
|  |             for (int i = 0; i < num_positions; i++) { | ||||||
|  |                 positions_data[i] = i; | ||||||
|  |             } | ||||||
|  |             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); | ||||||
|  |             free(positions_data); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         { | ||||||
|  |             struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); | ||||||
|  |             int* patches_data = (int*)malloc(ggml_nbytes(patches)); | ||||||
|  |             for (int i = 0; i < num_patches; i++) { | ||||||
|  |                 patches_data[i] = i + 1; | ||||||
|  |             } | ||||||
|  |             ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); | ||||||
|  |             free(patches_data); | ||||||
|         } |         } | ||||||
|         ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); |  | ||||||
|         free(patches_data); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (ggml_backend_is_cpu(ctx->backend)) { |     if (ggml_backend_is_cpu(ctx->backend)) { | ||||||
| @@ -2081,7 +2550,14 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { | |||||||
|     if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { |     if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { | ||||||
|         return ctx->vision_model.mm_3_b->ne[0]; |         return ctx->vision_model.mm_3_b->ne[0]; | ||||||
|     } |     } | ||||||
|  |     if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { | ||||||
|  |         return 4096; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; |     std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; | ||||||
|     throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); |     throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool clip_is_minicpmv(const struct clip_ctx * ctx) { | ||||||
|  |     return ctx->has_minicpmv_projector; | ||||||
|  | } | ||||||
|   | |||||||
| @@ -18,14 +18,17 @@ | |||||||
| #    define CLIP_API | #    define CLIP_API | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| struct clip_ctx; |  | ||||||
|  |  | ||||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||||
| extern "C" { | extern "C" { | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| struct clip_ctx; | struct clip_ctx; | ||||||
|  |  | ||||||
|  | struct clip_image_size { | ||||||
|  |     int width; | ||||||
|  |     int height; | ||||||
|  | }; | ||||||
|  |  | ||||||
| struct clip_image_u8_batch { | struct clip_image_u8_batch { | ||||||
|     struct clip_image_u8 * data; |     struct clip_image_u8 * data; | ||||||
|     size_t size; |     size_t size; | ||||||
| @@ -55,6 +58,10 @@ CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); | |||||||
| CLIP_API int clip_n_patches    (const struct clip_ctx * ctx); | CLIP_API int clip_n_patches    (const struct clip_ctx * ctx); | ||||||
| CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); | CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); | ||||||
|  |  | ||||||
|  | CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); | ||||||
|  | CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); | ||||||
|  |  | ||||||
|  | CLIP_API struct clip_image_size * clip_image_size_init(); | ||||||
| CLIP_API struct clip_image_u8  * clip_image_u8_init (); | CLIP_API struct clip_image_u8  * clip_image_u8_init (); | ||||||
| CLIP_API struct clip_image_f32 * clip_image_f32_init(); | CLIP_API struct clip_image_f32 * clip_image_f32_init(); | ||||||
|  |  | ||||||
| @@ -78,6 +85,8 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons | |||||||
|  |  | ||||||
| CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); | CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); | ||||||
|  |  | ||||||
|  | CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx); | ||||||
|  |  | ||||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|   | |||||||
| @@ -202,6 +202,33 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { | ||||||
|  |     int width = image->nx; | ||||||
|  |     int height = image->ny; | ||||||
|  |     int num_patches = (height / patch_size) * (width / patch_size); | ||||||
|  |     clip_image_f32 * patch = clip_image_f32_init(); | ||||||
|  |     patch->nx = patch_size * num_patches; | ||||||
|  |     patch->ny = patch_size; | ||||||
|  |     patch->buf.resize(3 * patch->nx * patch->ny); | ||||||
|  |  | ||||||
|  |     int patch_index = 0; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < height; i += patch_size) { | ||||||
|  |         for (int j = 0; j < width; j += patch_size) { | ||||||
|  |             for (int pi = 0; pi < patch_size; ++pi) { | ||||||
|  |                 for (int pj = 0; pj < patch_size; ++pj) { | ||||||
|  |                     int input_index = ((i + pi) * width + (j + pj)) * 3; | ||||||
|  |                     int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3; | ||||||
|  |                     patch->buf[output_index] = image->buf[input_index]; | ||||||
|  |                     patch->buf[output_index+1] = image->buf[input_index+1]; | ||||||
|  |                     patch->buf[output_index+2] = image->buf[input_index+2]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             patch_index++; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return patch; | ||||||
|  | } | ||||||
|  |  | ||||||
| static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { | static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { | ||||||
|     // std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 |     // std::vector<clip_image_f32*> img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 | ||||||
| @@ -218,7 +245,44 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli | |||||||
|  |  | ||||||
|     const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); |     const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); | ||||||
|  |  | ||||||
|     if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { |     if (clip_is_minicpmv(ctx_clip)) { | ||||||
|  |         std::vector<float *> image_embd_v; | ||||||
|  |         image_embd_v.resize(img_res_v.size); | ||||||
|  |         struct clip_image_size * load_image_size = clip_image_size_init(); | ||||||
|  |         for (size_t i = 0; i < img_res_v.size; i++) { | ||||||
|  |             const int64_t t_img_enc_step_start_us = ggml_time_us(); | ||||||
|  |             image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); | ||||||
|  |             int patch_size=14; | ||||||
|  |             load_image_size->width = img_res_v.data[i].nx; | ||||||
|  |             load_image_size->height = img_res_v.data[i].ny; | ||||||
|  |             clip_add_load_image_size(ctx_clip, load_image_size); | ||||||
|  |             const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); | ||||||
|  |             if (!encoded) { | ||||||
|  |                 LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|  |             const int64_t t_img_enc_steop_batch_us = ggml_time_us(); | ||||||
|  |             LOG_TEE("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); | ||||||
|  |         } | ||||||
|  |         const int64_t t_img_enc_batch_us = ggml_time_us(); | ||||||
|  |         LOG_TEE("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); | ||||||
|  |  | ||||||
|  |         int n_img_pos_out = 0; | ||||||
|  |         for (size_t i = 0; i < image_embd_v.size(); i++) { | ||||||
|  |             std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip)); | ||||||
|  |             n_img_pos_out += clip_n_patches(ctx_clip); | ||||||
|  |         } | ||||||
|  |         *n_img_pos = n_img_pos_out; | ||||||
|  |         for (size_t i = 0; i < image_embd_v.size(); i++) { | ||||||
|  |             free(image_embd_v[i]); | ||||||
|  |         } | ||||||
|  |         image_embd_v.clear(); | ||||||
|  |         load_image_size->width = img->nx; | ||||||
|  |         load_image_size->height = img->ny; | ||||||
|  |         clip_add_load_image_size(ctx_clip, load_image_size); | ||||||
|  |         LOG_TEE("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); | ||||||
|  |     } | ||||||
|  |     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { | ||||||
|         // flat / default llava-1.5 type embedding |         // flat / default llava-1.5 type embedding | ||||||
|         *n_img_pos = clip_n_patches(ctx_clip); |         *n_img_pos = clip_n_patches(ctx_clip); | ||||||
|         bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 |         bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 | ||||||
| @@ -228,7 +292,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli | |||||||
|  |  | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|     } else { |     } | ||||||
|  |     else { | ||||||
|         // spatial_unpad llava-1.6 type embedding |         // spatial_unpad llava-1.6 type embedding | ||||||
|         // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working |         // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working | ||||||
|         std::vector<float *> image_embd_v; |         std::vector<float *> image_embd_v; | ||||||
| @@ -297,7 +362,11 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * | |||||||
| } | } | ||||||
|  |  | ||||||
| bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { | bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { | ||||||
|     float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*6); // TODO: base on gridsize/llava model |     int num_max_patches = 6; | ||||||
|  |     if (clip_is_minicpmv(ctx_clip)) { | ||||||
|  |         num_max_patches = 10; | ||||||
|  |     } | ||||||
|  |     float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model | ||||||
|     if (!image_embd) { |     if (!image_embd) { | ||||||
|         LOG_TEE("Unable to allocate memory for image embeddings\n"); |         LOG_TEE("Unable to allocate memory for image embeddings\n"); | ||||||
|         return false; |         return false; | ||||||
|   | |||||||
| @@ -17,12 +17,11 @@ | |||||||
| #    define LLAVA_API | #    define LLAVA_API | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| struct clip_ctx; |  | ||||||
|  |  | ||||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||||
| extern "C" { | extern "C" { | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | struct clip_ctx; | ||||||
| struct llava_image_embed { | struct llava_image_embed { | ||||||
|     float * embed; |     float * embed; | ||||||
|     int n_image_pos; |     int n_image_pos; | ||||||
| @@ -37,8 +36,8 @@ LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, | |||||||
| LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); | LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); | ||||||
| /** build an image embed from a path to an image filename */ | /** build an image embed from a path to an image filename */ | ||||||
| LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); | LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); | ||||||
| LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); |  | ||||||
| /** free an embedding made with llava_image_embed_make_* */ | /** free an embedding made with llava_image_embed_make_* */ | ||||||
|  | LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); | ||||||
|  |  | ||||||
| /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ | /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ | ||||||
| LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); | LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); | ||||||
|   | |||||||
							
								
								
									
										309
									
								
								examples/llava/minicpmv-cli.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										309
									
								
								examples/llava/minicpmv-cli.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,309 @@ | |||||||
|  | #include "ggml.h" | ||||||
|  | #include "log.h" | ||||||
|  | #include "common.h" | ||||||
|  | #include "clip.h" | ||||||
|  | #include "llava.h" | ||||||
|  | #include "llama.h" | ||||||
|  |  | ||||||
|  | #include <cstdio> | ||||||
|  | #include <cstdlib> | ||||||
|  | #include <vector> | ||||||
|  |  | ||||||
|  | struct llava_context { | ||||||
|  |     struct clip_ctx * ctx_clip = NULL; | ||||||
|  |     struct llama_context * ctx_llama = NULL; | ||||||
|  |     struct llama_model * model = NULL; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | static void show_additional_info(int /*argc*/, char ** argv) { | ||||||
|  |     LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); | ||||||
|  |     LOG_TEE("  note: a lower temperature value like 0.1 is recommended for better quality.\n"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { | ||||||
|  |     (void) level; | ||||||
|  |     (void) user_data; | ||||||
|  |     LOG_TEE("%s", text); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llama_model * llava_init(gpt_params * params) { | ||||||
|  |     llama_backend_init(); | ||||||
|  |     llama_numa_init(params->numa); | ||||||
|  |  | ||||||
|  |     llama_model_params model_params = llama_model_params_from_gpt_params(*params); | ||||||
|  |  | ||||||
|  |     llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); | ||||||
|  |     if (model == NULL) { | ||||||
|  |         LOG_TEE("%s: error: unable to load model\n" , __func__); | ||||||
|  |         return NULL; | ||||||
|  |     } | ||||||
|  |     return model; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { | ||||||
|  |     auto prompt = params->prompt; | ||||||
|  |     if (prompt.empty()) { | ||||||
|  |         prompt = "describe the image in detail."; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); | ||||||
|  |     if (params->n_ctx < 2048) { | ||||||
|  |         // warn user here, "Image processing requires at least 2048 context, setting context to 2048" | ||||||
|  |         LOG_TEE("%s: warn: Image processing requires at least 2048 context, setting context to 2048\n" , __func__); | ||||||
|  |         ctx_params.n_ctx = 2048; | ||||||
|  |     } else { | ||||||
|  |         ctx_params.n_ctx = params->n_ctx; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); | ||||||
|  |  | ||||||
|  |     if (ctx_llama == NULL) { | ||||||
|  |         LOG_TEE("%s: error: failed to create the llama_context\n" , __func__); | ||||||
|  |         return NULL; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); | ||||||
|  |  | ||||||
|  |     ctx_llava->ctx_llama = ctx_llama; | ||||||
|  |     ctx_llava->model = model; | ||||||
|  |     return ctx_llava; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llava_free(struct llava_context * ctx_llava) { | ||||||
|  |     if (ctx_llava->ctx_clip) { | ||||||
|  |         clip_free(ctx_llava->ctx_clip); | ||||||
|  |         ctx_llava->ctx_clip = NULL; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     llama_free(ctx_llava->ctx_llama); | ||||||
|  |     llama_free_model(ctx_llava->model); | ||||||
|  |     llama_backend_free(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct clip_ctx * clip_init_context(gpt_params * params) { | ||||||
|  |     const char * clip_path = params->mmproj.c_str(); | ||||||
|  |  | ||||||
|  |     auto prompt = params->prompt; | ||||||
|  |     if (prompt.empty()) { | ||||||
|  |         prompt = "describe the image in detail."; | ||||||
|  |     } | ||||||
|  |     auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); | ||||||
|  |     return ctx_clip; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) { | ||||||
|  |     int N = (int) tokens.size(); | ||||||
|  |     for (int i = 0; i < N; i += n_batch) { | ||||||
|  |         int n_eval = (int) tokens.size() - i; | ||||||
|  |         if (n_eval > n_batch) { | ||||||
|  |             n_eval = n_batch; | ||||||
|  |         } | ||||||
|  |         if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { | ||||||
|  |             LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |         *n_past += n_eval; | ||||||
|  |     } | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { | ||||||
|  |     std::vector<llama_token> tokens; | ||||||
|  |     tokens.push_back(id); | ||||||
|  |     return eval_tokens(ctx_llama, tokens, 1, n_past); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ | ||||||
|  |     std::string              str2     = str; | ||||||
|  |     std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); | ||||||
|  |     return eval_tokens(ctx_llama, embd_inp, n_batch, n_past); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void process_eval_image_embed(struct llava_context * ctx_llava, const struct llava_image_embed * embeds, int n_batch, int * n_past, int idx) { | ||||||
|  |     float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip)); | ||||||
|  |     std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip)); | ||||||
|  |  | ||||||
|  |     auto slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed)); | ||||||
|  |     slice_embed->embed = image_embed; | ||||||
|  |     slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip); | ||||||
|  |     llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past); | ||||||
|  |     llava_image_embed_free(slice_embed); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, gpt_params * params, int &n_past) { | ||||||
|  |     std::string system_prompt; | ||||||
|  |     int idx = 0; | ||||||
|  |     int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); | ||||||
|  |     system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; | ||||||
|  |     LOG_TEE("%s: image token past: %d\n", __func__, n_past); | ||||||
|  |     eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |     process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); | ||||||
|  |     eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |     if (num_image_embeds > 1) { | ||||||
|  |         size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip); | ||||||
|  |         eval_string(ctx_llava->ctx_llama, std::string("<slice>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |         for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) { | ||||||
|  |             for (size_t j = 0; j < num_image_embeds_col; ++j) { | ||||||
|  |                 eval_string(ctx_llava->ctx_llama, std::string("<image>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |                 process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); | ||||||
|  |                 eval_string(ctx_llava->ctx_llama, std::string("</image>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |                 if (j == num_image_embeds_col - 1) { | ||||||
|  |                     eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         eval_string(ctx_llava->ctx_llama, std::string("</slice>").c_str(), params->n_batch, &n_past, false); | ||||||
|  |     } | ||||||
|  |     LOG_TEE("%s: image token past: %d\n", __func__, n_past); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static const char * sample(struct llama_sampling_context * ctx_sampling, | ||||||
|  |                            struct llama_context * ctx_llama, | ||||||
|  |                            int * n_past) { | ||||||
|  |     const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); | ||||||
|  |     llama_sampling_accept(ctx_sampling, ctx_llama, id, true); | ||||||
|  |     static std::string ret; | ||||||
|  |     if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { | ||||||
|  |         ret = "</s>"; | ||||||
|  |     } else { | ||||||
|  |         ret = llama_token_to_piece(ctx_llama, id); | ||||||
|  |     } | ||||||
|  |     eval_id(ctx_llama, id, n_past); | ||||||
|  |     return ret.c_str(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ | ||||||
|  |     auto ctx_clip = clip_init_context(params); | ||||||
|  |     auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->n_threads, fname.c_str()); | ||||||
|  |     if (!embeds) { | ||||||
|  |         std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; | ||||||
|  |         return NULL; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // process the prompt | ||||||
|  |     if (params->prompt.empty() && params->interactive == false) { | ||||||
|  |         LOG_TEE("prompt should be given or interactive mode should be on"); | ||||||
|  |         return NULL; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto model = llava_init(params); | ||||||
|  |     if (model == NULL) { | ||||||
|  |         fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); | ||||||
|  |         return NULL; | ||||||
|  |     } | ||||||
|  |     const int64_t t_llava_init_start_us = ggml_time_us(); | ||||||
|  |     auto ctx_llava = llava_init_context(params, model); | ||||||
|  |     ctx_llava->ctx_clip = ctx_clip; | ||||||
|  |     const int64_t t_llava_init_end_us = ggml_time_us(); | ||||||
|  |     float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; | ||||||
|  |     LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); | ||||||
|  |  | ||||||
|  |     const int64_t t_process_image_start_us = ggml_time_us(); | ||||||
|  |     process_image(ctx_llava, embeds, params, n_past); | ||||||
|  |     const int64_t t_process_image_end_us = ggml_time_us(); | ||||||
|  |     float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; | ||||||
|  |     LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); | ||||||
|  |  | ||||||
|  |     llava_image_embed_free(embeds); | ||||||
|  |     return ctx_llava; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ | ||||||
|  |     std::string user_prompt = prompt; | ||||||
|  |     if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; | ||||||
|  |  | ||||||
|  |     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); | ||||||
|  |     eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); | ||||||
|  |     // generate the response | ||||||
|  |  | ||||||
|  |     LOG_TEE("\n"); | ||||||
|  |  | ||||||
|  |     struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); | ||||||
|  |     return ctx_sampling; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ | ||||||
|  |  | ||||||
|  |     const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); | ||||||
|  |     return tmp; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int main(int argc, char ** argv) { | ||||||
|  |     ggml_time_init(); | ||||||
|  |  | ||||||
|  |     gpt_params params; | ||||||
|  |  | ||||||
|  |     if (!gpt_params_parse(argc, argv, params)) { | ||||||
|  |         show_additional_info(argc, argv); | ||||||
|  |         return 1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | #ifndef LOG_DISABLE_LOGS | ||||||
|  |     log_set_target(log_filename_generator("llava", "log")); | ||||||
|  |     LOG_TEE("Log start\n"); | ||||||
|  |     log_dump_cmdline(argc, argv); | ||||||
|  |     llama_log_set(llama_log_callback_logTee, nullptr); | ||||||
|  | #endif // LOG_DISABLE_LOGS | ||||||
|  |  | ||||||
|  |     if (params.mmproj.empty() || (params.image.empty())) { | ||||||
|  |         gpt_params_print_usage(argc, argv, params); | ||||||
|  |         show_additional_info(argc, argv); | ||||||
|  |         return 1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (auto & image : params.image) { | ||||||
|  |         int n_past = 0; | ||||||
|  |         auto ctx_llava = minicpmv_init(¶ms, image, n_past); | ||||||
|  |  | ||||||
|  |         if (!params.prompt.empty()) { | ||||||
|  |             LOG_TEE("<user>%s\n", params.prompt.c_str()); | ||||||
|  |             LOG_TEE("<assistant>"); | ||||||
|  |             auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); | ||||||
|  |             const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; | ||||||
|  |             std::string response = ""; | ||||||
|  |             bool have_tmp = false; | ||||||
|  |             for (int i = 0; i < max_tgt_len; i++) { | ||||||
|  |                 auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); | ||||||
|  |                 response += tmp; | ||||||
|  |                 if (strcmp(tmp, "</s>") == 0){ | ||||||
|  |                     if(!have_tmp)continue; | ||||||
|  |                     else break; | ||||||
|  |                 } | ||||||
|  |                 if (strstr(tmp, "###")) break; // Yi-VL behavior | ||||||
|  |                 have_tmp = true; | ||||||
|  |                 printf("%s", tmp); | ||||||
|  |                 if (strstr(response.c_str(), "<user>")) break; // minicpm-v | ||||||
|  |  | ||||||
|  |                 fflush(stdout); | ||||||
|  |             } | ||||||
|  |             llama_sampling_free(ctx_sampling); | ||||||
|  |         }else { | ||||||
|  |             while (true) { | ||||||
|  |                 LOG_TEE("<user>"); | ||||||
|  |                 std::string prompt; | ||||||
|  |                 std::getline(std::cin, prompt); | ||||||
|  |                 LOG_TEE("<assistant>"); | ||||||
|  |                 auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); | ||||||
|  |                 const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; | ||||||
|  |                 std::string response = ""; | ||||||
|  |                 for (int i = 0; i < max_tgt_len; i++) { | ||||||
|  |                     auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); | ||||||
|  |                     response += tmp; | ||||||
|  |                     if (strcmp(tmp, "</s>") == 0) break; | ||||||
|  |                     if (strstr(tmp, "###")) break; // Yi-VL behavior | ||||||
|  |                     printf("%s", tmp);// mistral llava-1.6 | ||||||
|  |                     if (strstr(response.c_str(), "<user>")) break; // minicpm-v | ||||||
|  |                     fflush(stdout); | ||||||
|  |                 } | ||||||
|  |                 llama_sampling_free(ctx_sampling); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         printf("\n"); | ||||||
|  |         llama_print_timings(ctx_llava->ctx_llama); | ||||||
|  |  | ||||||
|  |         ctx_llava->model = NULL; | ||||||
|  |         llava_free(ctx_llava); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return 0; | ||||||
|  | } | ||||||
							
								
								
									
										382
									
								
								examples/llava/minicpmv-convert-image-encoder-to-gguf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								examples/llava/minicpmv-convert-image-encoder-to-gguf.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,382 @@ | |||||||
|  | import argparse | ||||||
|  | import os | ||||||
|  | import json | ||||||
|  | import re | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  | from gguf import * | ||||||
|  | from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig | ||||||
|  |  | ||||||
|  | TEXT = "clip.text" | ||||||
|  | VISION = "clip.vision" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def add_key_str(raw_key: str, arch: str) -> str: | ||||||
|  |     return raw_key.format(arch=arch) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: | ||||||
|  |     if name in ( | ||||||
|  |         "logit_scale", | ||||||
|  |         "text_model.embeddings.position_ids", | ||||||
|  |         "vision_model.embeddings.position_ids", | ||||||
|  |     ): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     if has_minicpmv and name in ["visual_projection.weight"]: | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     if name.startswith("v") and not has_vision: | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     if name.startswith("t") and not has_text: | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     return False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_tensor_name(name: str) -> str: | ||||||
|  |     if "projection" in name: | ||||||
|  |         return name | ||||||
|  |     if "mm_projector" in name: | ||||||
|  |         name = name.replace("model.mm_projector", "mm") | ||||||
|  |         name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) | ||||||
|  |         name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) | ||||||
|  |         return name | ||||||
|  |  | ||||||
|  |     return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def bytes_to_unicode(): | ||||||
|  |     """ | ||||||
|  |     Returns list of utf-8 byte and a corresponding list of unicode strings. | ||||||
|  |     The reversible bpe codes work on unicode strings. | ||||||
|  |     This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | ||||||
|  |     When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | ||||||
|  |     This is a significant percentage of your normal, say, 32K bpe vocab. | ||||||
|  |     To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | ||||||
|  |     And avoids mapping to whitespace/control characters the bpe code barfs on. | ||||||
|  |     """ | ||||||
|  |     bs = ( | ||||||
|  |         list(range(ord("!"), ord("~") + 1)) | ||||||
|  |         + list(range(ord("¡"), ord("¬") + 1)) | ||||||
|  |         + list(range(ord("®"), ord("ÿ") + 1)) | ||||||
|  |     ) | ||||||
|  |     cs = bs[:] | ||||||
|  |     n = 0 | ||||||
|  |     for b in range(2**8): | ||||||
|  |         if b not in bs: | ||||||
|  |             bs.append(b) | ||||||
|  |             cs.append(2**8 + n) | ||||||
|  |             n += 1 | ||||||
|  |     cs = [chr(n) for n in cs] | ||||||
|  |     return dict(zip(bs, cs)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | ap = argparse.ArgumentParser() | ||||||
|  | ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) | ||||||
|  | ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") | ||||||
|  | ap.add_argument("--text-only", action="store_true", required=False, | ||||||
|  |                 help="Save a text-only model. It can't be used to encode images") | ||||||
|  | ap.add_argument("--vision-only", action="store_true", required=False, | ||||||
|  |                 help="Save a vision-only model. It can't be used to encode texts") | ||||||
|  | ap.add_argument("--clip-model-is-vision", action="store_true", required=False, | ||||||
|  |                 help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") | ||||||
|  | ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, | ||||||
|  |                 help="The clip model is from openclip (for ViT-SO400M type))") | ||||||
|  | ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.") | ||||||
|  | ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") | ||||||
|  | ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) | ||||||
|  | # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 | ||||||
|  | # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 | ||||||
|  | default_image_mean = [0.48145466, 0.4578275, 0.40821073] | ||||||
|  | default_image_std = [0.26862954, 0.26130258, 0.27577711] | ||||||
|  | ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) | ||||||
|  | ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) | ||||||
|  |  | ||||||
|  | # with proper | ||||||
|  | args = ap.parse_args() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if args.text_only and args.vision_only: | ||||||
|  |     print("--text-only and --image-only arguments cannot be specified at the same time.") | ||||||
|  |     exit(1) | ||||||
|  |  | ||||||
|  | if args.use_f32: | ||||||
|  |     print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") | ||||||
|  |  | ||||||
|  | # output in the same directory as the model if output_dir is None | ||||||
|  | dir_model = args.model_dir | ||||||
|  |  | ||||||
|  | if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: | ||||||
|  |     vocab = None | ||||||
|  |     tokens = None | ||||||
|  | else: | ||||||
|  |     with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: | ||||||
|  |         vocab = json.load(f) | ||||||
|  |         tokens = [key for key in vocab] | ||||||
|  |  | ||||||
|  | # possible data types | ||||||
|  | #   ftype == 0 -> float32 | ||||||
|  | #   ftype == 1 -> float16 | ||||||
|  | # | ||||||
|  | # map from ftype to string | ||||||
|  | ftype_str = ["f32", "f16"] | ||||||
|  |  | ||||||
|  | ftype = 1 | ||||||
|  | if args.use_f32: | ||||||
|  |     ftype = 0 | ||||||
|  |  | ||||||
|  | # if args.clip_model_is_vision or args.clip_model_is_openclip: | ||||||
|  | #     model = CLIPVisionModel.from_pretrained(dir_model) | ||||||
|  | #     processor = None | ||||||
|  | # else: | ||||||
|  | #     model = CLIPModel.from_pretrained(dir_model) | ||||||
|  | #     processor = CLIPProcessor.from_pretrained(dir_model) | ||||||
|  |  | ||||||
|  | default_vision_config = { | ||||||
|  |         "hidden_size": 1152, | ||||||
|  |         "image_size": 980, | ||||||
|  |         "intermediate_size": 4304, | ||||||
|  |         "model_type": "idefics2", | ||||||
|  |         "num_attention_heads": 16, | ||||||
|  |         "num_hidden_layers": 27, | ||||||
|  |         "patch_size": 14, | ||||||
|  |     } | ||||||
|  | vision_config = Idefics2VisionConfig(**default_vision_config) | ||||||
|  | model = Idefics2VisionTransformer(vision_config) | ||||||
|  |  | ||||||
|  | processor = None | ||||||
|  | # if model.attn_pool is not None: | ||||||
|  | #     model.attn_pool = torch.nn.Identity() | ||||||
|  |  | ||||||
|  | # model.blocks = model.blocks[:-1] | ||||||
|  | model.load_state_dict(torch.load(os.path.join(dir_model, "minicpmv.clip"))) | ||||||
|  |  | ||||||
|  | fname_middle = None | ||||||
|  | has_text_encoder = True | ||||||
|  | has_vision_encoder = True | ||||||
|  | has_minicpmv_projector = False | ||||||
|  | if args.text_only: | ||||||
|  |     fname_middle = "text-" | ||||||
|  |     has_vision_encoder = False | ||||||
|  | elif args.minicpmv_projector is not None: | ||||||
|  |     fname_middle = "mmproj-" | ||||||
|  |     has_text_encoder = False | ||||||
|  |     has_minicpmv_projector = True | ||||||
|  | elif args.vision_only: | ||||||
|  |     fname_middle = "vision-" | ||||||
|  |     has_text_encoder = False | ||||||
|  | else: | ||||||
|  |     fname_middle = "" | ||||||
|  |  | ||||||
|  | output_dir = args.output_dir if args.output_dir is not None else dir_model | ||||||
|  | os.makedirs(output_dir, exist_ok=True) | ||||||
|  | output_prefix = os.path.basename(output_dir).replace("ggml_", "") | ||||||
|  | fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") | ||||||
|  | fout = GGUFWriter(path=fname_out, arch="clip") | ||||||
|  |  | ||||||
|  | fout.add_bool("clip.has_text_encoder", has_text_encoder) | ||||||
|  | fout.add_bool("clip.has_vision_encoder", has_vision_encoder) | ||||||
|  | fout.add_bool("clip.has_minicpmv_projector", has_minicpmv_projector) | ||||||
|  | fout.add_file_type(ftype) | ||||||
|  | if args.text_only: | ||||||
|  |     fout.add_description("text-only CLIP model") | ||||||
|  | elif args.vision_only and not has_minicpmv_projector: | ||||||
|  |     fout.add_description("vision-only CLIP model") | ||||||
|  | elif has_minicpmv_projector: | ||||||
|  |     fout.add_description("image encoder for MiniCPM-V") | ||||||
|  |     # add projector type | ||||||
|  |     fout.add_string("clip.projector_type", "resampler") | ||||||
|  | else: | ||||||
|  |     fout.add_description("two-tower CLIP model") | ||||||
|  |  | ||||||
|  | if has_vision_encoder: | ||||||
|  |     # vision_model hparams | ||||||
|  |     fout.add_uint32("clip.vision.image_size", 448) | ||||||
|  |     fout.add_uint32("clip.vision.patch_size", 14) | ||||||
|  |     fout.add_uint32(add_key_str(KEY_EMBEDDING_LENGTH, VISION), 1152) | ||||||
|  |     fout.add_uint32(add_key_str(KEY_FEED_FORWARD_LENGTH, VISION), 4304) | ||||||
|  |     fout.add_uint32("clip.vision.projection_dim", 0) | ||||||
|  |     fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) | ||||||
|  |     fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) | ||||||
|  |     block_count = 26 | ||||||
|  |     fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) | ||||||
|  |  | ||||||
|  |     if processor is not None: | ||||||
|  |         image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean | ||||||
|  |         image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std | ||||||
|  |     else: | ||||||
|  |         image_mean = args.image_mean if args.image_mean is not None else default_image_mean | ||||||
|  |         image_std = args.image_std if args.image_std is not None else default_image_std | ||||||
|  |     fout.add_array("clip.vision.image_mean", image_mean) | ||||||
|  |     fout.add_array("clip.vision.image_std", image_std) | ||||||
|  |  | ||||||
|  | use_gelu = True | ||||||
|  | fout.add_bool("clip.use_gelu", use_gelu) | ||||||
|  |  | ||||||
|  | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | ||||||
|  |     """ | ||||||
|  |     embed_dim: output dimension for each position | ||||||
|  |     pos: a list of positions to be encoded: size (M,) | ||||||
|  |     out: (M, D) | ||||||
|  |     """ | ||||||
|  |     assert embed_dim % 2 == 0 | ||||||
|  |     omega = np.arange(embed_dim // 2, dtype=np.float32) | ||||||
|  |     omega /= embed_dim / 2. | ||||||
|  |     omega = 1. / 10000 ** omega  # (D/2,) | ||||||
|  |  | ||||||
|  |     pos = pos.reshape(-1)  # (M,) | ||||||
|  |     out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product | ||||||
|  |  | ||||||
|  |     emb_sin = np.sin(out)  # (M, D/2) | ||||||
|  |     emb_cos = np.cos(out)  # (M, D/2) | ||||||
|  |  | ||||||
|  |     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D) | ||||||
|  |     return emb | ||||||
|  |  | ||||||
|  | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | ||||||
|  |     assert embed_dim % 2 == 0 | ||||||
|  |  | ||||||
|  |     # use half of dimensions to encode grid_h | ||||||
|  |     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2) | ||||||
|  |     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2) | ||||||
|  |  | ||||||
|  |     emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D) | ||||||
|  |     return emb | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 | ||||||
|  | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | ||||||
|  |     """ | ||||||
|  |     grid_size: int of the grid height and width | ||||||
|  |     return: | ||||||
|  |     pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | ||||||
|  |     """ | ||||||
|  |     if isinstance(grid_size, int): | ||||||
|  |         grid_h_size, grid_w_size = grid_size, grid_size | ||||||
|  |     else: | ||||||
|  |         grid_h_size, grid_w_size = grid_size[0], grid_size[1] | ||||||
|  |  | ||||||
|  |     grid_h = np.arange(grid_h_size, dtype=np.float32) | ||||||
|  |     grid_w = np.arange(grid_w_size, dtype=np.float32) | ||||||
|  |     grid = np.meshgrid(grid_w, grid_h)  # here w goes first | ||||||
|  |     grid = np.stack(grid, axis=0) | ||||||
|  |  | ||||||
|  |     grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) | ||||||
|  |     pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | ||||||
|  |     if cls_token: | ||||||
|  |         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | ||||||
|  |     return pos_embed | ||||||
|  |  | ||||||
|  | def _replace_name_resampler(s, v): | ||||||
|  |     if re.match("resampler.pos_embed", s): | ||||||
|  |         return { | ||||||
|  |             s: v, | ||||||
|  |             re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), | ||||||
|  |         } | ||||||
|  |     if re.match("resampler.proj", s): | ||||||
|  |         return { | ||||||
|  |             re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))), | ||||||
|  |             re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), | ||||||
|  |         } | ||||||
|  |     if re.match("resampler.attn.in_proj_.*", s): | ||||||
|  |         return { | ||||||
|  |             re.sub("attn.in_proj_", "attn.q.", s): v.chunk(3, dim=0)[0], | ||||||
|  |             re.sub("attn.in_proj_", "attn.k.", s): v.chunk(3, dim=0)[1], | ||||||
|  |             re.sub("attn.in_proj_", "attn.v.", s): v.chunk(3, dim=0)[2], | ||||||
|  |         } | ||||||
|  |     return {s: v} | ||||||
|  |  | ||||||
|  | if has_minicpmv_projector: | ||||||
|  |     projector = torch.load(args.minicpmv_projector) | ||||||
|  |     new_state_dict = {} | ||||||
|  |     for k, v in projector.items(): | ||||||
|  |         kvs = _replace_name_resampler(k, v) | ||||||
|  |         for nk, nv in kvs.items(): | ||||||
|  |             new_state_dict[nk] = nv | ||||||
|  |     projector = new_state_dict | ||||||
|  |     ftype_cur = 0 | ||||||
|  |     for name, data in projector.items(): | ||||||
|  |         name = get_tensor_name(name) | ||||||
|  |         data = data.squeeze().numpy() | ||||||
|  |  | ||||||
|  |         n_dims = len(data.shape) | ||||||
|  |         if ftype == 1: | ||||||
|  |             if name[-7:] == ".weight" and n_dims == 2: | ||||||
|  |                 print("  Converting to float16") | ||||||
|  |                 data = data.astype(np.float16) | ||||||
|  |                 ftype_cur = 1 | ||||||
|  |             else: | ||||||
|  |                 print("  Converting to float32") | ||||||
|  |                 data = data.astype(np.float32) | ||||||
|  |                 ftype_cur = 0 | ||||||
|  |         else: | ||||||
|  |             if data.dtype != np.float32: | ||||||
|  |                 print("  Converting to float32") | ||||||
|  |                 data = data.astype(np.float32) | ||||||
|  |                 ftype_cur = 0 | ||||||
|  |  | ||||||
|  |         fout.add_tensor(name, data) | ||||||
|  |         print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") | ||||||
|  |  | ||||||
|  |     print("Projector tensors added\n") | ||||||
|  |  | ||||||
|  | def _replace_name(s, v): | ||||||
|  |     s = "vision_model." + s | ||||||
|  |     if re.match("vision_model.embeddings.position_embedding", s): | ||||||
|  |         v = v.unsqueeze(0) | ||||||
|  |         return {s: v} | ||||||
|  |  | ||||||
|  |     return {s: v} | ||||||
|  |  | ||||||
|  | state_dict = model.state_dict() | ||||||
|  | new_state_dict = {} | ||||||
|  | for k, v in state_dict.items(): | ||||||
|  |     kvs = _replace_name(k, v) | ||||||
|  |     for nk, nv in kvs.items(): | ||||||
|  |         new_state_dict[nk] = nv | ||||||
|  | state_dict = new_state_dict | ||||||
|  | for name, data in state_dict.items(): | ||||||
|  |     if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): | ||||||
|  |         # we don't need this | ||||||
|  |         print(f"skipping parameter: {name}") | ||||||
|  |         continue | ||||||
|  |  | ||||||
|  |     name = get_tensor_name(name) | ||||||
|  |     data = data.squeeze().numpy() | ||||||
|  |  | ||||||
|  |     n_dims = len(data.shape) | ||||||
|  |  | ||||||
|  |     # ftype == 0 -> float32, ftype == 1 -> float16 | ||||||
|  |     ftype_cur = 0 | ||||||
|  |     if n_dims == 4: | ||||||
|  |         print(f"tensor {name} is always saved in f16") | ||||||
|  |         data = data.astype(np.float16) | ||||||
|  |         ftype_cur = 1 | ||||||
|  |     elif ftype == 1: | ||||||
|  |         if name[-7:] == ".weight" and n_dims == 2: | ||||||
|  |             print("  Converting to float16") | ||||||
|  |             data = data.astype(np.float16) | ||||||
|  |             ftype_cur = 1 | ||||||
|  |         else: | ||||||
|  |             print("  Converting to float32") | ||||||
|  |             data = data.astype(np.float32) | ||||||
|  |             ftype_cur = 0 | ||||||
|  |     else: | ||||||
|  |         if data.dtype != np.float32: | ||||||
|  |             print("  Converting to float32") | ||||||
|  |             data = data.astype(np.float32) | ||||||
|  |             ftype_cur = 0 | ||||||
|  |  | ||||||
|  |     print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") | ||||||
|  |     fout.add_tensor(name, data) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | fout.write_header_to_file() | ||||||
|  | fout.write_kv_data_to_file() | ||||||
|  | fout.write_tensors_to_file() | ||||||
|  | fout.close() | ||||||
|  |  | ||||||
|  | print("Done. Output file: " + fname_out) | ||||||
							
								
								
									
										47
									
								
								examples/llava/minicpmv-surgery.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								examples/llava/minicpmv-surgery.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | import argparse | ||||||
|  | import os | ||||||
|  | import torch | ||||||
|  | from transformers import AutoModel, AutoTokenizer | ||||||
|  |  | ||||||
|  | ap = argparse.ArgumentParser() | ||||||
|  | ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model") | ||||||
|  | args = ap.parse_args() | ||||||
|  |  | ||||||
|  | # find the model part that includes the the multimodal projector weights | ||||||
|  | model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) | ||||||
|  | checkpoint = model.state_dict() | ||||||
|  |  | ||||||
|  | # get a list of mm tensor names | ||||||
|  | mm_tensors = [k for k, v in checkpoint.items() if k.startswith("resampler")] | ||||||
|  |  | ||||||
|  | # store these tensors in a new dictionary and torch.save them | ||||||
|  | projector = {name: checkpoint[name].float() for name in mm_tensors} | ||||||
|  | torch.save(projector, f"{args.model}/minicpmv.projector") | ||||||
|  |  | ||||||
|  | clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vpm")] | ||||||
|  | if len(clip_tensors) > 0: | ||||||
|  |     clip = {name.replace("vpm.", ""): checkpoint[name].float() for name in clip_tensors} | ||||||
|  |     torch.save(clip, f"{args.model}/minicpmv.clip") | ||||||
|  |  | ||||||
|  |     # added tokens should be removed to be able to convert Mistral models | ||||||
|  |     if os.path.exists(f"{args.model}/added_tokens.json"): | ||||||
|  |         with open(f"{args.model}/added_tokens.json", "w") as f: | ||||||
|  |             f.write("{}\n") | ||||||
|  |  | ||||||
|  | config = model.llm.config | ||||||
|  | config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5" | ||||||
|  | config.auto_map = { | ||||||
|  |     "AutoConfig": "configuration_minicpm.MiniCPMConfig", | ||||||
|  |     "AutoModel": "modeling_minicpm.MiniCPMModel", | ||||||
|  |     "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", | ||||||
|  |     "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", | ||||||
|  |     "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" | ||||||
|  | } | ||||||
|  | model.llm.save_pretrained(f"{args.model}/model") | ||||||
|  | tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | ||||||
|  | tok.save_pretrained(f"{args.model}/model") | ||||||
|  | # os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py") | ||||||
|  |  | ||||||
|  | print("Done!") | ||||||
|  | print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") | ||||||
|  | print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.") | ||||||
| @@ -2,3 +2,4 @@ | |||||||
| --extra-index-url https://download.pytorch.org/whl/cpu | --extra-index-url https://download.pytorch.org/whl/cpu | ||||||
| pillow~=10.2.0 | pillow~=10.2.0 | ||||||
| torch~=2.2.1 | torch~=2.2.1 | ||||||
|  | torchvision==0.17.1 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user