mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	code : normalize enum names (#5697)
* coda : normalize enum names ggml-ci * code : cont * code : cont
This commit is contained in:
		
							
								
								
									
										64
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -850,9 +850,9 @@ struct LLM_TN { | ||||
| // | ||||
|  | ||||
| static std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = { | ||||
|     { LLAMA_ROPE_SCALING_NONE,   "none"   }, | ||||
|     { LLAMA_ROPE_SCALING_LINEAR, "linear" }, | ||||
|     { LLAMA_ROPE_SCALING_YARN,   "yarn"   }, | ||||
|     { LLAMA_ROPE_SCALING_TYPE_NONE,   "none"   }, | ||||
|     { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, | ||||
|     { LLAMA_ROPE_SCALING_TYPE_YARN,   "yarn"   }, | ||||
| }; | ||||
|  | ||||
| static int32_t llama_rope_scaling_type_from_string(const std::string & name) { | ||||
| @@ -862,7 +862,7 @@ static int32_t llama_rope_scaling_type_from_string(const std::string & name) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return LLAMA_ROPE_SCALING_UNSPECIFIED; | ||||
|     return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; | ||||
| } | ||||
|  | ||||
| static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { | ||||
| @@ -1580,7 +1580,7 @@ struct llama_hparams { | ||||
|     bool causal_attn = true; | ||||
|     bool need_kq_pos = false; | ||||
|  | ||||
|     uint32_t pooling_type = LLAMA_POOLING_NONE; | ||||
|     uint32_t pooling_type = LLAMA_POOLING_TYPE_NONE; | ||||
|  | ||||
|     bool operator!=(const llama_hparams & other) const { | ||||
|         if (this->vocab_only    != other.vocab_only)    return true; | ||||
| @@ -2345,9 +2345,9 @@ namespace GGUFMeta { | ||||
|  | ||||
|         static const char * override_type_to_str(const llama_model_kv_override_type ty) { | ||||
|             switch (ty) { | ||||
|                 case LLAMA_KV_OVERRIDE_BOOL:  return "bool"; | ||||
|                 case LLAMA_KV_OVERRIDE_INT:   return "int"; | ||||
|                 case LLAMA_KV_OVERRIDE_FLOAT: return "float"; | ||||
|                 case LLAMA_KV_OVERRIDE_TYPE_BOOL:  return "bool"; | ||||
|                 case LLAMA_KV_OVERRIDE_TYPE_INT:   return "int"; | ||||
|                 case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; | ||||
|             } | ||||
|             return "unknown"; | ||||
|         } | ||||
| @@ -2358,13 +2358,13 @@ namespace GGUFMeta { | ||||
|                 LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", | ||||
|                     __func__, override_type_to_str(override->tag), override->key); | ||||
|                 switch (override->tag) { | ||||
|                     case LLAMA_KV_OVERRIDE_BOOL:  { | ||||
|                     case LLAMA_KV_OVERRIDE_TYPE_BOOL:  { | ||||
|                         LLAMA_LOG_INFO("%s\n", override->bool_value ? "true" : "false"); | ||||
|                     } break; | ||||
|                     case LLAMA_KV_OVERRIDE_INT:   { | ||||
|                     case LLAMA_KV_OVERRIDE_TYPE_INT:   { | ||||
|                         LLAMA_LOG_INFO("%" PRId64 "\n", override->int_value); | ||||
|                     } break; | ||||
|                     case LLAMA_KV_OVERRIDE_FLOAT: { | ||||
|                     case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { | ||||
|                         LLAMA_LOG_INFO("%.6f\n", override->float_value); | ||||
|                     } break; | ||||
|                     default: | ||||
| @@ -2383,7 +2383,7 @@ namespace GGUFMeta { | ||||
|         template<typename OT> | ||||
|         static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type | ||||
|         try_override(OT & target, const struct llama_model_kv_override *override) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, override)) { | ||||
|                 target = override->bool_value; | ||||
|                 return true; | ||||
|             } | ||||
| @@ -2393,7 +2393,7 @@ namespace GGUFMeta { | ||||
|         template<typename OT> | ||||
|         static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type | ||||
|         try_override(OT & target, const struct llama_model_kv_override *override) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, override)) { | ||||
|                 target = override->int_value; | ||||
|                 return true; | ||||
|             } | ||||
| @@ -2403,7 +2403,7 @@ namespace GGUFMeta { | ||||
|         template<typename OT> | ||||
|         static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type | ||||
|         try_override(T & target, const struct llama_model_kv_override *override) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) { | ||||
|             if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, override)) { | ||||
|                 target = override->float_value; | ||||
|                 return true; | ||||
|             } | ||||
| @@ -2999,7 +2999,7 @@ static void llm_load_hparams( | ||||
|     std::string rope_scaling("linear"); | ||||
|     ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); | ||||
|     hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); | ||||
|     GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED); | ||||
|     GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); | ||||
|  | ||||
|     // rope_freq_scale (inverse of the kv) is optional | ||||
|     float ropescale = 0.0f; | ||||
| @@ -3643,7 +3643,7 @@ static bool llm_load_tensors( | ||||
|         model.buft_layer[i] = llama_default_buffer_type_cpu(true); | ||||
|     } | ||||
|  | ||||
|     if (split_mode == LLAMA_SPLIT_LAYER) { | ||||
|     if (split_mode == LLAMA_SPLIT_MODE_LAYER) { | ||||
|         // calculate the split points | ||||
|         int device_count = llama_get_device_count(); | ||||
|         bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; }); | ||||
| @@ -3682,10 +3682,10 @@ static bool llm_load_tensors( | ||||
|         } | ||||
|     } else { | ||||
|         ggml_backend_buffer_type_t split_buft; | ||||
|         if (split_mode == LLAMA_SPLIT_ROW) { | ||||
|         if (split_mode == LLAMA_SPLIT_MODE_ROW) { | ||||
|             split_buft = llama_default_buffer_type_split(main_gpu, tensor_split); | ||||
|         } else { | ||||
|             // LLAMA_SPLIT_NONE or LLAMA_SPLIT_LAYER in backends where it is not supported | ||||
|             // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported | ||||
|             split_buft = llama_default_buffer_type_offload(main_gpu); | ||||
|         } | ||||
|         // assign the repeating layers | ||||
| @@ -5070,7 +5070,7 @@ struct llm_build_context { | ||||
|         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head), | ||||
|         n_orig_ctx       (cparams.n_yarn_orig_ctx), | ||||
|         do_rope_shift    (worst_case || kv_self.has_shift), | ||||
|         pooling_type     (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_NONE), | ||||
|         pooling_type     (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_TYPE_NONE), | ||||
|         cb               (cb), | ||||
|         buf_compute_meta (lctx.buf_compute_meta) { | ||||
|             // all initializations should be done in init() | ||||
| @@ -6050,12 +6050,12 @@ struct llm_build_context { | ||||
|         cur = inpL; | ||||
|  | ||||
|         // pooling layer | ||||
|         if (pooling_type == LLAMA_POOLING_MEAN) { | ||||
|         if (pooling_type == LLAMA_POOLING_TYPE_MEAN) { | ||||
|             cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); | ||||
|         } else if (pooling_type == LLAMA_POOLING_CLS) { | ||||
|         } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) { | ||||
|             cur = ggml_get_rows(ctx0, cur, inp_cls); | ||||
|         } else { | ||||
|             GGML_ASSERT(pooling_type == LLAMA_POOLING_NONE && "Invalid pooling type"); | ||||
|             GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type"); | ||||
|         } | ||||
|         cb(cur, "result_embd", -1); | ||||
|  | ||||
| @@ -7754,7 +7754,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_MEAN) { | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); | ||||
| @@ -7782,7 +7782,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_CLS) { | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); | ||||
| @@ -11351,7 +11351,7 @@ static int llama_apply_lora_from_file_internal( | ||||
| struct llama_model_params llama_model_default_params() { | ||||
|     struct llama_model_params result = { | ||||
|         /*.n_gpu_layers                =*/ 0, | ||||
|         /*.split_mode                  =*/ LLAMA_SPLIT_LAYER, | ||||
|         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER, | ||||
|         /*.main_gpu                    =*/ 0, | ||||
|         /*.tensor_split                =*/ nullptr, | ||||
|         /*.progress_callback           =*/ nullptr, | ||||
| @@ -11377,7 +11377,7 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.n_batch                     =*/ 512, | ||||
|         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default | ||||
|         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS, | ||||
|         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, | ||||
|         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, | ||||
|         /*.rope_freq_base              =*/ 0.0f, | ||||
|         /*.rope_freq_scale             =*/ 0.0f, | ||||
|         /*.yarn_ext_factor             =*/ -1.0f, | ||||
| @@ -11565,16 +11565,16 @@ struct llama_context * llama_new_context_with_model( | ||||
|     cparams.cb_eval_user_data = params.cb_eval_user_data; | ||||
|  | ||||
|     auto rope_scaling_type = params.rope_scaling_type; | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { | ||||
|         rope_scaling_type = hparams.rope_scaling_type_train; | ||||
|     } | ||||
|  | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) { | ||||
|     if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { | ||||
|         cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none | ||||
|     } | ||||
|  | ||||
|     if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' | ||||
|         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f; | ||||
|         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; | ||||
|     } | ||||
|  | ||||
|     if (params.seed == LLAMA_DEFAULT_SEED) { | ||||
| @@ -11608,8 +11608,8 @@ struct llama_context * llama_new_context_with_model( | ||||
|         } | ||||
| #elif defined(GGML_USE_CUBLAS) | ||||
|         if (model->n_gpu_layers > 0) { | ||||
|             // with split_mode LLAMA_SPLIT_NONE or LLAMA_SPLIT_ROW, only the main GPU backend is used | ||||
|             if (model->split_mode == LLAMA_SPLIT_NONE || model->split_mode == LLAMA_SPLIT_ROW) { | ||||
|             // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used | ||||
|             if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { | ||||
|                 ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu); | ||||
|                 if (backend == nullptr) { | ||||
|                     LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu); | ||||
| @@ -11618,7 +11618,7 @@ struct llama_context * llama_new_context_with_model( | ||||
|                 } | ||||
|                 ctx->backends.push_back(backend); | ||||
|             } else { | ||||
|                 // LLAMA_SPLIT_LAYER requires a backend for each GPU | ||||
|                 // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU | ||||
|                 for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) { | ||||
|                     ggml_backend_t backend = ggml_backend_cuda_init(device); | ||||
|                     if (backend == nullptr) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov