mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llava : support for Yi-VL and fix for mobileVLM (#5093)
* Support for Yi-VL, templating fix for mobileVLM * ws * Update examples/llava/clip.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update llava-cli.cpp * Update clip.cpp bugfix for new conversions --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -98,6 +98,7 @@ static std::string format(const char * fmt, ...) { | ||||
|  | ||||
| enum projector_type { | ||||
|     PROJECTOR_TYPE_MLP, | ||||
|     PROJECTOR_TYPE_MLP_NORM, | ||||
|     PROJECTOR_TYPE_LDP, | ||||
|     PROJECTOR_TYPE_UNKNOWN, | ||||
| }; | ||||
| @@ -304,10 +305,18 @@ struct clip_vision_model { | ||||
|     struct ggml_tensor * projection; | ||||
|  | ||||
|     // LLaVA projection | ||||
|     struct ggml_tensor * mm_0_w; | ||||
|     struct ggml_tensor * mm_0_b; | ||||
|     struct ggml_tensor * mm_2_w; | ||||
|     struct ggml_tensor * mm_2_b; | ||||
|     struct ggml_tensor * mm_0_w = NULL; | ||||
|     struct ggml_tensor * mm_0_b = NULL; | ||||
|     struct ggml_tensor * mm_2_w = NULL; | ||||
|     struct ggml_tensor * mm_2_b = NULL; | ||||
|  | ||||
|     // Yi type models with mlp+normalization projection | ||||
|     struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4 | ||||
|     struct ggml_tensor * mm_1_b = NULL; | ||||
|     struct ggml_tensor * mm_3_w = NULL; | ||||
|     struct ggml_tensor * mm_3_b = NULL; | ||||
|     struct ggml_tensor * mm_4_w = NULL; | ||||
|     struct ggml_tensor * mm_4_b = NULL; | ||||
|  | ||||
|     // MobileVLM projection | ||||
|     struct ggml_tensor * mm_model_mlp_1_w; | ||||
| @@ -460,6 +469,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | ||||
|     // pre-layernorm | ||||
|     { | ||||
|         embeddings = ggml_norm(ctx0, embeddings, eps); | ||||
|         ggml_set_name(embeddings, "pre_ln"); | ||||
|  | ||||
|         embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); | ||||
|     } | ||||
| @@ -575,6 +585,27 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | ||||
|  | ||||
|             embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); | ||||
|             embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); | ||||
|  | ||||
|         } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { | ||||
|             embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); | ||||
|             embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); | ||||
|             // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); | ||||
|             // First LayerNorm | ||||
|             embeddings = ggml_norm(ctx0, embeddings, eps); | ||||
|             embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), | ||||
|                                 model.mm_1_b); | ||||
|  | ||||
|             // GELU activation | ||||
|             embeddings = ggml_gelu(ctx0, embeddings); | ||||
|  | ||||
|             // Second linear layer | ||||
|             embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); | ||||
|             embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); | ||||
|  | ||||
|             // Second LayerNorm | ||||
|             embeddings = ggml_norm(ctx0, embeddings, eps); | ||||
|             embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), | ||||
|                                 model.mm_4_b); | ||||
|         } | ||||
|         else if (ctx->proj_type == PROJECTOR_TYPE_LDP) { | ||||
|             // MobileVLM projector | ||||
| @@ -808,6 +839,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | ||||
|         else { | ||||
|             new_clip->proj_type = PROJECTOR_TYPE_MLP; | ||||
|         } | ||||
|         if (new_clip->proj_type == PROJECTOR_TYPE_MLP) { | ||||
|             if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) { | ||||
|                 new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
| #ifdef GGML_USE_CUBLAS | ||||
| @@ -956,11 +992,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { | ||||
|         vision_model.pre_ln_b            = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); | ||||
|  | ||||
|         // LLaVA projection | ||||
|         if (new_clip->proj_type == PROJECTOR_TYPE_MLP) { | ||||
|         if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { | ||||
|             vision_model.mm_0_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); | ||||
|             vision_model.mm_0_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); | ||||
|             vision_model.mm_2_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); | ||||
|             vision_model.mm_2_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); | ||||
|             try { | ||||
|                 // Yi-type llava | ||||
|                 vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight")); | ||||
|                 vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias")); | ||||
|             } catch (std::runtime_error & e) {  } | ||||
|             try { | ||||
|                 // missing in Yi-type llava | ||||
|                 vision_model.mm_2_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); | ||||
|                 vision_model.mm_2_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); | ||||
|             } catch (std::runtime_error & e) {  } | ||||
|             try { | ||||
|                 // Yi-type llava | ||||
|                 vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight")); | ||||
|                 vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias")); | ||||
|             } catch (std::runtime_error & e) {  } | ||||
|             try { | ||||
|                 // Yi-type llava | ||||
|                 vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight")); | ||||
|                 vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias")); | ||||
|             } catch (std::runtime_error & e) {  } | ||||
|         } | ||||
|         else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { | ||||
|             // MobileVLM projection | ||||
| @@ -1432,6 +1486,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { | ||||
|     } | ||||
|     else if (ctx->proj_type == PROJECTOR_TYPE_MLP) { | ||||
|         return ctx->vision_model.mm_2_b->ne[0]; | ||||
|     } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { | ||||
|         return ctx->vision_model.mm_3_b->ne[0]; | ||||
|     } | ||||
|     else { | ||||
|         std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; | ||||
|   | ||||
| @@ -148,10 +148,35 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ | ||||
|     const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; | ||||
|     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama)); | ||||
|  | ||||
|     // llava chat format is "<system_prompt>\nUSER:<image_embeddings>\n<textual_prompt>\nASSISTANT:" | ||||
|     eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant.  The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params->n_batch, &n_past, add_bos); | ||||
|     std::string system_prompt, user_prompt; | ||||
|     size_t image_pos = prompt.find("<image>"); | ||||
|     if (image_pos != std::string::npos) { | ||||
|         // new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image | ||||
|  | ||||
|         system_prompt = prompt.substr(0, image_pos); | ||||
|         user_prompt = prompt.substr(image_pos + std::string("<image>").length()); | ||||
|         // We replace \n with actual newlines in user_prompt, just in case -e was not used in templating string | ||||
|         size_t pos = 0; | ||||
|         while ((pos = user_prompt.find("\\n", pos)) != std::string::npos) { | ||||
|             user_prompt.replace(pos, 2, "\n"); | ||||
|             pos += 1; // Advance past the replaced newline | ||||
|         } | ||||
|         while ((pos = system_prompt.find("\\n", pos)) != std::string::npos) { | ||||
|             system_prompt.replace(pos, 2, "\n"); | ||||
|             pos += 1; // Advance past the replaced newline | ||||
|         } | ||||
|  | ||||
|         printf("system_prompt: %s\n", system_prompt.c_str()); | ||||
|         printf("user_prompt: %s\n", user_prompt.c_str()); | ||||
|     } else { | ||||
|         // llava-1.5 native mode | ||||
|         system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:"; | ||||
|         user_prompt = prompt + "\nASSISTANT:"; | ||||
|     } | ||||
|  | ||||
|     eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos); | ||||
|     llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); | ||||
|     eval_string(ctx_llava->ctx_llama, (prompt + "\nASSISTANT:").c_str(), params->n_batch, &n_past, false); | ||||
|     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); | ||||
|  | ||||
|     // generate the response | ||||
|  | ||||
| @@ -162,6 +187,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ | ||||
|     for (int i = 0; i < max_tgt_len; i++) { | ||||
|         const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); | ||||
|         if (strcmp(tmp, "</s>") == 0) break; | ||||
|         if (strstr(tmp, "###")) break; // Yi-VL behavior | ||||
|  | ||||
|         printf("%s", tmp); | ||||
|         fflush(stdout); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 John
					John