mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727)
The C API in llama.h claims users can implement `llama_sampler_i` to create custom `llama_sampler`. The sampler chain takes ownership and calls `llama_sampler_free` on them. However, `llama_sampler_free` is hard-coded to use `delete`. This is undefined behavior if the object wasn't also allocated via `new` from libllama's C++ runtime. Callers in C and C-compatible languages do not use C++'s `new` operator. C++ callers may not be sharing the same heap as libllama.
This commit is contained in:
		 Christian Fillion
					Christian Fillion
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							ec3bc8270b
						
					
				
				
					commit
					7ee953a64a
				
			| @@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     return new llama_sampler{ | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_llg_i, | ||||
|         /* .ctx   = */ ctx, | ||||
|     }; | ||||
|         /* .ctx   = */ ctx | ||||
|     ); | ||||
| } | ||||
|  | ||||
| #else | ||||
|   | ||||
| @@ -1114,11 +1114,12 @@ extern "C" { | ||||
|     }; | ||||
|  | ||||
|     struct llama_sampler { | ||||
|         struct llama_sampler_i  * iface; | ||||
|         llama_sampler_context_t   ctx; | ||||
|         const struct llama_sampler_i * iface; | ||||
|         llama_sampler_context_t        ctx; | ||||
|     }; | ||||
|  | ||||
|     // mirror of llama_sampler_i: | ||||
|     LLAMA_API struct llama_sampler * llama_sampler_init  (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); | ||||
|     LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl); | ||||
|     LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token); | ||||
|     LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p); | ||||
|   | ||||
| @@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) { | ||||
|  | ||||
| // llama_sampler API | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { | ||||
|     return new llama_sampler { | ||||
|         /* .iface = */ iface, | ||||
|         /* .ctx   = */ ctx, | ||||
|     }; | ||||
| } | ||||
|  | ||||
| const char * llama_sampler_name(const struct llama_sampler * smpl) { | ||||
|     if (!smpl->iface) { | ||||
|         return "(null)"; | ||||
| @@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { | ||||
|     } | ||||
|  | ||||
|     if (smpl->ctx == nullptr) { | ||||
|         return new llama_sampler { | ||||
|         return llama_sampler_init( | ||||
|             /* .iface = */ smpl->iface, | ||||
|             /* .ctx   = */ nullptr, | ||||
|         }; | ||||
|             /* .ctx   = */ nullptr | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     GGML_ABORT("the sampler does not support cloning"); | ||||
| @@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_chain_i, | ||||
|         /* .ctx   = */ new llama_sampler_chain { | ||||
|             /* .params      = */ params, | ||||
|             /* .samplers    = */ {}, | ||||
|             /* .t_sample_us = */ 0, | ||||
|             /* .n_sample    = */ 0, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { | ||||
| @@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_greedy() { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_greedy_i, | ||||
|         /* .ctx   = */ nullptr, | ||||
|     }; | ||||
|         /* .ctx   = */ nullptr | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // dist | ||||
| @@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = { | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_dist_i, | ||||
|         /* .ctx   = */ new llama_sampler_dist { | ||||
|             /* .seed     = */ seed, | ||||
|             /* .seed_cur = */ seed_cur, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // softmax | ||||
| @@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_softmax() { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_softmax_i, | ||||
|         /* .ctx   = */ nullptr, | ||||
|     }; | ||||
|         /* .ctx   = */ nullptr | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // top-k | ||||
| @@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_top_k(int32_t k) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_top_k_i, | ||||
|         /* .ctx   = */ new llama_sampler_top_k { | ||||
|             /* .k = */ k, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // top-p | ||||
| @@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_top_p_i, | ||||
|         /* .ctx   = */ new llama_sampler_top_p { | ||||
|             /* .p        = */ p, | ||||
|             /* .min_keep = */ min_keep, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // min-p | ||||
| @@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_min_p_i, | ||||
|         /* .ctx   = */ new llama_sampler_min_p { | ||||
|             /* .p        = */ p, | ||||
|             /* .min_keep = */ min_keep, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // typical | ||||
| @@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_typical_i, | ||||
|         /* .ctx   = */ new llama_sampler_typical { | ||||
|             /* .p        = */ p, | ||||
|             /* .min_keep = */ min_keep, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // temp | ||||
| @@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_temp(float temp) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_temp_i, | ||||
|         /* .ctx   = */ new llama_sampler_temp { | ||||
|             /*.temp = */ temp, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // temp-ext | ||||
| @@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_temp_ext_i, | ||||
|         /* .ctx   = */ new llama_sampler_temp_ext { | ||||
|             /* .temp     = */ temp, | ||||
|             /* .delta    = */ delta, | ||||
|             /* .exponent = */ exponent, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // xtc | ||||
| @@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = { | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_xtc_i, | ||||
|         /* .ctx   = */ new llama_sampler_xtc { | ||||
|             /* .probability   = */ p, | ||||
| @@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, | ||||
|             /* .seed          = */ seed, | ||||
|             /* .seed_cur      = */ seed_cur, | ||||
|             /* .rng           = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // mirostat | ||||
| @@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_mirostat_i, | ||||
|         /* .ctx   = */ new llama_sampler_mirostat { | ||||
|             /* .n_vocab  = */ n_vocab, | ||||
| @@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see | ||||
|             /* .m        = */ m, | ||||
|             /* .mu       = */ 2.0f*tau, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // mirostat v2 | ||||
| @@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { | ||||
|     auto seed_cur = get_rng_seed(seed); | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_mirostat_v2_i, | ||||
|         /* .ctx   = */ new llama_sampler_mirostat_v2 { | ||||
|             /* .seed     = */ seed, | ||||
| @@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, | ||||
|             /* .eta      = */ eta, | ||||
|             /* .mu       = */ 2.0f*tau, | ||||
|             /* .rng      = */ std::mt19937(seed_cur), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // grammar | ||||
| @@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_grammar_i, | ||||
|         /* .ctx   = */ ctx, | ||||
|     }; | ||||
|         /* .ctx   = */ ctx | ||||
|     ); | ||||
| } | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_grammar( | ||||
| @@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties( | ||||
|         float penalty_present) { | ||||
|     penalty_last_n = std::max(penalty_last_n, 0); | ||||
|  | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_penalties_i, | ||||
|         /* .ctx   = */ new llama_sampler_penalties { | ||||
|             /* .penalty_last_n  = */ penalty_last_n, | ||||
| @@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties( | ||||
|             /* .penalty_present = */ penalty_present, | ||||
|             /* .prev            = */ ring_buffer<llama_token>(penalty_last_n), | ||||
|             /* .token_count     = */ {}, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // DRY | ||||
| @@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_dry_i, | ||||
|         /* .ctx   = */ new llama_sampler_dry { | ||||
|             /* .total_context_size     = */ context_size, | ||||
| @@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, | ||||
|             /* .dry_repeat_count       = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, | ||||
|             /* .dry_max_token_repeat   = */ {}, | ||||
|             /* .last_tokens            = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // wrapper for test-sampling.cpp | ||||
| @@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias( | ||||
|                          int32_t   n_vocab, | ||||
|                          int32_t   n_logit_bias, | ||||
|           const llama_logit_bias * logit_bias) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_logit_bias_i, | ||||
|         /* .ctx   = */ new llama_sampler_logit_bias { | ||||
|             /* .n_vocab    = */ n_vocab, | ||||
|             /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), | ||||
|             /* .to_search  = */ {}, | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // infill | ||||
| @@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = { | ||||
| }; | ||||
|  | ||||
| struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { | ||||
|     return new llama_sampler { | ||||
|     return llama_sampler_init( | ||||
|         /* .iface = */ &llama_sampler_infill_i, | ||||
|         /* .ctx   = */ new llama_sampler_infill { | ||||
|             /* .vocab = */ vocab, | ||||
|             /* .buf0  = */ std::vector<char>(512), | ||||
|             /* .buf1  = */ std::vector<char>(512), | ||||
|         }, | ||||
|     }; | ||||
|         } | ||||
|     ); | ||||
| } | ||||
|  | ||||
| // utils | ||||
|   | ||||
		Reference in New Issue
	
	Block a user