sampling : optimize samplers by reusing bucket sort (#15665)

* sampling : optimize sorting using bucket sort in more places

ggml-ci

* sampling : do not sort in dist sampler

ggml-ci

* sampling : avoid heap allocations for sort buffers

ggml-ci

* common : add option to sort sampling candidates by probability

ggml-ci

* sampling : revert the change for preserving sort buffers

* sampling : use std::copy instead of memcpy

* sampling : clarify purpose of partial sort helpers

ggml-ci

* cont : remove wrong comment [no ci]

* common : update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Georgi Gerganov
2025-08-31 20:41:02 +03:00
committed by GitHub
parent 0d161f021a
commit e92d53b29e
9 changed files with 227 additions and 171 deletions

View File

@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
return &gsmpl->cur_p;
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {
// remember the selected token before sorting
const llama_token id = res->data[res->selected].id;
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});
// restore the selected token after sorting
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}
res->sorted = true;
}
return res;
}
llama_token common_sampler_last(const struct common_sampler * gsmpl) {

View File

@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers
// access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
// the .sorted flag of the result indicates whether the returned candidates are sorted
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
// get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl);

View File

@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",