mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
@@ -604,10 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|||||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
|
|
||||||
// sorting is not necessary here
|
// edge cases
|
||||||
llama_sampler_softmax_impl(cur_p, false);
|
if (cur_p->size == 0) {
|
||||||
|
cur_p->selected = -1;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_p->selected = 0;
|
||||||
|
|
||||||
|
if (cur_p->size == 1) {
|
||||||
|
cur_p->data[0].p = 1.0f;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// max logit for numerical stability
|
||||||
|
float max_l = cur_p->data[0].logit;
|
||||||
|
if (!cur_p->sorted) {
|
||||||
|
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||||
|
max_l = std::max(max_l, cur_p->data[i].logit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply softmax to obtain the probabilities
|
||||||
|
double sum_cum = 0.0f;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
float p = expf(cur_p->data[i].logit - max_l);
|
||||||
|
cur_p->data[i].p = p;
|
||||||
|
sum_cum += p;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
// sample from the obtained probabilities and normalize the probs in a single pass
|
||||||
|
// this is ~3x faster on Mac with full gpt-oss vocab than the version below
|
||||||
|
//
|
||||||
|
std::uniform_real_distribution<double> dist(0.0f, 1.0f);
|
||||||
|
const double rnd = dist(ctx->rng);
|
||||||
|
|
||||||
|
double sum_run = 0.0f;
|
||||||
|
const double sum_tgt = sum_cum*rnd;
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
if (!found) {
|
||||||
|
// accumulate probs until we reach the target sum
|
||||||
|
sum_run += cur_p->data[i].p;
|
||||||
|
if (sum_run >= sum_tgt) {
|
||||||
|
cur_p->selected = i;
|
||||||
|
found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize probs
|
||||||
|
cur_p->data[i].p /= sum_cum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback to the last token (don't think this can happen)
|
||||||
|
assert(found);
|
||||||
|
if (!found) {
|
||||||
|
cur_p->selected = cur_p->size - 1;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
cur_p->data[i].p /= sum_cum;
|
||||||
|
}
|
||||||
|
|
||||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
||||||
|
|||||||
Reference in New Issue
Block a user