llama : remove implicit recurrent state rollbacks

This commit is contained in:
Francis Couture-Harpin
2024-11-24 20:35:30 -05:00
parent 124c222f76
commit 8006f3b3c8
25 changed files with 411 additions and 1119 deletions

View File

@@ -93,7 +93,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) {
llama_past_seq_cp(ctx, 0, s, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
@@ -436,17 +436,17 @@ int main(int argc, char ** argv) {
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, -1, n_past, -1);
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_past_seq_keep(ctx, seq_id_best);
llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_past_seq_rm (ctx, seq_id_best, -1, -1);
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_past_seq_cp(ctx, 0, s, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}
}
}