mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	sampling : deduplicated code for probability distribution access (#6240)
* sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare`
This commit is contained in:
		| @@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl( | ||||
|                   bool is_resampling) {  // Add a parameter to indicate if we are resampling | ||||
|     const llama_sampling_params & params = ctx_sampling->params; | ||||
|  | ||||
|     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); | ||||
|  | ||||
|     const float   temp            = params.temp; | ||||
|     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; | ||||
|     const float   penalty_repeat  = params.penalty_repeat; | ||||
|     const float   penalty_freq    = params.penalty_freq; | ||||
|     const float   penalty_present = params.penalty_present; | ||||
|     const int     mirostat        = params.mirostat; | ||||
|     const float   mirostat_tau    = params.mirostat_tau; | ||||
|     const float   mirostat_eta    = params.mirostat_eta; | ||||
|     const bool    penalize_nl     = params.penalize_nl; | ||||
|  | ||||
|     auto & prev = ctx_sampling->prev; | ||||
|     auto & cur  = ctx_sampling->cur; | ||||
|  | ||||
|     std::vector<float> original_logits; | ||||
|     auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); | ||||
|     if (!is_resampling) { | ||||
|         GGML_ASSERT(!original_logits.empty()); | ||||
|     } | ||||
|     llama_token id = 0; | ||||
|  | ||||
|     // Get a pointer to the logits | ||||
|     float * logits = llama_get_logits_ith(ctx_main, idx); | ||||
|  | ||||
|     // Declare original_logits at the beginning of the function scope | ||||
|     std::vector<float> original_logits; | ||||
|  | ||||
|     if (!is_resampling) { | ||||
|         // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. | ||||
|         original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); | ||||
|     } | ||||
|  | ||||
|     // apply params.logit_bias map | ||||
|     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||
|         logits[it->first] += it->second; | ||||
|     } | ||||
|  | ||||
|     if (ctx_cfg) { | ||||
|         float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); | ||||
|         llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); | ||||
|     } | ||||
|  | ||||
|     cur.clear(); | ||||
|  | ||||
|     for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | ||||
|         cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | ||||
|     } | ||||
|  | ||||
|     llama_token_data_array cur_p = { cur.data(), cur.size(), false }; | ||||
|  | ||||
|     // apply penalties | ||||
|     const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; | ||||
|     const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); | ||||
|     if (penalty_tokens_used_size) { | ||||
|         const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; | ||||
|  | ||||
|         llama_sample_repetition_penalties(ctx_main, &cur_p, | ||||
|                 penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, | ||||
|                 penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); | ||||
|  | ||||
|         if (!penalize_nl) { | ||||
|             for (size_t idx = 0; idx < cur_p.size; idx++) { | ||||
|                 if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { | ||||
|                     cur_p.data[idx].logit = nl_logit; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // If we are in the resampling phase, apply grammar checks before sampling logic | ||||
|     if (is_resampling && ctx_sampling->grammar != NULL) { | ||||
|         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); | ||||
|     } | ||||
|  | ||||
|     if (temp < 0.0) { | ||||
|         // greedy sampling, with probs | ||||
|         llama_sample_softmax(ctx_main, &cur_p); | ||||
| @@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl( | ||||
|     return id; | ||||
| } | ||||
|  | ||||
| static llama_token_data_array llama_sample_probability_distribution_impl( | ||||
| static llama_token_data_array llama_sampling_prepare_impl( | ||||
|                   struct llama_sampling_context * ctx_sampling, | ||||
|                   struct llama_context * ctx_main, | ||||
|                   struct llama_context * ctx_cfg, | ||||
|                   const int idx) { | ||||
|                   const int idx, | ||||
|                   bool apply_grammar, | ||||
|                   std::vector<float> * original_logits) { | ||||
|     const llama_sampling_params & params = ctx_sampling->params; | ||||
|  | ||||
|     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); | ||||
| @@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl( | ||||
|     const float   penalty_repeat  = params.penalty_repeat; | ||||
|     const float   penalty_freq    = params.penalty_freq; | ||||
|     const float   penalty_present = params.penalty_present; | ||||
|  | ||||
|     const bool    penalize_nl     = params.penalize_nl; | ||||
|  | ||||
|     auto & prev = ctx_sampling->prev; | ||||
| @@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl( | ||||
|     // Get a pointer to the logits | ||||
|     float * logits = llama_get_logits_ith(ctx_main, idx); | ||||
|  | ||||
|     // Declare original_logits at the beginning of the function scope | ||||
|     std::vector<float> original_logits; | ||||
|     if (apply_grammar && original_logits != NULL) { | ||||
|         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. | ||||
|         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; | ||||
|     } | ||||
|  | ||||
|     // apply params.logit_bias map | ||||
|     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||
| @@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // apply grammar checks | ||||
|     if (ctx_sampling->grammar != NULL) { | ||||
|     // apply grammar checks before sampling logic | ||||
|     if (apply_grammar && ctx_sampling->grammar != NULL) { | ||||
|         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); | ||||
|     } | ||||
|  | ||||
|     llama_sample_softmax(ctx_main, &cur_p); | ||||
|     return cur_p; | ||||
| } | ||||
|  | ||||
| @@ -382,12 +329,14 @@ llama_token llama_sampling_sample( | ||||
|     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); | ||||
| } | ||||
|  | ||||
| llama_token_data_array llama_sampling_probability_distribution( | ||||
| llama_token_data_array llama_sampling_prepare( | ||||
|                   struct llama_sampling_context * ctx_sampling, | ||||
|                   struct llama_context * ctx_main, | ||||
|                   struct llama_context * ctx_cfg, | ||||
|                   const int idx) { | ||||
|     return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); | ||||
|                   const int idx, | ||||
|                   bool apply_grammar, | ||||
|                   std::vector<float> * original_logits) { | ||||
|     return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); | ||||
| } | ||||
|  | ||||
| void llama_sampling_accept( | ||||
|   | ||||
| @@ -131,12 +131,14 @@ llama_token llama_sampling_sample( | ||||
|         struct llama_context * ctx_cfg, | ||||
|         int idx = 0); | ||||
|  | ||||
| // returns the probability that token of given id will be sampled | ||||
| llama_token_data_array llama_sampling_probability_distribution( | ||||
| // Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. | ||||
| llama_token_data_array llama_sampling_prepare( | ||||
|         struct llama_sampling_context * ctx_sampling, | ||||
|         struct llama_context * ctx_main, | ||||
|         struct llama_context * ctx_cfg, | ||||
|         int idx = 0); | ||||
|         int idx = 0, | ||||
|         bool apply_grammar = true, | ||||
|         std::vector<float> * original_logits = nullptr); | ||||
|  | ||||
| void llama_sampling_accept( | ||||
|         struct llama_sampling_context * ctx_sampling, | ||||
|   | ||||
| @@ -219,7 +219,8 @@ int main(int argc, char ** argv) { | ||||
|                 if (params.sparams.temp > 0) { | ||||
|                     // stochastic verification | ||||
|  | ||||
|                     llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); | ||||
|                     llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); | ||||
|                     llama_sample_softmax(ctx_tgt, &dist_tgt); | ||||
|                     float p_tgt = 0, p_dft = 0; | ||||
|  | ||||
|                     // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Minsoo Cheong
					Minsoo Cheong