mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llava : support Minicpm-omni (#11289)
* init * add readme * update readme * no use make * update readme * update fix code * fix editorconfig-checker * no change convert py * use clip_image_u8_free
This commit is contained in:
		
							
								
								
									
										46
									
								
								examples/llava/README-minicpmo2.6.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								examples/llava/README-minicpmo2.6.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,46 @@ | |||||||
|  | ## MiniCPM-o 2.6 | ||||||
|  | Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. | ||||||
|  |  | ||||||
|  | ### Prepare models and code | ||||||
|  |  | ||||||
|  | Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. | ||||||
|  |  | ||||||
|  | Clone llama.cpp: | ||||||
|  | ```bash | ||||||
|  | git clone git@github.com:OpenBMB/llama.cpp.git | ||||||
|  | cd llama.cpp | ||||||
|  | git checkout minicpm-omni | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Usage of MiniCPM-o 2.6 | ||||||
|  |  | ||||||
|  | Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 | ||||||
|  | python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 | ||||||
|  | python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model | ||||||
|  |  | ||||||
|  | # quantize int4 version | ||||||
|  | ./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Build llama.cpp using `CMake`: | ||||||
|  | https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | cmake -B build | ||||||
|  | cmake --build build --config Release | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Inference on Linux or Mac | ||||||
|  | ``` | ||||||
|  | # run f16 version | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" | ||||||
|  |  | ||||||
|  | # run quantized int4 version | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?" | ||||||
|  |  | ||||||
|  | # or run in interactive mode | ||||||
|  | ./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i | ||||||
|  | ``` | ||||||
| @@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|         else if (ctx->minicpmv_version == 3) { |         else if (ctx->minicpmv_version == 3) { | ||||||
|             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); |             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); | ||||||
|         } |         } | ||||||
|  |         else if (ctx->minicpmv_version == 4) { | ||||||
|  |             pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); | ||||||
|  |         } | ||||||
|         ggml_set_name(pos_embed, "pos_embed"); |         ggml_set_name(pos_embed, "pos_embed"); | ||||||
|         ggml_set_input(pos_embed); |         ggml_set_input(pos_embed); | ||||||
|     } |     } | ||||||
| @@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 | |||||||
|                     n_head = hidden_size/d_head; |                     n_head = hidden_size/d_head; | ||||||
|                     num_query = 64; |                     num_query = 64; | ||||||
|                 } |                 } | ||||||
|  |                 else if (ctx->minicpmv_version == 4) { | ||||||
|  |                     hidden_size = 3584; | ||||||
|  |                     n_head = hidden_size/d_head; | ||||||
|  |                     num_query = 64; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|                 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); |                 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); | ||||||
|                 Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); |                 Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); | ||||||
| @@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag | |||||||
|                 images[images.size()-1].push_back(patch); |                 images[images.size()-1].push_back(patch); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |         clip_image_u8_free(refine_image); | ||||||
|     } |     } | ||||||
|     return images; |     return images; | ||||||
| } | } | ||||||
| @@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli | |||||||
|                 clip_image_f32_free(res); |                 clip_image_f32_free(res); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |         for (size_t i = 0; i < imgs.size(); ++i) { | ||||||
|  |             for (size_t j = 0; j < imgs[i].size(); ++j) { | ||||||
|  |                 if (imgs[i][j] != nullptr) { | ||||||
|  |                     clip_image_u8_free(imgs[i][j]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|     else if (ctx->has_qwen2vl_merger) { |     else if (ctx->has_qwen2vl_merger) { | ||||||
| @@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i | |||||||
|         else if (ctx->minicpmv_version == 3) { |         else if (ctx->minicpmv_version == 3) { | ||||||
|             n_patches = 64; |             n_patches = 64; | ||||||
|         } |         } | ||||||
|  |         else if (ctx->minicpmv_version == 4) { | ||||||
|  |             n_patches = 64; | ||||||
|  |         } | ||||||
|     } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { |     } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { | ||||||
|         int patch_size = params.patch_size * 2; |         int patch_size = params.patch_size * 2; | ||||||
|         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); |         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); | ||||||
| @@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 |             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 | ||||||
|             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); |             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); | ||||||
|             int* positions_data = (int*)malloc(ggml_nbytes(positions)); |             int* positions_data = (int*)malloc(ggml_nbytes(positions)); | ||||||
|             int bucket_coords_h[70]; |             int bucket_coords_h[1024]; | ||||||
|             int bucket_coords_w[70]; |             int bucket_coords_w[1024]; | ||||||
|             for (int i = 0; i < pos_h; i++){ |             for (int i = 0; i < pos_h; i++){ | ||||||
|                 bucket_coords_h[i] = std::floor(70.0*i/pos_h); |                 bucket_coords_h[i] = std::floor(70.0*i/pos_h); | ||||||
|             } |             } | ||||||
| @@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima | |||||||
|             else if (ctx->minicpmv_version == 3) { |             else if (ctx->minicpmv_version == 3) { | ||||||
|                 embed_dim = 3584; |                 embed_dim = 3584; | ||||||
|             } |             } | ||||||
|  |             else if (ctx->minicpmv_version == 4) { | ||||||
|  |                 embed_dim = 3584; | ||||||
|  |             } | ||||||
|             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); |             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); | ||||||
|  |  | ||||||
|             float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); |             float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); | ||||||
| @@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { | |||||||
|         else if (ctx->minicpmv_version == 3) { |         else if (ctx->minicpmv_version == 3) { | ||||||
|             return 3584; |             return 3584; | ||||||
|         } |         } | ||||||
|  |         else if (ctx->minicpmv_version == 4) { | ||||||
|  |             return 3584; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|     if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { |     if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { | ||||||
|         return ctx->vision_model.mm_1_b->ne[0]; |         return ctx->vision_model.mm_1_b->ne[0]; | ||||||
|   | |||||||
| @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> | |||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { | static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { | ||||||
|     int width = image->nx; |     int width = image->nx; | ||||||
|     int height = image->ny; |     int height = image->ny; | ||||||
|     int num_patches = (height / patch_size) * (width / patch_size); |     int num_patches = (height / patch_size) * (width / patch_size); | ||||||
| @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli | |||||||
|                 encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); |                 encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                 int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); |                 encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); | ||||||
|                 if (has_minicpmv_projector == 2) { |  | ||||||
|                     encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); |  | ||||||
|                 } |  | ||||||
|                 else if (has_minicpmv_projector == 3) { |  | ||||||
|                     encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); |  | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (!encoded) { |             if (!encoded) { | ||||||
| @@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli | |||||||
|         load_image_size->height = img->ny; |         load_image_size->height = img->ny; | ||||||
|         clip_add_load_image_size(ctx_clip, load_image_size); |         clip_add_load_image_size(ctx_clip, load_image_size); | ||||||
|         LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); |         LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); | ||||||
|  |         delete[] img_res_v.data; | ||||||
|  |         img_res_v.size = 0; | ||||||
|  |         img_res_v.data = nullptr; | ||||||
|     } |     } | ||||||
|     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { |     else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { | ||||||
|         // flat / default llava-1.5 type embedding |         // flat / default llava-1.5 type embedding | ||||||
|   | |||||||
| @@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e | |||||||
|     else if (has_minicpmv_projector == 3) { |     else if (has_minicpmv_projector == 3) { | ||||||
|         system_prompt = "<|im_start|>user\n"; |         system_prompt = "<|im_start|>user\n"; | ||||||
|     } |     } | ||||||
|  |     else if (has_minicpmv_projector == 4) { | ||||||
|  |         system_prompt = "<|im_start|>user\n"; | ||||||
|  |     } | ||||||
|     LOG_INF("%s: image token past: %d\n", __func__, n_past); |     LOG_INF("%s: image token past: %d\n", __func__, n_past); | ||||||
|     eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false); |     eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false); | ||||||
|     process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); |     process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); | ||||||
| @@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm | |||||||
|         else if (has_minicpmv_projector == 3) { |         else if (has_minicpmv_projector == 3) { | ||||||
|             user_prompt = "<|im_start|>user\n" + prompt; |             user_prompt = "<|im_start|>user\n" + prompt; | ||||||
|         } |         } | ||||||
|  |         else if (has_minicpmv_projector == 4) { | ||||||
|  |             user_prompt = "<|im_start|>user\n" + prompt; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); |     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); | ||||||
| @@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm | |||||||
|     else if (has_minicpmv_projector == 3) { |     else if (has_minicpmv_projector == 3) { | ||||||
|         eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); |         eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); | ||||||
|     } |     } | ||||||
|  |     else if (has_minicpmv_projector == 4) { | ||||||
|  |         eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // generate the response |     // generate the response | ||||||
|  |  | ||||||
| @@ -308,7 +317,6 @@ int main(int argc, char ** argv) { | |||||||
|                     const auto * tmp = llama_loop(ctx_llava, smpl, n_past); |                     const auto * tmp = llama_loop(ctx_llava, smpl, n_past); | ||||||
|                     response += tmp; |                     response += tmp; | ||||||
|                     if (strcmp(tmp, "</s>") == 0) break; |                     if (strcmp(tmp, "</s>") == 0) break; | ||||||
|                     if (strstr(tmp, "###")) break; // Yi-VL behavior |  | ||||||
|                     printf("%s", tmp);// mistral llava-1.6 |                     printf("%s", tmp);// mistral llava-1.6 | ||||||
|                     if (strstr(response.c_str(), "<user>")) break; // minicpm-v |                     if (strstr(response.c_str(), "<user>")) break; // minicpm-v | ||||||
|                     fflush(stdout); |                     fflush(stdout); | ||||||
|   | |||||||
| @@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] | |||||||
| default_image_std = [0.26862954, 0.26130258, 0.27577711] | default_image_std = [0.26862954, 0.26130258, 0.27577711] | ||||||
| ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) | ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) | ||||||
| ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) | ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) | ||||||
| ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) | ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) | ||||||
|  |  | ||||||
| # with proper | # with proper | ||||||
| args = ap.parse_args() | args = ap.parse_args() | ||||||
| @@ -545,12 +545,19 @@ if args.use_f32: | |||||||
|  |  | ||||||
| minicpmv_version = args.minicpmv_version | minicpmv_version = args.minicpmv_version | ||||||
| emb_dim = 4096 | emb_dim = 4096 | ||||||
|  | block_count = 26 | ||||||
| if minicpmv_version == 1: | if minicpmv_version == 1: | ||||||
|     emb_dim = 2304 |     emb_dim = 2304 | ||||||
|  |     block_count = 26 | ||||||
| elif minicpmv_version == 2: | elif minicpmv_version == 2: | ||||||
|     emb_dim = 4096 |     emb_dim = 4096 | ||||||
|  |     block_count = 27 | ||||||
| elif minicpmv_version == 3: | elif minicpmv_version == 3: | ||||||
|     emb_dim = 3584 |     emb_dim = 3584 | ||||||
|  |     block_count = 27 | ||||||
|  | elif minicpmv_version == 4: | ||||||
|  |     emb_dim = 3584 | ||||||
|  |     block_count = 27 | ||||||
|  |  | ||||||
| default_vision_config = { | default_vision_config = { | ||||||
|         "hidden_size": 1152, |         "hidden_size": 1152, | ||||||
| @@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config) | |||||||
| if minicpmv_version == 3: | if minicpmv_version == 3: | ||||||
|     vision_config = SiglipVisionConfig(**default_vision_config) |     vision_config = SiglipVisionConfig(**default_vision_config) | ||||||
|     model = SiglipVisionTransformer(vision_config) |     model = SiglipVisionTransformer(vision_config) | ||||||
|  | elif minicpmv_version == 4: | ||||||
|  |     vision_config = SiglipVisionConfig(**default_vision_config) | ||||||
|  |     model = SiglipVisionTransformer(vision_config) | ||||||
|  |  | ||||||
| processor = None | processor = None | ||||||
| # if model.attn_pool is not None: | # if model.attn_pool is not None: | ||||||
| @@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None: | |||||||
|     fname_middle = "mmproj-" |     fname_middle = "mmproj-" | ||||||
|     has_text_encoder = False |     has_text_encoder = False | ||||||
|     has_minicpmv_projector = True |     has_minicpmv_projector = True | ||||||
|     minicpmv_version = 3 |     minicpmv_version = 4 | ||||||
| elif args.vision_only: | elif args.vision_only: | ||||||
|     fname_middle = "vision-" |     fname_middle = "vision-" | ||||||
|     has_text_encoder = False |     has_text_encoder = False | ||||||
| @@ -625,7 +635,6 @@ if has_vision_encoder: | |||||||
|     fout.add_uint32("clip.vision.projection_dim", 0) |     fout.add_uint32("clip.vision.projection_dim", 0) | ||||||
|     fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) |     fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) | ||||||
|     fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) |     fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) | ||||||
|     block_count = 26 |  | ||||||
|     fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) |     fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) | ||||||
|  |  | ||||||
|     if processor is not None: |     if processor is not None: | ||||||
|   | |||||||
| @@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") | |||||||
| args = ap.parse_args() | args = ap.parse_args() | ||||||
|  |  | ||||||
| # find the model part that includes the the multimodal projector weights | # find the model part that includes the the multimodal projector weights | ||||||
| model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) | model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) | ||||||
| checkpoint = model.state_dict() | checkpoint = model.state_dict() | ||||||
|  |  | ||||||
| # get a list of mm tensor names | # get a list of mm tensor names | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 tc-mb
					tc-mb