mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
This commit is contained in:
		| @@ -58,28 +58,30 @@ inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n | ||||
|  | ||||
| // TODO: use common/sampling.h | ||||
| inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { | ||||
|       // out of user input, sample next token | ||||
|     const float   temp      = params.sampling_params.temp; | ||||
|     const int32_t top_k     = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k; | ||||
|     const float   top_p     = params.sampling_params.top_p; | ||||
|     const float   tfs_z     = params.sampling_params.tfs_z; | ||||
|     const float   typical_p = params.sampling_params.typical_p; | ||||
|       // const int32_t repeat_last_n   = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n; | ||||
|       // const float   repeat_penalty  = params.sampling_params.repeat_penalty; | ||||
|       // const float   alpha_presence  = params.sampling_params.presence_penalty; | ||||
|       // const float   alpha_frequency = params.sampling_params.frequency_penalty; | ||||
|     const int     mirostat     = params.sampling_params.mirostat; | ||||
|     const float   mirostat_tau = params.sampling_params.mirostat_tau; | ||||
|     const float   mirostat_eta = params.sampling_params.mirostat_eta; | ||||
|       // const bool    penalize_nl     = params.sampling_params.penalize_nl; | ||||
|     auto & sparams = params.sparams; | ||||
|  | ||||
|     // out of user input, sample next token | ||||
|     const float   temp      = sparams.temp; | ||||
|     const int32_t top_k     = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : sparams.top_k; | ||||
|     const float   top_p     = sparams.top_p; | ||||
|     const float   tfs_z     = sparams.tfs_z; | ||||
|     const float   typical_p = sparams.typical_p; | ||||
|     // const int32_t repeat_last_n   = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n; | ||||
|     // const float   repeat_penalty  = sparams.repeat_penalty; | ||||
|     // const float   alpha_presence  = sparams.presence_penalty; | ||||
|     // const float   alpha_frequency = sparams.frequency_penalty; | ||||
|     const int     mirostat     = sparams.mirostat; | ||||
|     const float   mirostat_tau = sparams.mirostat_tau; | ||||
|     const float   mirostat_eta = sparams.mirostat_eta; | ||||
|     // const bool    penalize_nl     = sparams.penalize_nl; | ||||
|  | ||||
|     llama_token id = 0; | ||||
|     { | ||||
|         auto logits  = llama_get_logits(ctx_llama); | ||||
|         auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama)); | ||||
|  | ||||
|           // Apply params.logit_bias map | ||||
|         for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) { | ||||
|         // Apply params.logit_bias map | ||||
|         for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) { | ||||
|             logits[it->first] += it->second; | ||||
|         } | ||||
|  | ||||
| @@ -91,18 +93,18 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { | ||||
|  | ||||
|         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | ||||
|  | ||||
|           // TODO: Apply penalties | ||||
|           // float nl_logit = logits[llama_token_nl(ctx)]; | ||||
|           // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||
|           // llama_sample_repetition_penalty(ctx, &candidates_p, | ||||
|           //      last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|           //      last_n_repeat, repeat_penalty); | ||||
|           // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, | ||||
|           // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|           // last_n_repeat, alpha_frequency, alpha_presence); | ||||
|           // if (!penalize_nl) { | ||||
|           //     logits[llama_token_nl(ctx)] = nl_logit; | ||||
|           // } | ||||
|         // TODO: Apply penalties | ||||
|         // float nl_logit = logits[llama_token_nl(ctx)]; | ||||
|         // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); | ||||
|         // llama_sample_repetition_penalty(ctx, &candidates_p, | ||||
|         //      last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|         //      last_n_repeat, repeat_penalty); | ||||
|         // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, | ||||
|         // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, | ||||
|         // last_n_repeat, alpha_frequency, alpha_presence); | ||||
|         // if (!penalize_nl) { | ||||
|         //     logits[llama_token_nl(ctx)] = nl_logit; | ||||
|         // } | ||||
|  | ||||
|         if (temp <= 0) { | ||||
|               // Greedy sampling | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov