mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
		| @@ -2,7 +2,6 @@ | ||||
|  | ||||
| #include "console.h" | ||||
| #include "llama.h" | ||||
| #include "grammar-parser.h" | ||||
|  | ||||
| #include <cassert> | ||||
| #include <cinttypes> | ||||
| @@ -34,6 +33,7 @@ | ||||
|  | ||||
| static llama_context           ** g_ctx; | ||||
| static llama_model             ** g_model; | ||||
| static gpt_sampler             ** g_smpl; | ||||
| static gpt_params               * g_params; | ||||
| static std::vector<llama_token> * g_input_tokens; | ||||
| static std::ostringstream       * g_output_ss; | ||||
| @@ -81,7 +81,7 @@ static void write_logfile( | ||||
|     yaml_dump_string_multiline(logfile, "output", output.c_str()); | ||||
|     yaml_dump_vector_int(logfile, "output_tokens", output_tokens); | ||||
|  | ||||
|     llama_dump_timing_info_yaml(logfile, ctx); | ||||
|     llama_perf_dump_yaml(logfile, ctx); | ||||
|     fclose(logfile); | ||||
| } | ||||
|  | ||||
| @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { | ||||
|         } else { | ||||
|             console::cleanup(); | ||||
|             printf("\n"); | ||||
|             llama_print_timings(*g_ctx); | ||||
|             gpt_perf_print(*g_ctx, *g_smpl); | ||||
|             write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); | ||||
|             _exit(130); | ||||
|         } | ||||
| @@ -103,7 +103,6 @@ static void sigint_handler(int signo) { | ||||
|  | ||||
| int main(int argc, char ** argv) { | ||||
|     gpt_params params; | ||||
|     llama_sampling_params & sparams = params.sparams; | ||||
|     g_params = ¶ms; | ||||
|  | ||||
|     if (!gpt_params_parse(argc, argv, params)) { | ||||
| @@ -111,6 +110,8 @@ int main(int argc, char ** argv) { | ||||
|         return 1; | ||||
|     } | ||||
|  | ||||
|     auto & sparams = params.sparams; | ||||
|  | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|     log_set_target(log_filename_generator("infill", "log")); | ||||
|     LOG_TEE("Log start\n"); | ||||
| @@ -156,26 +157,21 @@ int main(int argc, char ** argv) { | ||||
|         LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); | ||||
|     } | ||||
|  | ||||
|     LOG_TEE("%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); | ||||
|     LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); | ||||
|     print_build_info(); | ||||
|  | ||||
|     if (params.seed == LLAMA_DEFAULT_SEED) { | ||||
|         params.seed = time(NULL); | ||||
|     } | ||||
|  | ||||
|     LOG_TEE("%s: seed  = %u\n", __func__, params.seed); | ||||
|  | ||||
|     std::mt19937 rng(params.seed); | ||||
|     LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); | ||||
|  | ||||
|     LOG("%s: llama backend init\n", __func__); | ||||
|     llama_backend_init(); | ||||
|     llama_numa_init(params.numa); | ||||
|  | ||||
|     llama_model * model; | ||||
|     llama_context * ctx; | ||||
|     llama_model * model = nullptr; | ||||
|     llama_context * ctx = nullptr; | ||||
|     gpt_sampler  * smpl = nullptr; | ||||
|  | ||||
|     g_model = &model; | ||||
|     g_ctx = &ctx; | ||||
|     g_smpl = &smpl; | ||||
|  | ||||
|     // load the model and apply lora adapter, if any | ||||
|     LOG("%s: load the model and apply lora adapter, if any\n", __func__); | ||||
| @@ -305,7 +301,7 @@ int main(int argc, char ** argv) { | ||||
|             LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); | ||||
|         } | ||||
|     } | ||||
|     LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); | ||||
|     LOG_TEE("sampling: \n%s\n", sparams.print().c_str()); | ||||
|     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||
|     LOG_TEE("\n\n"); | ||||
|  | ||||
| @@ -349,7 +345,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     std::vector<llama_token> embd; | ||||
|  | ||||
|     struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); | ||||
|     smpl = gpt_sampler_init(model, sparams); | ||||
|  | ||||
|     while (n_remain != 0 || params.interactive) { | ||||
|         // predict | ||||
| @@ -421,11 +417,11 @@ int main(int argc, char ** argv) { | ||||
|         embd.clear(); | ||||
|  | ||||
|         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||
|             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr); | ||||
|             const llama_token id = gpt_sampler_sample(smpl, ctx, -1); | ||||
|  | ||||
|             llama_sampling_accept(ctx_sampling, ctx, id, true); | ||||
|             gpt_sampler_accept(smpl, id, true); | ||||
|  | ||||
|             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); | ||||
|             // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); | ||||
|  | ||||
|             embd.push_back(id); | ||||
|  | ||||
| @@ -444,7 +440,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|                 // push the prompt in the sampling context in order to apply repetition penalties later | ||||
|                 // for the prompt, we don't apply grammar rules | ||||
|                 llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); | ||||
|                 gpt_sampler_accept(smpl, embd_inp[n_consumed], false); | ||||
|  | ||||
|                 ++n_consumed; | ||||
|                 if ((int) embd.size() >= params.n_batch) { | ||||
| @@ -476,7 +472,7 @@ int main(int argc, char ** argv) { | ||||
|         // if not currently processing queued inputs; | ||||
|         if ((int) embd_inp.size() <= n_consumed) { | ||||
|             // deal with eot token in infill mode | ||||
|             if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ | ||||
|             if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ | ||||
|                 if (is_interacting && !params.interactive_first) { | ||||
|                     // print an eot token | ||||
|                     printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); | ||||
| @@ -542,7 +538,7 @@ int main(int argc, char ** argv) { | ||||
|                 is_interacting = false; | ||||
|             } | ||||
|             // deal with end of generation tokens in interactive mode | ||||
|             else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { | ||||
|             else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { | ||||
|                 LOG("found EOS token\n"); | ||||
|  | ||||
|                 if (params.interactive) { | ||||
| @@ -615,7 +611,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|             if (n_past > 0) { | ||||
|                 if (is_interacting) { | ||||
|                     llama_sampling_reset(ctx_sampling); | ||||
|                     gpt_sampler_reset(smpl); | ||||
|                 } | ||||
|                 is_interacting = false; | ||||
|             } | ||||
| @@ -638,13 +634,14 @@ int main(int argc, char ** argv) { | ||||
|         fflush(stdout); | ||||
|     } | ||||
|  | ||||
|     llama_print_timings(ctx); | ||||
|     LOG_TEE("\n"); | ||||
|     gpt_perf_print(ctx, smpl); | ||||
|     write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); | ||||
|  | ||||
|     llama_free(ctx); | ||||
|     llama_free_model(model); | ||||
|  | ||||
|     llama_sampling_free(ctx_sampling); | ||||
|     gpt_sampler_free(smpl); | ||||
|     llama_backend_free(); | ||||
|  | ||||
| #ifndef LOG_DISABLE_LOGS | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov