mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type * llama : use enum types over int --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										44
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										44
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -873,16 +873,16 @@ struct LLM_TN { | ||||
| // gguf helpers | ||||
| // | ||||
|  | ||||
| static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = { | ||||
| static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = { | ||||
|     { LLAMA_ROPE_SCALING_TYPE_NONE,   "none"   }, | ||||
|     { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, | ||||
|     { LLAMA_ROPE_SCALING_TYPE_YARN,   "yarn"   }, | ||||
| }; | ||||
|  | ||||
| static int32_t llama_rope_scaling_type_from_string(const std::string & name) { | ||||
| static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { | ||||
|     for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { | ||||
|         if (kv.second == name) { | ||||
|             return kv.first; | ||||
|             return (llama_rope_scaling_type) kv.first; | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -1612,7 +1612,6 @@ struct llama_hparams { | ||||
|     float    rope_freq_base_train; | ||||
|     float    rope_freq_scale_train; | ||||
|     uint32_t n_yarn_orig_ctx; | ||||
|     int32_t  rope_scaling_type_train; | ||||
|  | ||||
|     float f_clamp_kqv      = 0.0f; | ||||
|     float f_max_alibi_bias = 0.0f; | ||||
| @@ -1620,8 +1619,9 @@ struct llama_hparams { | ||||
|     bool causal_attn = true; | ||||
|     bool need_kq_pos = false; | ||||
|  | ||||
|     enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; | ||||
|     enum llama_rope_type    rope_type    = LLAMA_ROPE_TYPE_NONE; | ||||
|     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE; | ||||
|     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE; | ||||
|     enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; | ||||
|  | ||||
|     bool operator!=(const llama_hparams & other) const { | ||||
|         if (this->vocab_only    != other.vocab_only)    return true; | ||||
| @@ -1670,8 +1670,8 @@ struct llama_cparams { | ||||
|     uint32_t n_threads;       // number of threads to use for generation | ||||
|     uint32_t n_threads_batch; // number of threads to use for batch processing | ||||
|  | ||||
|     float    rope_freq_base; | ||||
|     float    rope_freq_scale; | ||||
|     float rope_freq_base; | ||||
|     float rope_freq_scale; | ||||
|  | ||||
|     uint32_t n_yarn_orig_ctx; | ||||
|     // These hyperparameters are not exposed in GGUF, because all | ||||
| @@ -1683,7 +1683,7 @@ struct llama_cparams { | ||||
|     float defrag_thold; | ||||
|  | ||||
|     bool offload_kqv; | ||||
|     bool do_pooling; | ||||
|     enum llama_pooling_type pooling_type; | ||||
|  | ||||
|     ggml_backend_sched_eval_callback cb_eval; | ||||
|     void * cb_eval_user_data; | ||||
| @@ -2933,7 +2933,11 @@ template<> | ||||
| bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { | ||||
|     uint32_t tmp; | ||||
|     const bool found = get_key(kid, tmp, required); | ||||
|     result = (enum llama_pooling_type) tmp; | ||||
|     if (found) { | ||||
|         result = (enum llama_pooling_type) tmp; | ||||
|     } else { | ||||
|         result = LLAMA_POOLING_TYPE_UNSPECIFIED; | ||||
|     } | ||||
|     return found; | ||||
| } | ||||
|  | ||||
| @@ -3210,7 +3214,7 @@ static void llm_load_hparams( | ||||
|                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps); | ||||
|                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn); | ||||
|                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); | ||||
|                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type); | ||||
|                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false); | ||||
|  | ||||
|                 switch (hparams.n_layer) { | ||||
|                     case 3: | ||||
| @@ -5175,7 +5179,7 @@ struct llm_build_context { | ||||
|         n_kv             (worst_case ? n_ctx            : kv_self.n), | ||||
|         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head), | ||||
|         n_orig_ctx       (cparams.n_yarn_orig_ctx), | ||||
|         pooling_type     (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE), | ||||
|         pooling_type     (cparams.pooling_type), | ||||
|         rope_type        (hparams.rope_type), | ||||
|         cb               (cb), | ||||
|         buf_compute_meta (lctx.buf_compute_meta) { | ||||
| @@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { | ||||
|     if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); | ||||
| @@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { | ||||
|     if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { | ||||
|         const int64_t n_tokens = batch.n_tokens; | ||||
|  | ||||
|         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); | ||||
| @@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default | ||||
|         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS, | ||||
|         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, | ||||
|         /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, | ||||
|         /*.rope_freq_base              =*/ 0.0f, | ||||
|         /*.rope_freq_scale             =*/ 0.0f, | ||||
|         /*.yarn_ext_factor             =*/ -1.0f, | ||||
| @@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() { | ||||
|         /*.logits_all                  =*/ false, | ||||
|         /*.embedding                   =*/ false, | ||||
|         /*.offload_kqv                 =*/ true, | ||||
|         /*.do_pooling                  =*/ true, | ||||
|         /*.abort_callback              =*/ nullptr, | ||||
|         /*.abort_callback_data         =*/ nullptr, | ||||
|     }; | ||||
| @@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model( | ||||
|     cparams.yarn_beta_slow   = params.yarn_beta_slow; | ||||
|     cparams.defrag_thold     = params.defrag_thold; | ||||
|     cparams.offload_kqv      = params.offload_kqv; | ||||
|     cparams.do_pooling       = params.do_pooling; | ||||
|     cparams.pooling_type     = params.pooling_type; | ||||
|  | ||||
|     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx; | ||||
|     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base; | ||||
| @@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model( | ||||
|         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; | ||||
|     } | ||||
|  | ||||
|     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { | ||||
|         if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { | ||||
|             cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; | ||||
|         } else { | ||||
|             cparams.pooling_type = hparams.pooling_type; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (params.seed == LLAMA_DEFAULT_SEED) { | ||||
|         params.seed = time(NULL); | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Douglas Hanley
					Douglas Hanley