mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
|
||||
// sorting is not necessary here
|
||||
llama_sampler_softmax_impl(cur_p, false);
|
||||
// edge cases
|
||||
if (cur_p->size == 0) {
|
||||
cur_p->selected = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
cur_p->selected = 0;
|
||||
|
||||
if (cur_p->size == 1) {
|
||||
cur_p->data[0].p = 1.0f;
|
||||
return;
|
||||
}
|
||||
|
||||
// max logit for numerical stability
|
||||
float max_l = cur_p->data[0].logit;
|
||||
if (!cur_p->sorted) {
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
max_l = std::max(max_l, cur_p->data[i].logit);
|
||||
}
|
||||
}
|
||||
|
||||
// apply softmax to obtain the probabilities
|
||||
double sum_cum = 0.0f;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
float p = expf(cur_p->data[i].logit - max_l);
|
||||
cur_p->data[i].p = p;
|
||||
sum_cum += p;
|
||||
}
|
||||
|
||||
#if 1
|
||||
// sample from the obtained probabilities and normalize the probs in a single pass
|
||||
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
|
||||
//
|
||||
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
||||
const double rnd = dist(ctx->rng);
|
||||
|
||||
double sum_run = 0.0f;
|
||||
const double sum_tgt = sum_cum*rnd;
|
||||
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (!found) {
|
||||
// accumulate probs until we reach the target sum
|
||||
sum_run += cur_p->data[i].p;
|
||||
if (sum_run >= sum_tgt) {
|
||||
cur_p->selected = i;
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
|
||||
// normalize probs
|
||||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
// fallback to the last token (don't think this can happen)
|
||||
assert(found);
|
||||
if (!found) {
|
||||
cur_p->selected = cur_p->size - 1;
|
||||
}
|
||||
#else
|
||||
// for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
cur_p->data[i].p /= sum_cum;
|
||||
}
|
||||
|
||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||
#endif
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
||||
|
||||
Reference in New Issue
Block a user