mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix top-p sampling to match the canonical definition (#1953)
* Fix top-p sampling to match the standard definition (smallest set that has probability mass at least p, not largest set with probability mass less than p) * top-p: correct gt to gte * add test for correct top-p behavior
This commit is contained in:
		| @@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can | |||||||
|     for (size_t i = 0; i < candidates->size; ++i) { |     for (size_t i = 0; i < candidates->size; ++i) { | ||||||
|         cum_sum += candidates->data[i].p; |         cum_sum += candidates->data[i].p; | ||||||
|  |  | ||||||
|         // Check if the running sum is greater than p or if we have kept at least min_keep tokens |         // Check if the running sum is at least p or if we have kept at least min_keep tokens | ||||||
|         if (cum_sum > p && i >= min_keep) { |         // we set the last index to i+1 to indicate that the current iterate should be included in the set | ||||||
|             last_idx = i; |         if (cum_sum >= p && i + 1 >= min_keep) { | ||||||
|  |             last_idx = i + 1; | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -181,6 +181,7 @@ int main(void) { | |||||||
|  |  | ||||||
|     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); |     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); | ||||||
|     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); |     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); | ||||||
|  |     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f); | ||||||
|     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); |     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); | ||||||
|  |  | ||||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); |     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Alex Renda
					Alex Renda