mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : minor sampling refactor (2) (#9386)
This commit is contained in:
		| @@ -140,8 +140,6 @@ while n_cur <= n_len { | |||||||
|  |  | ||||||
|         let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) |         let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) | ||||||
|  |  | ||||||
|         llama_sampler_accept(smpl, new_token_id) |  | ||||||
|  |  | ||||||
|         // is it an end of stream? -> mark the stream as finished |         // is it an end of stream? -> mark the stream as finished | ||||||
|         if llama_token_is_eog(model, new_token_id) || n_cur == n_len { |         if llama_token_is_eog(model, new_token_id) || n_cur == n_len { | ||||||
|             i_batch[i] = -1 |             i_batch[i] = -1 | ||||||
|   | |||||||
| @@ -172,8 +172,6 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); |             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); | ||||||
|  |  | ||||||
|             llama_sampler_accept(smpl, new_token_id); |  | ||||||
|  |  | ||||||
|             // is it an end of generation? -> mark the stream as finished |             // is it an end of generation? -> mark the stream as finished | ||||||
|             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { |             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { | ||||||
|                 i_batch[i] = -1; |                 i_batch[i] = -1; | ||||||
|   | |||||||
| @@ -121,7 +121,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std | |||||||
|         llama_decode(ctx, bat); |         llama_decode(ctx, bat); | ||||||
|  |  | ||||||
|         llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); |         llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); | ||||||
|         llama_sampler_accept(smpl, token); |  | ||||||
|  |  | ||||||
|         if (token == eos_token) { |         if (token == eos_token) { | ||||||
|             break; |             break; | ||||||
|   | |||||||
| @@ -414,8 +414,6 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( | |||||||
|     // sample the most likely token |     // sample the most likely token | ||||||
|     const auto new_token_id = llama_sampler_sample(sampler, context, -1); |     const auto new_token_id = llama_sampler_sample(sampler, context, -1); | ||||||
|  |  | ||||||
|     llama_sampler_accept(sampler, new_token_id); |  | ||||||
|  |  | ||||||
|     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); |     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); | ||||||
|     if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { |     if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { | ||||||
|         return nullptr; |         return nullptr; | ||||||
|   | |||||||
| @@ -152,8 +152,6 @@ actor LlamaContext { | |||||||
|  |  | ||||||
|         new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) |         new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) | ||||||
|  |  | ||||||
|         llama_sampler_accept(sampling, new_token_id) |  | ||||||
|  |  | ||||||
|         if llama_token_is_eog(model, new_token_id) || n_cur == n_len { |         if llama_token_is_eog(model, new_token_id) || n_cur == n_len { | ||||||
|             print("\n") |             print("\n") | ||||||
|             is_done = true |             is_done = true | ||||||
|   | |||||||
| @@ -220,8 +220,6 @@ int main(int argc, char ** argv) { | |||||||
|         { |         { | ||||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); |             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); | ||||||
|  |  | ||||||
|             llama_sampler_accept(smpl, new_token_id); |  | ||||||
|  |  | ||||||
|             // is it an end of generation? |             // is it an end of generation? | ||||||
|             if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { |             if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { | ||||||
|                 LOG_TEE("\n"); |                 LOG_TEE("\n"); | ||||||
|   | |||||||
| @@ -74,8 +74,6 @@ int main(int argc, char ** argv) { | |||||||
|         auto next_token     = llama_sampler_sample(smpl, ctx, -1); |         auto next_token     = llama_sampler_sample(smpl, ctx, -1); | ||||||
|         auto next_token_str = llama_token_to_piece(ctx, next_token); |         auto next_token_str = llama_token_to_piece(ctx, next_token); | ||||||
|  |  | ||||||
|         llama_sampler_accept(smpl, next_token); |  | ||||||
|  |  | ||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result0 += next_token_str; |         result0 += next_token_str; | ||||||
|  |  | ||||||
| @@ -132,8 +130,6 @@ int main(int argc, char ** argv) { | |||||||
|         auto next_token     = llama_sampler_sample(smpl2, ctx2, -1); |         auto next_token     = llama_sampler_sample(smpl2, ctx2, -1); | ||||||
|         auto next_token_str = llama_token_to_piece(ctx2, next_token); |         auto next_token_str = llama_token_to_piece(ctx2, next_token); | ||||||
|  |  | ||||||
|         llama_sampler_accept(smpl2, next_token); |  | ||||||
|  |  | ||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result1 += next_token_str; |         result1 += next_token_str; | ||||||
|  |  | ||||||
| @@ -222,8 +218,6 @@ int main(int argc, char ** argv) { | |||||||
|         auto next_token     = llama_sampler_sample(smpl3, ctx3, -1); |         auto next_token     = llama_sampler_sample(smpl3, ctx3, -1); | ||||||
|         auto next_token_str = llama_token_to_piece(ctx3, next_token); |         auto next_token_str = llama_token_to_piece(ctx3, next_token); | ||||||
|  |  | ||||||
|         llama_sampler_accept(smpl3, next_token); |  | ||||||
|  |  | ||||||
|         printf("%s", next_token_str.c_str()); |         printf("%s", next_token_str.c_str()); | ||||||
|         result2 += next_token_str; |         result2 += next_token_str; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -613,7 +613,7 @@ struct server_context { | |||||||
|  |  | ||||||
|     gpt_params params; |     gpt_params params; | ||||||
|  |  | ||||||
|     llama_batch batch; |     llama_batch batch = {}; | ||||||
|  |  | ||||||
|     bool clean_kv_cache = true; |     bool clean_kv_cache = true; | ||||||
|     bool add_bos_token  = true; |     bool add_bos_token  = true; | ||||||
|   | |||||||
| @@ -118,8 +118,6 @@ int main(int argc, char ** argv) { | |||||||
|         { |         { | ||||||
|             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); |             const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); | ||||||
|  |  | ||||||
|             llama_sampler_accept(smpl, new_token_id); |  | ||||||
|  |  | ||||||
|             // is it an end of generation? |             // is it an end of generation? | ||||||
|             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { |             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { | ||||||
|                 LOG_TEE("\n"); |                 LOG_TEE("\n"); | ||||||
|   | |||||||
| @@ -1127,15 +1127,16 @@ extern "C" { | |||||||
|                              int32_t   n_logit_bias, |                              int32_t   n_logit_bias, | ||||||
|               const llama_logit_bias * logit_bias); |               const llama_logit_bias * logit_bias); | ||||||
|  |  | ||||||
|     // Shorthand for: |     /// @details Sample and accept a token from the idx-th output of the last evaluation | ||||||
|     // |     // | ||||||
|  |     // Shorthand for: | ||||||
|     //    const auto * logits = llama_get_logits_ith(ctx, idx); |     //    const auto * logits = llama_get_logits_ith(ctx, idx); | ||||||
|     //    llama_token_data_array cur_p = { ... init from logits ... }; |     //    llama_token_data_array cur_p = { ... init from logits ... }; | ||||||
|     //    llama_sampler_apply(smpl, &cur_p); |     //    llama_sampler_apply(smpl, &cur_p); | ||||||
|     //    return cur_p.data[cur_p.selected].id; |     //    auto token = cur_p.data[cur_p.selected].id; | ||||||
|     // |     //    llama_sampler_accept(smpl, token); | ||||||
|     // At this point, this is mostly a convenience function. |     //    return token; | ||||||
|     // |     // Returns the sampled token | ||||||
|     LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); |     LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); | ||||||
|  |  | ||||||
|     // TODO: extend in the future |     // TODO: extend in the future | ||||||
|   | |||||||
| @@ -8,49 +8,44 @@ | |||||||
| #include <cstring> | #include <cstring> | ||||||
| #include <ctime> | #include <ctime> | ||||||
| #include <cfloat> | #include <cfloat> | ||||||
|  | #include <cmath> | ||||||
| #include <numeric> | #include <numeric> | ||||||
| #include <random> | #include <random> | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
|  |  | ||||||
| static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) { | static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { | ||||||
| #if 1 |     // iterator for the probabilities | ||||||
|     probs.resize(cur_p->size); | #ifdef __GNUC__ | ||||||
|     for (size_t i = 0; i < cur_p->size; ++i) { |  | ||||||
|         probs[i] = cur_p->data[i].p; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     std::discrete_distribution<size_t> dist(probs.begin(), probs.end()); |  | ||||||
| #else |  | ||||||
|     // avoid the copy with a custom iterator |  | ||||||
|     #pragma GCC diagnostic push |     #pragma GCC diagnostic push | ||||||
|     #pragma GCC diagnostic ignored "-Wunused-local-typedefs" |     #pragma GCC diagnostic ignored "-Wunused-local-typedefs" | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     struct probs_iterator { |     struct probs_iterator { | ||||||
|         typedef std::input_iterator_tag iterator_category; |         typedef std::input_iterator_tag iterator_category; | ||||||
|         typedef float value_type; |         typedef float value_type; | ||||||
|         typedef float * pointer; |         typedef float * pointer; | ||||||
|         typedef float & reference; |         typedef float & reference; | ||||||
|         typedef size_t difference_type; |         typedef ptrdiff_t difference_type; | ||||||
|  |  | ||||||
|         const llama_token_data_array * data; |         const llama_token_data * data; | ||||||
|         size_t i; |  | ||||||
|  |  | ||||||
|         bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; } |         bool operator==(const probs_iterator & other) const { return data == other.data; } | ||||||
|         bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; } |         bool operator!=(const probs_iterator & other) const { return data != other.data; } | ||||||
|         float operator*() const { return data->data[i].p; } |         const float & operator*() const { return data->p; } | ||||||
|         probs_iterator & operator++() { ++i; return *this; } |         probs_iterator & operator++() { ++data; return *this; } | ||||||
|         probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; } |         probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|  | #ifdef __GNUC__ | ||||||
|     #pragma GCC diagnostic pop |     #pragma GCC diagnostic pop | ||||||
|  |  | ||||||
|     std::discrete_distribution<size_t> dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size}); |  | ||||||
|  |  | ||||||
|     GGML_UNUSED(probs); |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  |     std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); | ||||||
|  |  | ||||||
|     return dist(rng); |     return dist(rng); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* | ||||||
| static void llama_log_softmax(float * array, size_t size) { | static void llama_log_softmax(float * array, size_t size) { | ||||||
|     float max_l = *std::max_element(array, array + size); |     float max_l = *std::max_element(array, array + size); | ||||||
|     float sum = 0.f; |     float sum = 0.f; | ||||||
| @@ -64,6 +59,7 @@ static void llama_log_softmax(float * array, size_t size) { | |||||||
|         array[i] = logf(array[i] / sum); |         array[i] = logf(array[i] / sum); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | */ | ||||||
|  |  | ||||||
| static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { | static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { | ||||||
|     GGML_ASSERT(cur_p->size > 0); |     GGML_ASSERT(cur_p->size > 0); | ||||||
| @@ -231,67 +227,92 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte | |||||||
|         cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; |         cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; |     llama_token_data_array cur_p = { | ||||||
|  |         /* .data       = */ cur.data(), | ||||||
|  |         /* .size       = */ cur.size(), | ||||||
|  |         /* .selected   = */ -1, | ||||||
|  |         /* .sorted     = */ false, | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     llama_sampler_apply(smpl, &cur_p); |     llama_sampler_apply(smpl, &cur_p); | ||||||
|  |  | ||||||
|     return cur_p.data[cur_p.selected].id; |     GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); | ||||||
|  |  | ||||||
|  |     auto token = cur_p.data[cur_p.selected].id; | ||||||
|  |  | ||||||
|  |     llama_sampler_accept(smpl, token); | ||||||
|  |  | ||||||
|  |     return token; | ||||||
| } | } | ||||||
|  |  | ||||||
| // sampler chain | // sampler chain | ||||||
|  |  | ||||||
|  | static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { | ||||||
|  |     return "chain"; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { | ||||||
|  |     auto * chain = (llama_sampler_chain *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     time_meas tm(chain->t_sample_us, chain->params.no_perf); | ||||||
|  |  | ||||||
|  |     for (auto * smpl : chain->samplers) { | ||||||
|  |         llama_sampler_accept(smpl, token); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     chain->n_sample++; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||||||
|  |     auto * chain = (llama_sampler_chain *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     time_meas tm(chain->t_sample_us, chain->params.no_perf); | ||||||
|  |  | ||||||
|  |     for (auto * smpl : chain->samplers) { | ||||||
|  |         llama_sampler_apply(smpl, cur_p); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llama_sampler_chain_reset(struct llama_sampler * smpl) { | ||||||
|  |     auto * chain = (llama_sampler_chain *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     for (auto * smpl : chain->samplers) { | ||||||
|  |         llama_sampler_reset(smpl); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     chain->t_sample_us = 0; | ||||||
|  |     chain->n_sample    = 0; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { | ||||||
|  |     const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     auto * result = llama_sampler_chain_init(chain_src->params); | ||||||
|  |  | ||||||
|  |     for (auto * smpl : chain_src->samplers) { | ||||||
|  |         llama_sampler_chain_add(result, llama_sampler_clone(smpl)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return result; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static void llama_sampler_chain_free(struct llama_sampler * smpl) { | ||||||
|  |     auto * chain = (llama_sampler_chain *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     for (auto * smpl : chain->samplers) { | ||||||
|  |         llama_sampler_free(smpl); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     delete chain; | ||||||
|  | } | ||||||
|  |  | ||||||
| static struct llama_sampler_i llama_sampler_chain_i = { | static struct llama_sampler_i llama_sampler_chain_i = { | ||||||
|     /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, |     /* .name   = */ llama_sampler_chain_name, | ||||||
|     /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { |     /* .accept = */ llama_sampler_chain_accept, | ||||||
|         auto * chain = (llama_sampler_chain *) smpl->ctx; |     /* .apply  = */ llama_sampler_chain_apply, | ||||||
|  |     /* .reset  = */ llama_sampler_chain_reset, | ||||||
|         time_meas tm(chain->t_sample_us, chain->params.no_perf); |     /* .clone  = */ llama_sampler_chain_clone, | ||||||
|  |     /* .free   = */ llama_sampler_chain_free, | ||||||
|         for (auto * smpl : chain->samplers) { |  | ||||||
|             llama_sampler_accept(smpl, token); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         chain->n_sample++; |  | ||||||
|     }, |  | ||||||
|     /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { |  | ||||||
|         auto * chain = (llama_sampler_chain *) smpl->ctx; |  | ||||||
|  |  | ||||||
|         time_meas tm(chain->t_sample_us, chain->params.no_perf); |  | ||||||
|  |  | ||||||
|         for (auto * smpl : chain->samplers) { |  | ||||||
|             llama_sampler_apply(smpl, cur_p); |  | ||||||
|         } |  | ||||||
|     }, |  | ||||||
|     /* .reset  = */ [](struct llama_sampler * smpl) { |  | ||||||
|         auto * chain = (llama_sampler_chain *) smpl->ctx; |  | ||||||
|  |  | ||||||
|         for (auto * smpl : chain->samplers) { |  | ||||||
|             llama_sampler_reset(smpl); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         chain->t_sample_us = 0; |  | ||||||
|         chain->n_sample    = 0; |  | ||||||
|     }, |  | ||||||
|     /* .clone  = */ [](const struct llama_sampler * smpl) { |  | ||||||
|         const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; |  | ||||||
|  |  | ||||||
|         auto * result = llama_sampler_chain_init(chain_src->params); |  | ||||||
|  |  | ||||||
|         for (auto * smpl : chain_src->samplers) { |  | ||||||
|             llama_sampler_chain_add(result, llama_sampler_clone(smpl)); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         return result; |  | ||||||
|     }, |  | ||||||
|     /* .free   = */ [](struct llama_sampler * smpl) { |  | ||||||
|         auto * chain = (llama_sampler_chain *) smpl->ctx; |  | ||||||
|  |  | ||||||
|         for (auto * smpl : chain->samplers) { |  | ||||||
|             llama_sampler_free(smpl); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         delete chain; |  | ||||||
|     }, |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { | struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { | ||||||
| @@ -368,8 +389,6 @@ struct llama_sampler_dist { | |||||||
|     const uint32_t seed; |     const uint32_t seed; | ||||||
|  |  | ||||||
|     std::mt19937 rng; |     std::mt19937 rng; | ||||||
|  |  | ||||||
|     std::vector<float> probs; // work array |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { | static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { | ||||||
| @@ -378,7 +397,7 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* | |||||||
|  |  | ||||||
| static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||||||
|     auto * ctx = (llama_sampler_dist *) smpl->ctx; |     auto * ctx = (llama_sampler_dist *) smpl->ctx; | ||||||
|     cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); |     cur_p->selected = llama_sample_dist(cur_p, ctx->rng); | ||||||
| } | } | ||||||
|  |  | ||||||
| static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { | static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { | ||||||
| @@ -419,7 +438,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { | |||||||
|         /* .ctx   = */ new llama_sampler_dist { |         /* .ctx   = */ new llama_sampler_dist { | ||||||
|             /* .seed = */ seed, |             /* .seed = */ seed, | ||||||
|             /* .rng  = */ std::mt19937(seed), |             /* .rng  = */ std::mt19937(seed), | ||||||
|             /* .probs = */ {}, |  | ||||||
|         }, |         }, | ||||||
|     }; |     }; | ||||||
| } | } | ||||||
| @@ -1023,8 +1041,6 @@ struct llama_sampler_mirostat { | |||||||
|     float mu; |     float mu; | ||||||
|  |  | ||||||
|     std::mt19937 rng; |     std::mt19937 rng; | ||||||
|  |  | ||||||
|     std::vector<float> probs; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { | static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { | ||||||
| @@ -1055,7 +1071,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke | |||||||
|     llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); |     llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); | ||||||
|     llama_sampler_softmax_impl(cur_p); |     llama_sampler_softmax_impl(cur_p); | ||||||
|  |  | ||||||
|     const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); |     const int idx = llama_sample_dist(cur_p, ctx->rng); | ||||||
|  |  | ||||||
|     cur_p->selected = idx; |     cur_p->selected = idx; | ||||||
|  |  | ||||||
| @@ -1111,7 +1127,6 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see | |||||||
|             /* .m       = */ m, |             /* .m       = */ m, | ||||||
|             /* .mu      = */ 2.0f*tau, |             /* .mu      = */ 2.0f*tau, | ||||||
|             /* .rng     = */ std::mt19937(seed), |             /* .rng     = */ std::mt19937(seed), | ||||||
|             /* .probs   = */ {}, |  | ||||||
|         }, |         }, | ||||||
|     }; |     }; | ||||||
| } | } | ||||||
| @@ -1127,8 +1142,6 @@ struct llama_sampler_mirostat_v2 { | |||||||
|     float mu; |     float mu; | ||||||
|  |  | ||||||
|     std::mt19937 rng; |     std::mt19937 rng; | ||||||
|  |  | ||||||
|     std::vector<float> probs; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { | static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { | ||||||
| @@ -1152,7 +1165,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t | |||||||
|     // Normalize the probabilities of the remaining words |     // Normalize the probabilities of the remaining words | ||||||
|     llama_sampler_softmax_impl(cur_p); |     llama_sampler_softmax_impl(cur_p); | ||||||
|  |  | ||||||
|     const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); |     const int idx = llama_sample_dist(cur_p, ctx->rng); | ||||||
|  |  | ||||||
|     cur_p->selected = idx; |     cur_p->selected = idx; | ||||||
|  |  | ||||||
| @@ -1207,7 +1220,6 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, | |||||||
|             /* .eta   = */ eta, |             /* .eta   = */ eta, | ||||||
|             /* .mu    = */ 2.0f*tau, |             /* .mu    = */ 2.0f*tau, | ||||||
|             /* .rng   = */ std::mt19937(seed), |             /* .rng   = */ std::mt19937(seed), | ||||||
|             /* .probs = */ {}, |  | ||||||
|         }, |         }, | ||||||
|     }; |     }; | ||||||
| } | } | ||||||
| @@ -1527,6 +1539,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * / | |||||||
| static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||||||
|     auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; |     auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; | ||||||
|  |  | ||||||
|  |     if (ctx->logit_bias.empty()) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     ctx->to_search.clear(); |     ctx->to_search.clear(); | ||||||
|  |  | ||||||
|     // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) |     // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) | ||||||
| @@ -1538,6 +1554,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if (ctx->to_search.empty()) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // search for the remaining candidates that were not found in the previous step |     // search for the remaining candidates that were not found in the previous step | ||||||
|     for (size_t i = 0; i < cur_p->size; ++i) { |     for (size_t i = 0; i < cur_p->size; ++i) { | ||||||
|         for (const auto & lb : ctx->to_search) { |         for (const auto & lb : ctx->to_search) { | ||||||
|   | |||||||
| @@ -245,7 +245,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n", |     printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n", | ||||||
|            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); |            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren