mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	server: fix reported top tokens for temperature 0 (#7203)
This commit is contained in:
		| @@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ | ||||
|  | ||||
|     result->prev.resize(params.n_prev); | ||||
|  | ||||
|     result->n_considered = 0; | ||||
|     result->n_valid = 0; | ||||
|  | ||||
|     llama_sampling_set_rng_seed(result, params.seed); | ||||
|  | ||||
| @@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { | ||||
|  | ||||
|     std::fill(ctx->prev.begin(), ctx->prev.end(), 0); | ||||
|     ctx->cur.clear(); | ||||
|     ctx->n_considered = 0; | ||||
|     ctx->n_valid = 0; | ||||
| } | ||||
|  | ||||
| void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { | ||||
| @@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     ctx_sampling->n_considered = cur_p.size; | ||||
|     ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; | ||||
|  | ||||
|     return id; | ||||
| } | ||||
|   | ||||
| @@ -81,7 +81,7 @@ struct llama_sampling_context { | ||||
|     // TODO: replace with ring-buffer | ||||
|     std::vector<llama_token>      prev; | ||||
|     std::vector<llama_token_data> cur; | ||||
|     size_t n_considered; | ||||
|     size_t n_valid; // Number of correct top tokens with correct probabilities. | ||||
|  | ||||
|     std::mt19937 rng; | ||||
| }; | ||||
|   | ||||
| @@ -2270,10 +2270,10 @@ struct server_context { | ||||
|  | ||||
|                 const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); | ||||
|                 if (n_probs > 0) { | ||||
|                     const size_t n_considered = slot.ctx_sampling->n_considered; | ||||
|                     const size_t n_valid = slot.ctx_sampling->n_valid; | ||||
|  | ||||
|                     // Make sure at least n_probs top tokens are at the front of the vector: | ||||
|                     if (slot.sparams.temp == 0.0f && n_probs > n_considered) { | ||||
|                     if (slot.sparams.temp == 0.0f && n_probs > n_valid) { | ||||
|                         llama_sample_top_k(ctx, &cur_p, n_probs, 0); | ||||
|                     } | ||||
|  | ||||
| @@ -2289,7 +2289,7 @@ struct server_context { | ||||
|                         for (size_t i = 0; i < n_probs; ++i) { | ||||
|                             result.probs.push_back({ | ||||
|                                 cur_p.data[i].id, | ||||
|                                 i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. | ||||
|                                 i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. | ||||
|                             }); | ||||
|                         } | ||||
|                     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler