mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	grammar : check the full vocab only if necessary (opt) (#4306)
* Check the full vocab for grammar only if necessary * Fix missing logit restoration step (?) Does this matter, actually? * Fix whitespace / formatting * Adjust comment * Didn't mean to push test gbnf * Split sampling into the helper function (?) And also revert the changes made to the header * common : fix final newline --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -149,11 +149,12 @@ static void sampler_queue( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_token llama_sampling_sample( | static llama_token llama_sampling_sample_impl( | ||||||
|                   struct llama_sampling_context * ctx_sampling, |                   struct llama_sampling_context * ctx_sampling, | ||||||
|                   struct llama_context * ctx_main, |                   struct llama_context * ctx_main, | ||||||
|                   struct llama_context * ctx_cfg, |                   struct llama_context * ctx_cfg, | ||||||
|                   const int idx) { |                   const int idx, | ||||||
|  |                   bool is_resampling) {  // Add a parameter to indicate if we are resampling | ||||||
|     const llama_sampling_params & params = ctx_sampling->params; |     const llama_sampling_params & params = ctx_sampling->params; | ||||||
|  |  | ||||||
|     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); |     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); | ||||||
| @@ -173,8 +174,17 @@ llama_token llama_sampling_sample( | |||||||
|  |  | ||||||
|     llama_token id = 0; |     llama_token id = 0; | ||||||
|  |  | ||||||
|  |     // Get a pointer to the logits | ||||||
|     float * logits = llama_get_logits_ith(ctx_main, idx); |     float * logits = llama_get_logits_ith(ctx_main, idx); | ||||||
|  |  | ||||||
|  |     // Declare original_logits at the beginning of the function scope | ||||||
|  |     std::vector<float> original_logits; | ||||||
|  |  | ||||||
|  |     if (!is_resampling) { | ||||||
|  |         // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. | ||||||
|  |         original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // apply params.logit_bias map |     // apply params.logit_bias map | ||||||
|     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { |     for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { | ||||||
|         logits[it->first] += it->second; |         logits[it->first] += it->second; | ||||||
| @@ -210,7 +220,8 @@ llama_token llama_sampling_sample( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (ctx_sampling->grammar != NULL) { |     // If we are in the resampling phase, apply grammar checks before sampling logic | ||||||
|  |     if (is_resampling && ctx_sampling->grammar != NULL) { | ||||||
|         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); |         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -252,9 +263,40 @@ llama_token llama_sampling_sample( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if (ctx_sampling->grammar != NULL && !is_resampling) { | ||||||
|  |         // Create an array with a single token data element for the sampled id | ||||||
|  |         llama_token_data single_token_data = {id, logits[id], 0.0f}; | ||||||
|  |         llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; | ||||||
|  |  | ||||||
|  |         // Apply grammar constraints to the single token | ||||||
|  |         llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar); | ||||||
|  |  | ||||||
|  |         // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY | ||||||
|  |         bool is_valid = single_token_data_array.data[0].logit != -INFINITY; | ||||||
|  |  | ||||||
|  |         // If the token is not valid according to the grammar, perform resampling | ||||||
|  |         if (!is_valid) { | ||||||
|  |             LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); | ||||||
|  |  | ||||||
|  |             // Restore logits from the copy | ||||||
|  |             std::copy(original_logits.begin(), original_logits.end(), logits); | ||||||
|  |  | ||||||
|  |             return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);  // Pass true for is_resampling | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     return id; |     return id; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | llama_token llama_sampling_sample( | ||||||
|  |                   struct llama_sampling_context * ctx_sampling, | ||||||
|  |                   struct llama_context * ctx_main, | ||||||
|  |                   struct llama_context * ctx_cfg, | ||||||
|  |                   const int idx) { | ||||||
|  |     // Call the implementation function with is_resampling set to false by default | ||||||
|  |     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); | ||||||
|  | } | ||||||
|  |  | ||||||
| void llama_sampling_accept( | void llama_sampling_accept( | ||||||
|         struct llama_sampling_context * ctx_sampling, |         struct llama_sampling_context * ctx_sampling, | ||||||
|         struct llama_context * ctx_main, |         struct llama_context * ctx_main, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 kalomaze
					kalomaze