mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
This commit is contained in:
		| @@ -39,8 +39,8 @@ static gpt_params               * g_params; | ||||
| static std::vector<llama_token> * g_input_tokens; | ||||
| static std::ostringstream       * g_output_ss; | ||||
| static std::vector<llama_token> * g_output_tokens; | ||||
| static bool is_interacting = false; | ||||
|  | ||||
| static bool is_interacting = false; | ||||
|  | ||||
| static void write_logfile( | ||||
|     const llama_context * ctx, const gpt_params & params, const llama_model * model, | ||||
| @@ -104,7 +104,7 @@ static void sigint_handler(int signo) { | ||||
|  | ||||
| int main(int argc, char ** argv) { | ||||
|     gpt_params params; | ||||
|     llama_sampling_params & sparams = params.sampling_params; | ||||
|     llama_sampling_params & sparams = params.sparams; | ||||
|     g_params = ¶ms; | ||||
|  | ||||
|     if (!gpt_params_parse(argc, argv, params)) { | ||||
| @@ -358,36 +358,10 @@ int main(int argc, char ** argv) { | ||||
|             LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); | ||||
|         } | ||||
|     } | ||||
|     LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", | ||||
|             sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); | ||||
|     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); | ||||
|     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||
|     LOG_TEE("\n\n"); | ||||
|  | ||||
|     struct llama_grammar * grammar = NULL; | ||||
|     grammar_parser::parse_state parsed_grammar; | ||||
|  | ||||
|     if (!params.grammar.empty()) { | ||||
|         parsed_grammar = grammar_parser::parse(params.grammar.c_str()); | ||||
|         // will be empty (default) if there are parse errors | ||||
|         if (parsed_grammar.rules.empty()) { | ||||
|             return 1; | ||||
|         } | ||||
|         LOG_TEE("%s: grammar:\n", __func__); | ||||
|         grammar_parser::print_grammar(stderr, parsed_grammar); | ||||
|         LOG_TEE("\n"); | ||||
|  | ||||
|         { | ||||
|             auto it = sparams.logit_bias.find(llama_token_eos(ctx)); | ||||
|             if (it != sparams.logit_bias.end() && it->second == -INFINITY) { | ||||
|                 LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); | ||||
|         grammar = llama_grammar_init( | ||||
|             grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); | ||||
|     } | ||||
|  | ||||
|     LOG_TEE("\n#####  Infill mode  #####\n\n"); | ||||
|     if (params.infill) { | ||||
|         printf("\n************\n"); | ||||
| @@ -430,7 +404,7 @@ int main(int argc, char ** argv) { | ||||
|     std::vector<llama_token> embd; | ||||
|     std::vector<llama_token> embd_guidance; | ||||
|  | ||||
|     struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); | ||||
|     struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); | ||||
|  | ||||
|     while (n_remain != 0 || params.interactive) { | ||||
|         // predict | ||||
| @@ -549,7 +523,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); | ||||
|  | ||||
|             llama_sampling_accept(ctx_sampling, ctx, id); | ||||
|             llama_sampling_accept(ctx_sampling, ctx, id, true); | ||||
|  | ||||
|             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); | ||||
|  | ||||
| @@ -567,8 +541,11 @@ int main(int argc, char ** argv) { | ||||
|             LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); | ||||
|             while ((int) embd_inp.size() > n_consumed) { | ||||
|                 embd.push_back(embd_inp[n_consumed]); | ||||
|                 ctx_sampling->prev.erase(ctx_sampling->prev.begin()); | ||||
|                 ctx_sampling->prev.push_back(embd_inp[n_consumed]); | ||||
|  | ||||
|                 // push the prompt in the sampling context in order to apply repetition penalties later | ||||
|                 // for the prompt, we don't apply grammar rules | ||||
|                 llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); | ||||
|  | ||||
|                 ++n_consumed; | ||||
|                 if ((int) embd.size() >= params.n_batch) { | ||||
|                     break; | ||||
| @@ -600,7 +577,7 @@ int main(int argc, char ** argv) { | ||||
|         if ((int) embd_inp.size() <= n_consumed) { | ||||
|  | ||||
|             // deal with eot token in infill mode | ||||
|             if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ | ||||
|             if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){ | ||||
|                 if(is_interacting && !params.interactive_first) { | ||||
|                     // print an eot token | ||||
|                     printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); | ||||
| @@ -617,7 +594,7 @@ int main(int argc, char ** argv) { | ||||
|                     buffer += line; | ||||
|                 } while (another_line); | ||||
|                 // check if we got an empty line, if so we use the old input | ||||
|                 if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { | ||||
|                 if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { | ||||
|                     params.input_prefix = buffer; | ||||
|                 } | ||||
|                 buffer.clear(); | ||||
| @@ -627,7 +604,7 @@ int main(int argc, char ** argv) { | ||||
|                     buffer += line; | ||||
|                 } while (another_line); | ||||
|                 // check if we got an empty line | ||||
|                 if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { | ||||
|                 if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { | ||||
|                     params.input_suffix = buffer; | ||||
|                 } | ||||
|                 buffer.clear(); | ||||
| @@ -640,7 +617,7 @@ int main(int argc, char ** argv) { | ||||
|                     process_escapes(params.input_suffix); | ||||
|                 } | ||||
|                 suff_rm_leading_spc = params.escape; | ||||
|                 if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { | ||||
|                 if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { | ||||
|                     params.input_suffix.erase(0, 1); | ||||
|                     suff_rm_leading_spc = false; | ||||
|                 } | ||||
| @@ -667,7 +644,7 @@ int main(int argc, char ** argv) { | ||||
|                 is_interacting = false; | ||||
|             } | ||||
|             // deal with end of text token in interactive mode | ||||
|             else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { | ||||
|             else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { | ||||
|                 LOG("found EOS token\n"); | ||||
|  | ||||
|                 if (params.interactive) { | ||||
| @@ -740,15 +717,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             if (n_past > 0) { | ||||
|                 if (is_interacting) { | ||||
|                     // reset grammar state if we're restarting generation | ||||
|                     if (grammar != NULL) { | ||||
|                         llama_grammar_free(grammar); | ||||
|  | ||||
|                         std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); | ||||
|                         grammar = llama_grammar_init( | ||||
|                             grammar_rules.data(), grammar_rules.size(), | ||||
|                             parsed_grammar.symbol_ids.at("root")); | ||||
|                     } | ||||
|                     llama_sampling_reset(ctx_sampling); | ||||
|                 } | ||||
|                 is_interacting = false; | ||||
|             } | ||||
| @@ -778,9 +747,7 @@ int main(int argc, char ** argv) { | ||||
|     llama_free(ctx); | ||||
|     llama_free_model(model); | ||||
|  | ||||
|     if (grammar != NULL) { | ||||
|         llama_grammar_free(grammar); | ||||
|     } | ||||
|     llama_sampling_free(ctx_sampling); | ||||
|     llama_backend_free(); | ||||
|  | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov