mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	grammars: fix resampling logic regression (#7424)
				
					
				
			This commit is contained in:
		| @@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl( | |||||||
|                   struct llama_context * ctx_main, |                   struct llama_context * ctx_main, | ||||||
|                   struct llama_context * ctx_cfg, |                   struct llama_context * ctx_cfg, | ||||||
|                   const int idx, |                   const int idx, | ||||||
|                   bool is_resampling) {  // Add a parameter to indicate if we are resampling |                   bool is_resampling) { | ||||||
|     const llama_sampling_params & params = ctx_sampling->params; |     const llama_sampling_params & params = ctx_sampling->params; | ||||||
|  |  | ||||||
|     const float   temp            = params.temp; |     const float   temp            = params.temp; | ||||||
| @@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl( | |||||||
|     const float   mirostat_eta    = params.mirostat_eta; |     const float   mirostat_eta    = params.mirostat_eta; | ||||||
|  |  | ||||||
|     std::vector<float> original_logits; |     std::vector<float> original_logits; | ||||||
|     auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); |     auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); | ||||||
|     if (!is_resampling) { |     if (ctx_sampling->grammar != NULL && !is_resampling) { | ||||||
|         GGML_ASSERT(!original_logits.empty()); |         GGML_ASSERT(!original_logits.empty()); | ||||||
|     } |     } | ||||||
|     llama_token id = 0; |     llama_token id = 0; | ||||||
| @@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl( | |||||||
|             // Restore logits from the copy |             // Restore logits from the copy | ||||||
|             std::copy(original_logits.begin(), original_logits.end(), logits); |             std::copy(original_logits.begin(), original_logits.end(), logits); | ||||||
|  |  | ||||||
|             return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);  // Pass true for is_resampling |             return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl( | |||||||
|     // Get a pointer to the logits |     // Get a pointer to the logits | ||||||
|     float * logits = llama_get_logits_ith(ctx_main, idx); |     float * logits = llama_get_logits_ith(ctx_main, idx); | ||||||
|  |  | ||||||
|     if (apply_grammar && original_logits != NULL) { |     if (ctx_sampling->grammar != NULL && !apply_grammar) { | ||||||
|  |         GGML_ASSERT(original_logits != NULL); | ||||||
|         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. |         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. | ||||||
|         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; |         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; | ||||||
|     } |     } | ||||||
| @@ -342,7 +343,7 @@ llama_token llama_sampling_sample( | |||||||
|                   struct llama_context * ctx_cfg, |                   struct llama_context * ctx_cfg, | ||||||
|                   const int idx) { |                   const int idx) { | ||||||
|     // Call the implementation function with is_resampling set to false by default |     // Call the implementation function with is_resampling set to false by default | ||||||
|     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); |     return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_token_data_array llama_sampling_prepare( | llama_token_data_array llama_sampling_prepare( | ||||||
|   | |||||||
| @@ -707,7 +707,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); |             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); | ||||||
|  |  | ||||||
|             llama_sampling_accept(ctx_sampling, ctx, id, true); |             llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); | ||||||
|  |  | ||||||
|             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); |             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); | ||||||
|  |  | ||||||
| @@ -728,7 +728,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|                 // push the prompt in the sampling context in order to apply repetition penalties later |                 // push the prompt in the sampling context in order to apply repetition penalties later | ||||||
|                 // for the prompt, we don't apply grammar rules |                 // for the prompt, we don't apply grammar rules | ||||||
|                 llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); |                 llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); | ||||||
|  |  | ||||||
|                 ++n_consumed; |                 ++n_consumed; | ||||||
|                 if ((int) embd.size() >= params.n_batch) { |                 if ((int) embd.size() >= params.n_batch) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Olivier Chafik
					Olivier Chafik