mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	samplers : Min-P sampler implementation [alternative to Top P/Top K] (#3841)
* Introduce the new Min-P sampler by @kalomaze The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. * Min-P enabled and set to 0.05 default --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
		| @@ -218,6 +218,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             sparams.top_p = std::stof(argv[i]); |             sparams.top_p = std::stof(argv[i]); | ||||||
|  |         } else if (arg == "--min-p") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             sparams.min_p = std::stof(argv[i]); | ||||||
|         } else if (arg == "--temp") { |         } else if (arg == "--temp") { | ||||||
|             if (++i >= argc) { |             if (++i >= argc) { | ||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
| @@ -679,6 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch); |     printf("  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch); | ||||||
|     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); |     printf("  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); | ||||||
|     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); |     printf("  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); | ||||||
|  |     printf("  --min-p N             min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); | ||||||
|     printf("  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); |     printf("  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); | ||||||
|     printf("  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); |     printf("  --typical N           locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); | ||||||
|     printf("  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); |     printf("  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); | ||||||
| @@ -1275,6 +1282,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l | |||||||
|     fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); |     fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); | ||||||
|     fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); |     fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); | ||||||
|     fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); |     fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); | ||||||
|  |     fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); | ||||||
|     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); |     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); | ||||||
|     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); |     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -89,10 +89,10 @@ std::string llama_sampling_print(const llama_sampling_params & params) { | |||||||
|  |  | ||||||
|     snprintf(result, sizeof(result), |     snprintf(result, sizeof(result), | ||||||
|             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" |             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" | ||||||
|             "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" |             "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" | ||||||
|             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", |             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", | ||||||
|             params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, |             params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, | ||||||
|             params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, |             params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, | ||||||
|             params.mirostat, params.mirostat_eta, params.mirostat_tau); |             params.mirostat, params.mirostat_eta, params.mirostat_tau); | ||||||
|  |  | ||||||
|     return std::string(result); |     return std::string(result); | ||||||
| @@ -110,6 +110,7 @@ llama_token llama_sampling_sample( | |||||||
|     const float   temp            = params.temp; |     const float   temp            = params.temp; | ||||||
|     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k; |     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k; | ||||||
|     const float   top_p           = params.top_p; |     const float   top_p           = params.top_p; | ||||||
|  |     const float   min_p           = params.min_p; | ||||||
|     const float   tfs_z           = params.tfs_z; |     const float   tfs_z           = params.tfs_z; | ||||||
|     const float   typical_p       = params.typical_p; |     const float   typical_p       = params.typical_p; | ||||||
|     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; |     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; | ||||||
| @@ -190,6 +191,7 @@ llama_token llama_sampling_sample( | |||||||
|             llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); |             llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); | ||||||
|             llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); |             llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); | ||||||
|             llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); |             llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); | ||||||
|  |             llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); | ||||||
|             llama_sample_temp     (ctx_main, &cur_p, temp); |             llama_sample_temp     (ctx_main, &cur_p, temp); | ||||||
|  |  | ||||||
|             id = llama_sample_token(ctx_main, &cur_p); |             id = llama_sample_token(ctx_main, &cur_p); | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ typedef struct llama_sampling_params { | |||||||
|     int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens. |     int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens. | ||||||
|     int32_t top_k             = 40;    // <= 0 to use vocab size |     int32_t top_k             = 40;    // <= 0 to use vocab size | ||||||
|     float   top_p             = 0.95f; // 1.0 = disabled |     float   top_p             = 0.95f; // 1.0 = disabled | ||||||
|  |     float   min_p             = 0.05f; // 0.0 = disabled | ||||||
|     float   tfs_z             = 1.00f; // 1.0 = disabled |     float   tfs_z             = 1.00f; // 1.0 = disabled | ||||||
|     float   typical_p         = 1.00f; // 1.0 = disabled |     float   typical_p         = 1.00f; // 1.0 = disabled | ||||||
|     float   temp              = 0.80f; // 1.0 = disabled |     float   temp              = 0.80f; // 1.0 = disabled | ||||||
|   | |||||||
| @@ -208,6 +208,14 @@ Top-p sampling, also known as nucleus sampling, is another text generation metho | |||||||
|  |  | ||||||
| Example usage: `--top-p 0.95` | Example usage: `--top-p 0.95` | ||||||
|  |  | ||||||
|  | ### Min P Sampling | ||||||
|  |  | ||||||
|  | -   `--min-p N`: Sets a minimum base probability threshold for token selection (default: 0.05). | ||||||
|  |  | ||||||
|  | The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. | ||||||
|  |  | ||||||
|  | Example usage: `--min-p 0.05` | ||||||
|  |  | ||||||
| ### Tail Free Sampling (TFS) | ### Tail Free Sampling (TFS) | ||||||
|  |  | ||||||
| -   `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). | -   `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -7368,6 +7368,32 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { | ||||||
|  |     if (p <= 0.0f || !candidates->size) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     llama_sample_softmax(ctx, candidates); | ||||||
|  |  | ||||||
|  |     const int64_t t_start_sample_us = ggml_time_us(); | ||||||
|  |  | ||||||
|  |     float scale = candidates->data[0].p; // scale by max prob | ||||||
|  |     size_t i = 1; // first token always matches | ||||||
|  |  | ||||||
|  |     for (; i < candidates->size; ++i) { | ||||||
|  |         if (candidates->data[i].p < p * scale && i >= min_keep) { | ||||||
|  |             break; // prob too small | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Resize the output vector to keep only the matching tokens | ||||||
|  |     candidates->size = i; | ||||||
|  |  | ||||||
|  |     if (ctx) { | ||||||
|  |         ctx->t_sample_us += ggml_time_us() - t_start_sample_us; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { | void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { | ||||||
|     if (z >= 1.0f || candidates->size <= 2) { |     if (z >= 1.0f || candidates->size <= 2) { | ||||||
|         return; |         return; | ||||||
|   | |||||||
							
								
								
									
										7
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								llama.h
									
									
									
									
									
								
							| @@ -598,6 +598,13 @@ extern "C" { | |||||||
|                            float   p, |                            float   p, | ||||||
|                           size_t   min_keep); |                           size_t   min_keep); | ||||||
|  |  | ||||||
|  |     /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 | ||||||
|  |     LLAMA_API void llama_sample_min_p( | ||||||
|  |             struct llama_context * ctx, | ||||||
|  |           llama_token_data_array * candidates, | ||||||
|  |                            float   p, | ||||||
|  |                           size_t   min_keep); | ||||||
|  |  | ||||||
|     /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. |     /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. | ||||||
|     LLAMA_API void llama_sample_tail_free( |     LLAMA_API void llama_sample_tail_free( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 kalomaze
					kalomaze