llama : remove implicit recurrent state rollbacks

This commit is contained in:
Francis Couture-Harpin
2024-11-24 20:35:30 -05:00
parent 124c222f76
commit 8006f3b3c8
25 changed files with 411 additions and 1119 deletions

View File

@@ -406,7 +406,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -580,7 +580,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -955,7 +955,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return;
}
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1232,7 +1232,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return;
}
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1602,7 +1602,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return;
}
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1789,7 +1789,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
// clear the KV cache
llama_past_clear(ctx);
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;