Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin
2024-09-14 16:08:52 -04:00
144 changed files with 11344 additions and 6693 deletions

View File

@@ -1,8 +1,8 @@
#include "arg.h"
#include "common.h"
#include "console.h"
#include "sampling.h"
#include "llama.h"
#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@@ -34,6 +34,7 @@
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_sampler ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
@@ -81,7 +82,7 @@ static void write_logfile(
yaml_dump_string_multiline(logfile, "output", output.c_str());
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
llama_dump_timing_info_yaml(logfile, ctx);
llama_perf_dump_yaml(logfile, ctx);
fclose(logfile);
}
@@ -93,7 +94,7 @@ static void sigint_handler(int signo) {
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx);
gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
@@ -103,14 +104,14 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) {
gpt_params params;
llama_sampling_params & sparams = params.sparams;
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
gpt_params_print_usage(argc, argv, params);
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_INFILL)) {
return 1;
}
auto & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("infill", "log"));
LOG_TEE("Log start\n");
@@ -156,26 +157,19 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
print_build_info();
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model;
llama_context * ctx;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
gpt_sampler * smpl = nullptr;
g_model = &model;
g_ctx = &ctx;
g_smpl = &smpl;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@@ -305,16 +299,14 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
smpl = gpt_sampler_init(model, sparams);
LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
printf("no need to specify '--infill', always running infill\n");
printf("************\n\n");
}
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
@@ -349,8 +341,6 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while (n_remain != 0 || params.interactive) {
// predict
if (!embd.empty()) {
@@ -421,11 +411,11 @@ int main(int argc, char ** argv) {
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
llama_sampling_accept(ctx_sampling, ctx, id, true);
gpt_sampler_accept(smpl, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
embd.push_back(id);
@@ -444,7 +434,7 @@ int main(int argc, char ** argv) {
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
@@ -476,7 +466,7 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
@@ -542,7 +532,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -615,7 +605,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
gpt_sampler_reset(smpl);
}
is_interacting = false;
}
@@ -638,13 +628,14 @@ int main(int argc, char ** argv) {
fflush(stdout);
}
llama_print_timings(ctx);
LOG_TEE("\n");
gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
llama_free(ctx);
llama_free_model(model);
llama_sampling_free(ctx_sampling);
gpt_sampler_free(smpl);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS