mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sampling : fix off-by-one in tail-free sampling
ggml-ci
This commit is contained in:
		| @@ -963,7 +963,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, | ||||
|         } | ||||
|     ).set_sparam()); | ||||
|     add_opt(llama_arg( | ||||
|         {"--tfs"}, "N", | ||||
|         {"--tfs", "--tfs-z"}, "Z", | ||||
|         format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z), | ||||
|         [](gpt_params & params, const std::string & value) { | ||||
|             params.sparams.tfs_z = std::stof(value); | ||||
|   | ||||
| @@ -756,20 +756,22 @@ static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_tok | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     assert(cur_p->size > 0); // guaranteed earlier | ||||
|     size_t last_idx = cur_p->size - 1; | ||||
|  | ||||
|     float cum_sum = 0.0f; | ||||
|     size_t last_idx = cur_p->size; | ||||
|     for (size_t i = 0; i < second_derivatives.size(); ++i) { | ||||
|         cum_sum += second_derivatives[i]; | ||||
|  | ||||
|         // Check if the running sum is greater than z or if we have kept at least min_keep tokens | ||||
|         if (cum_sum > ctx->z && i >= ctx->min_keep) { | ||||
|         if (cum_sum > ctx->z && (i + 1) >= ctx->min_keep) { | ||||
|             last_idx = i; | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Resize the output vector to keep only the tokens above the tail location | ||||
|     cur_p->size = last_idx; | ||||
|     cur_p->size = last_idx + 1; | ||||
| } | ||||
|  | ||||
| static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) { | ||||
|   | ||||
| @@ -272,8 +272,8 @@ int main(void) { | ||||
|     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.00f); | ||||
|  | ||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f},               0.25f); | ||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f); | ||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f); | ||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f},        0.50f); | ||||
|     test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f, 0.20f}, 0.80f); | ||||
|  | ||||
|     test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); | ||||
|     test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov