examples : replace llama_kv_cache_seq_* with llama_past_seq_*

This commit is contained in:
Francis Couture-Harpin
2024-06-10 14:44:42 -04:00
parent 372482dffe
commit 43d8d4bf9e
23 changed files with 125 additions and 112 deletions

View File

@@ -394,14 +394,15 @@ int main(int argc, char ** argv) {
{
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_past_seq_keep(ctx_dft, s_keep);
llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
// FIXME: recurrent and hybrid models
llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_past_seq_keep(ctx_tgt, s_keep);
llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@@ -418,7 +419,8 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
@@ -474,8 +476,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@@ -553,9 +555,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_past_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());