mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	sampling : custom samplers order (#4285)
* Samplers sequence order w parameter * Cleaned commented code * Fixed formatting * Rewrote with unordered_map * Revert and rewrite, too many problems and safeguards would be needed * Fixed code style * Code style fixes according to review * More readable samplers input string, fixed help * Style fix in sampler_queue * Formatting fixes * Fixing whitespaces
This commit is contained in:
		| @@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) { | ||||
|     return std::string(result); | ||||
| } | ||||
|  | ||||
| std::string llama_sampling_order_print(const llama_sampling_params & params) { | ||||
|     std::string result = "CFG -> Penalties "; | ||||
|     if (params.mirostat == 0) { | ||||
|         for (auto s : params.samplers_sequence) { | ||||
|             switch (s) { | ||||
|                 case 'k': result += "-> top_k "; break; | ||||
|                 case 'f': result += "-> tfs_z "; break; | ||||
|                 case 'y': result += "-> typical_p "; break; | ||||
|                 case 'p': result += "-> top_p "; break; | ||||
|                 case 'm': result += "-> min_p "; break; | ||||
|                 case 't': result += "-> temp "; break; | ||||
|                 default : break; | ||||
|             } | ||||
|         } | ||||
|     } else result += "-> mirostat "; | ||||
|  | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| // no reasons to expose this function in header | ||||
| void sampler_queue( | ||||
|                    struct llama_context * ctx_main, | ||||
|             const llama_sampling_params & params, | ||||
|                  llama_token_data_array & cur_p, | ||||
|                                  size_t & min_keep) { | ||||
|     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); | ||||
|  | ||||
|     const float         temp              = params.temp; | ||||
|     const int32_t       top_k             = params.top_k <= 0 ? n_vocab : params.top_k; | ||||
|     const float         top_p             = params.top_p; | ||||
|     const float         min_p             = params.min_p; | ||||
|     const float         tfs_z             = params.tfs_z; | ||||
|     const float         typical_p         = params.typical_p; | ||||
|     const std::string & samplers_sequence = params.samplers_sequence; | ||||
|  | ||||
|     for (auto s : samplers_sequence) { | ||||
|         switch (s){ | ||||
|             case 'k': llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break; | ||||
|             case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break; | ||||
|             case 'y': llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break; | ||||
|             case 'p': llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break; | ||||
|             case 'm': llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break; | ||||
|             case 't': llama_sample_temp     (ctx_main, &cur_p, temp); break; | ||||
|             default : break; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| llama_token llama_sampling_sample( | ||||
|                   struct llama_sampling_context * ctx_sampling, | ||||
|                   struct llama_context * ctx_main, | ||||
| @@ -109,11 +157,6 @@ llama_token llama_sampling_sample( | ||||
|     const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); | ||||
|  | ||||
|     const float   temp            = params.temp; | ||||
|     const int32_t top_k           = params.top_k <= 0 ? n_vocab : params.top_k; | ||||
|     const float   top_p           = params.top_p; | ||||
|     const float   min_p           = params.min_p; | ||||
|     const float   tfs_z           = params.tfs_z; | ||||
|     const float   typical_p       = params.typical_p; | ||||
|     const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; | ||||
|     const float   penalty_repeat  = params.penalty_repeat; | ||||
|     const float   penalty_freq    = params.penalty_freq; | ||||
| @@ -188,12 +231,7 @@ llama_token llama_sampling_sample( | ||||
|             // temperature sampling | ||||
|             size_t min_keep = std::max(1, params.n_probs); | ||||
|  | ||||
|             llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); | ||||
|             llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); | ||||
|             llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); | ||||
|             llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); | ||||
|             llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); | ||||
|             llama_sample_temp     (ctx_main, &cur_p, temp); | ||||
|             sampler_queue(ctx_main, params, cur_p, min_keep); | ||||
|  | ||||
|             id = llama_sample_token(ctx_main, &cur_p); | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 MaggotHATE
					MaggotHATE