mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : move random seed generation to the samplers (#9398)
* llama_sampler_penalties : clamp penalty_last_n to zero
This commit is contained in:
		| @@ -8,6 +8,7 @@ | ||||
| #include <cstring> | ||||
| #include <ctime> | ||||
| #include <cfloat> | ||||
| #include <chrono> | ||||
| #include <cmath> | ||||
| #include <numeric> | ||||
| #include <random> | ||||
| @@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) | ||||
|     cur_p->size = k; | ||||
| } | ||||
|  | ||||
| static uint32_t get_rng_seed(uint32_t seed) { | ||||
|     if (seed == LLAMA_DEFAULT_SEED) { | ||||
|         // use system clock if std::random_device is not a true RNG | ||||
|         static bool is_rd_prng = std::random_device().entropy() == 0; | ||||
|         if (is_rd_prng) { | ||||
|             return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); | ||||
|         } | ||||
|         std::random_device rd; | ||||
|         return rd(); | ||||
|     } | ||||
|     return seed; | ||||
| } | ||||
|  | ||||
| // llama_sampler API | ||||
|  | ||||
| const char * llama_sampler_name(const struct llama_sampler * smpl) { | ||||
| @@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() { | ||||
|  | ||||
| struct llama_sampler_dist { | ||||
|     const uint32_t seed; | ||||
|           uint32_t seed_cur; | ||||
|  | ||||
|     std::mt19937 rng; | ||||
| }; | ||||
| @@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample | ||||
|  | ||||
| static void llama_sampler_dist_reset(struct llama_sampler * smpl) { | ||||
|     auto * ctx = (llama_sampler_dist *) smpl->ctx; | ||||
|     ctx->rng = std::mt19937(ctx->seed); | ||||
|     ctx->seed_cur = get_rng_seed(ctx->seed); | ||||
|     ctx->rng.seed(ctx->seed_cur); | ||||
| } | ||||
|  | ||||
| static void llama_sampler_dist_free(struct llama_sampler * smpl) { | ||||
| @@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ &llama_sampler_dist_i, | ||||
|         /* .ctx   = */ new llama_sampler_dist { | ||||
|             /* .seed = */ seed, | ||||
|             /* .rng  = */ std::mt19937(seed), | ||||
|             /* .seed     = */ seed, | ||||
|             /* .seed_cur = */ seed_cur, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
| } | ||||
| @@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat { | ||||
|     const int32_t n_vocab; | ||||
|  | ||||
|     const uint32_t seed; | ||||
|           uint32_t seed_cur; | ||||
|  | ||||
|     const float tau; | ||||
|     const float eta; | ||||
| @@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa | ||||
| static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { | ||||
|     auto * ctx = (llama_sampler_mirostat *) smpl->ctx; | ||||
|     ctx->mu = 2.0f*ctx->tau; | ||||
|     ctx->rng = std::mt19937(ctx->seed); | ||||
|     ctx->seed_cur = get_rng_seed(ctx->seed); | ||||
|     ctx->rng.seed(ctx->seed_cur); | ||||
| } | ||||
|  | ||||
| static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { | ||||
| @@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ &llama_sampler_mirostat_i, | ||||
|         /* .ctx   = */ new llama_sampler_mirostat { | ||||
|             /* .n_vocab = */ n_vocab, | ||||
|             /* .seed    = */ seed, | ||||
|             /* .tau     = */ tau, | ||||
|             /* .eta     = */ eta, | ||||
|             /* .m       = */ m, | ||||
|             /* .mu      = */ 2.0f*tau, | ||||
|             /* .rng     = */ std::mt19937(seed), | ||||
|             /* .n_vocab  = */ n_vocab, | ||||
|             /* .seed     = */ seed, | ||||
|             /* .seed_cur = */ seed_cur, | ||||
|             /* .tau      = */ tau, | ||||
|             /* .eta      = */ eta, | ||||
|             /* .m        = */ m, | ||||
|             /* .mu       = */ 2.0f*tau, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
| } | ||||
| @@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see | ||||
|  | ||||
| struct llama_sampler_mirostat_v2 { | ||||
|     const uint32_t seed; | ||||
|           uint32_t seed_cur; | ||||
|  | ||||
|     const float tau; | ||||
|     const float eta; | ||||
| @@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t | ||||
| static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { | ||||
|     auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; | ||||
|     ctx->mu = 2.0f*ctx->tau; | ||||
|     ctx->rng = std::mt19937(ctx->seed); | ||||
|     ctx->seed_cur = get_rng_seed(ctx->seed); | ||||
|     ctx->rng.seed(ctx->seed_cur); | ||||
| } | ||||
|  | ||||
| static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { | ||||
| @@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ &llama_sampler_mirostat_v2_i, | ||||
|         /* .ctx   = */ new llama_sampler_mirostat_v2 { | ||||
|             /* .seed  = */ seed, | ||||
|             /* .tau   = */ tau, | ||||
|             /* .eta   = */ eta, | ||||
|             /* .mu    = */ 2.0f*tau, | ||||
|             /* .rng   = */ std::mt19937(seed), | ||||
|             /* .seed     = */ seed, | ||||
|             /* .seed_cur = */ seed_cur, | ||||
|             /* .tau      = */ tau, | ||||
|             /* .eta      = */ eta, | ||||
|             /* .mu       = */ 2.0f*tau, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
| } | ||||
| @@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties( | ||||
|         ignore_eos = false; | ||||
|     } | ||||
|  | ||||
|     penalty_last_n = std::max(penalty_last_n, 0); | ||||
|  | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ &llama_sampler_penalties_i, | ||||
|         /* .ctx   = */ new llama_sampler_penalties { | ||||
| @@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { | ||||
|     const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; | ||||
|     return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); | ||||
| @@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias( | ||||
|         }, | ||||
|     }; | ||||
| } | ||||
|  | ||||
| // utils | ||||
|  | ||||
| uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { | ||||
|     if (smpl->iface == &llama_sampler_dist_i) { | ||||
|         return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; | ||||
|     } | ||||
|  | ||||
|     if (smpl->iface == &llama_sampler_mirostat_i) { | ||||
|         return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; | ||||
|     } | ||||
|  | ||||
|     if (smpl->iface == &llama_sampler_mirostat_v2_i) { | ||||
|         return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; | ||||
|     } | ||||
|  | ||||
|     if (smpl->iface == &llama_sampler_chain_i) { | ||||
|         const auto * ctx = (const llama_sampler_chain *) smpl->ctx; | ||||
|         for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { | ||||
|             const uint32_t seed = llama_sampler_get_seed(*it); | ||||
|             if (seed != LLAMA_DEFAULT_SEED) { | ||||
|                 return seed; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return LLAMA_DEFAULT_SEED; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren