From 8006f3b3c83d63995acfaff19cd2f9c3ffc52949 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 24 Nov 2024 20:35:30 -0500 Subject: [PATCH] llama : remove implicit recurrent state rollbacks --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 4 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- .../cvector-generator/cvector-generator.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 4 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 8 +- .../llama.cpp.swift/LibLlama.swift | 8 +- examples/lookahead/lookahead.cpp | 12 +- examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 17 +- examples/parallel/parallel.cpp | 10 +- examples/passkey/passkey.cpp | 28 +- examples/perplexity/perplexity.cpp | 12 +- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 4 +- examples/speculative/speculative.cpp | 22 +- ggml/src/ggml.c | 1 + include/llama.h | 59 +- src/llama.cpp | 1315 ++++------------- 25 files changed, 411 insertions(+), 1119 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d47b12acba..451307b554 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); } - llama_past_clear(lctx); + llama_kv_cache_clear(lctx); llama_synchronize(lctx); llama_perf_context_reset(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index ecaa793baf..81c3220ada 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_ERR("%s: llama_decode() failed\n", __func__); @@ -142,7 +142,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index e1552750cf..d3d156932e 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_past_seq_cp(context, 0, Int32(i), -1, -1) + llama_kv_cache_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 83312ad969..3b554033e7 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { //// assign the system KV cache to all parallel sequences //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them //for (int32_t i = 1; i < n_parallel; ++i) { - // llama_past_seq_cp(ctx, 0, i, -1, -1); + // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); //} if (n_parallel > 1) { diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 846905aacb..69e141ecb9 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { } static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3f38667ae7..3f18fc6a70 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 36668b0e86..6e42fa0734 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,7 +44,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); @@ -99,7 +99,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std const llama_model * model = llama_get_model(ctx); llama_token eos_token = llama_token_eos(model); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d882d02de4..d1ff3e8bc4 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -494,7 +494,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 5d092f0002..f82c614f57 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -375,8 +375,8 @@ int main(int argc, char ** argv) { LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 6fbb97f85d..c22bdedcfa 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1566,7 +1566,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // cool off before the test if (params.delay) { @@ -1606,7 +1606,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 8e4ffd851e..f5ffd063f8 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_past_clear(context); + llama_kv_cache_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -446,5 +446,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_past_clear(reinterpret_cast(context)); + llama_kv_cache_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 570b4081c9..dcd9803a2a 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -209,7 +209,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_past_clear(context) + llama_kv_cache_clear(context) let t_pp_start = ggml_time_us() @@ -222,7 +222,7 @@ actor LlamaContext { // bench text generation - llama_past_clear(context) + llama_kv_cache_clear(context) let t_tg_start = ggml_time_us() @@ -241,7 +241,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_past_clear(context) + llama_kv_cache_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -291,7 +291,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_past_clear(context) + llama_kv_cache_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 607c755fce..03cd63f3fe 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -93,7 +93,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_past_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -436,17 +436,17 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation // FIXME: recurrent and hybrid models - llama_past_seq_rm(ctx, -1, n_past, -1); + llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_past_seq_keep(ctx, seq_id_best); - llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_past_seq_rm (ctx, seq_id_best, -1, -1); + llama_kv_cache_seq_keep(ctx, seq_id_best); + llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_past_seq_cp(ctx, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 700d519717..e2c8c3828f 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -191,7 +191,7 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted // FIXME: recurrent and hybrid models - llama_past_seq_rm(ctx, 0, n_past, -1); + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); common_batch_clear(batch_tgt); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4632da8344..fb10c20c5e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -332,10 +332,6 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } - - // remove any "future" tokens that we might have inherited from the previous session - n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); - if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_INF("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -347,6 +343,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } + + // remove any "future" tokens that we might have inherited from the previous session + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", @@ -358,8 +357,6 @@ int main(int argc, char ** argv) { LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); - } else { - session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -609,9 +606,9 @@ int main(int argc, char ** argv) { LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -625,8 +622,6 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { - // TODO: are the session tokens guaranteed to all be matching here? - // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index f3b54c2ee7..20274c1479 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_INF("\n"); @@ -231,9 +231,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_past_seq_rm(ctx, i, -1, -1); + llama_kv_cache_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_past_seq_cp(ctx, 0, i, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -370,8 +370,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_past_seq_rm(ctx, client.id + 1, -1, -1); - llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 87df5f2421..09bba708f6 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -130,11 +130,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } common_batch_clear(batch); @@ -164,12 +164,12 @@ int main(int argc, char ** argv) { LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag(ctx); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; common_batch_clear(batch); @@ -195,12 +195,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag(ctx); - llama_kv_cache_update(ctx); + llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag (ctx); + llama_kv_cache_update (ctx); - n_past = llama_past_seq_pos_max(ctx, 0) + 1; + n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 707e98b0b6..efb41b80a3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -406,7 +406,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -580,7 +580,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -955,7 +955,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1232,7 +1232,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1602,7 +1602,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par return; } - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1789,7 +1789,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } // clear the KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3741f63d61..1768aae510 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index b936a2ec58..3866cfa27e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_past_clear(ctx3); + llama_kv_cache_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5f054dea40..f809c46d5a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1029,7 +1029,7 @@ struct server_context { SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); clean_kv_cache = false; } @@ -1760,7 +1760,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_past_seq_rm(ctx, slot->id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8dbcdba305..33b469e8f5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -410,15 +410,15 @@ int main(int argc, char ** argv) { { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_past_seq_keep(ctx_dft, s_keep); - llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_past_seq_keep(ctx_dft, 0); + llama_kv_cache_seq_keep(ctx_dft, s_keep); + llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_dft, 0); // FIXME: recurrent and hybrid models - llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_past_seq_keep(ctx_tgt, s_keep); - llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_past_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_kv_cache_seq_keep(ctx_tgt, s_keep); + llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -495,8 +495,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -577,9 +577,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_past_seq_keep(ctx_tgt, 0); + llama_kv_cache_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index abe8d04ff8..3f01092d9f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19825,6 +19825,7 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/include/llama.h b/include/llama.h index ebd00e771f..510e862caa 100644 --- a/include/llama.h +++ b/include/llama.h @@ -41,7 +41,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 9 +#define LLAMA_SESSION_VERSION 10 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 3 @@ -613,58 +613,35 @@ extern "C" { LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed - LLAMA_API void llama_past_clear( + LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); - LLAMA_API DEPRECATED(void llama_kv_cache_clear( - struct llama_context * ctx), - "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past (one more than the largest remaining pos in the seq_id) - // which is only meaningful to handle for partial removals. - LLAMA_API llama_pos llama_past_seq_rm( + LLAMA_API bool llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); - LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past (one more than the largest remaining pos in the destination seq_id) - // which is only meaningful to handle when partially copying. - LLAMA_API llama_pos llama_past_seq_cp( + LLAMA_API void llama_kv_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_past_seq_keep( + LLAMA_API void llama_kv_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -672,19 +649,12 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_past_seq_add( + LLAMA_API void llama_kv_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -692,28 +662,17 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_past_seq_div( + LLAMA_API void llama_kv_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); - LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "use llama_past_seq_div instead"); // Returns the largest position present in the KV and/or RS cache for the specified sequence - LLAMA_API llama_pos llama_past_seq_pos_max( + LLAMA_API llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: diff --git a/src/llama.cpp b/src/llama.cpp index ee22ec394d..0b3b181f70 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2847,69 +2847,24 @@ struct llama_kv_self_cache { } }; -// for recurrent models, use a tree of sequences to simplify -// quickly finding the tail cell of each sequence -// TODO: drop the _rs_ infix -struct llama_rs_seq_node { - llama_seq_id seq_id = -1; - int32_t next_cell = -1; - - // needed for automatic typecasting from a llama_seq_id - llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} - - // needed for more convenient std::find - bool operator==(const llama_rs_seq_node & other) const { - return seq_id == other.seq_id; - } - - bool is_tail() const { - return next_cell < 0; - } -}; - struct llama_rs_cell { llama_pos pos = -1; int32_t src = -1; // copy source id (cleared next when -1) - // Link to previous cell in this sequence. - // Sequences can only diverge, never converge, - // so this works when there are multiple seq_ids per cell too. - int32_t prev = -1; - - // ref count of tails (should match the number of next_cell == -1 in seq_nodes) - uint32_t tail_rc = 0; - - // seq_ids by insertion order, to simplify updating n_cells compared to a set - std::vector seq_nodes; - - void insert_node(const llama_rs_seq_node & node) { - auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); - if (node_dest == seq_nodes.end()) { - seq_nodes.push_back(node); - } else { - // overwrite the pre-existing node with the same seq_id if it exists - *node_dest = node; - } - } + std::set seq_id; bool has_seq_id(const llama_seq_id & id) const { - return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); + return seq_id.find(id) != seq_id.end(); } bool is_empty() const { - return seq_nodes.empty(); + return seq_id.empty(); } }; struct llama_rs_seq_meta { // cell id of the latest state of this seq_id int32_t tail = -1; - // number of cells for which this seq_id is the first - // (useful to know if cells in this sequence should be pruned) - int32_t n_cells = 0; - // the last pos of this sequence if it is in the current ubatch, - // only set and used when finding a slot. - llama_pos ubatch_end_pos = -1; }; // ring-buffered tree of cached recurrent state data @@ -2922,32 +2877,17 @@ struct llama_rs_self_cache { // computed when finding a slot uint32_t n = 0; // range of states used for the last slot - // only counts cells which are tails of all of their sequences. - // useful to know the minimum reserved cell count per seq_id. - uint32_t n_seqs = 0; - // cells part of multiple sequences, - // but which are only the tail of some of them. - // useful to dismiss sequences used as a shared prompt - uint32_t n_shared_tail_cells = 0; - // with state models, a cell can hold the state for more than one past token - // TODO: it's probably not possible to always use contiguous cells std::vector cells; // find tail cells faster std::vector seq_tails; // map seq_ids to cell ids - // freeable cell ids, computed when finding a slot - // useful to find the smallest range to defrag - std::vector freeable; - // per layer // NOTE: the naming of r and s is arbitrary std::vector r_l; // rolling/shift states std::vector s_l; // ssm (recurrent) states - // TODO: maybe use a simpler data structure than a tree - // Inefficient, but thorough verification and rebuilding of the rs cache // from only the cells list with `pos` and seq_ids. // Should not be called in a hot loop except when desperate and/or debugging. @@ -2977,7 +2917,7 @@ struct llama_rs_self_cache { uint32_t used_verif = 0; for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { llama_rs_cell & cell = cells[cell_id]; - if (cell.seq_nodes.empty()) { + if (cell.is_empty()) { if (cell.pos >= 0) { if (debug) { LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", @@ -2986,6 +2926,8 @@ struct llama_rs_self_cache { cell.pos = -1; was_valid = false; } + } else { + used_verif += 1; } if (cell.pos < 0) { if (cell.pos != -1) { @@ -2996,30 +2938,19 @@ struct llama_rs_self_cache { cell.pos = -1; was_valid = false; } - if (!cell.seq_nodes.empty()) { + if (!cell.is_empty()) { if (debug) { LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", - __func__, cell_id, cell.seq_nodes.size()); + __func__, cell_id, cell.seq_id.size()); } - cell.seq_nodes.clear(); + cell.seq_id.clear(); was_valid = false; } cell.src = -1; - if (cell.prev != -1) { - if (debug) { - LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", - __func__, cell_id, cell.prev); - } - cell.prev = -1; - was_valid = false; - } } else if (!debug) { // Assuming the cache should be actually rebuilt when not debugging cell.src = cell_id; } - if (!cell.seq_nodes.empty()) { - used_verif += 1; - } } if (used != used_verif) { if (debug) { @@ -3051,480 +2982,10 @@ struct llama_rs_self_cache { seq.tail = tail; was_valid = false; } - int32_t prev = -1; - for (size_t i = 0; i < seq_cells.size(); ++i) { - uint32_t cell_id = seq_cells[i].second; - llama_rs_cell & cell = cells[cell_id]; - if (cell.prev != prev) { - // TODO: relax the error when multiple cells have the same pos - if (debug) { - LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", - __func__, cell_id, cell.prev, prev); - } - cell.prev = prev; - was_valid = false; - } - prev = cell_id; - } - int32_t n_cells = 0; - int32_t next = -1; - for (size_t i = seq_cells.size(); i-- > 0;) { - uint32_t cell_id = seq_cells[i].second; - llama_rs_cell & cell = cells[cell_id]; - // assuming it's always found, because how else would it end up in the list of cells for this seq_id? - auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); - if (seq_node == cell.seq_nodes.begin()) { - n_cells += 1; - } - if (seq_node->next_cell != next) { - // TODO: relax the error when multiple cells have the same pos - if (debug) { - LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", - __func__, seq_id, cell_id, seq_node->next_cell, next); - } - seq_node->next_cell = next; - was_valid = false; - } - next = cell_id; - } - if (seq.n_cells != n_cells) { - if (debug) { - LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", - __func__, seq_id, seq.n_cells, n_cells); - } - seq.n_cells = n_cells; - } - } - // tail_rc - for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { - llama_rs_cell & cell = cells[cell_id]; - uint32_t tail_rc = 0; - for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { - auto & seq = seq_tails[seq_id]; - if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { - tail_rc += 1; - } - } - if (cell.tail_rc != tail_rc) { - if (debug) { - LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", - __func__, cell_id, cell.tail_rc, tail_rc); - } - cell.tail_rc = tail_rc; - was_valid = false; - } - } - // n_seqs - uint32_t n_seqs_verif = 0; - uint32_t n_shared_tail_cells_verif = 0; - for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { - llama_rs_cell & rs_cell = cells[cell_id]; - if (!rs_cell.seq_nodes.empty()) { - if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { - n_seqs_verif += 1; - } else if (rs_cell.tail_rc > 0) { - n_shared_tail_cells_verif += 1; - } - } - } - if (n_seqs != n_seqs_verif) { - if (debug) { - LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", - __func__, n_seqs, n_seqs_verif); - } - n_seqs = n_seqs_verif; - was_valid = false; - } - if (n_shared_tail_cells != n_shared_tail_cells_verif) { - if (debug) { - LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", - __func__, n_shared_tail_cells, n_shared_tail_cells_verif); - } - n_shared_tail_cells = n_shared_tail_cells_verif; - was_valid = false; } return was_valid; } - // each seq_id should have access to at least this many cells - // (to use when pruning (to avoid over-pruning)) - uint32_t min_cells_per_seq(const llama_ubatch & batch) const { - uint32_t seqs = n_seqs; - for (uint32_t i = 0; i < batch.n_seqs; ++i) { - llama_seq_id seq_id = batch.seq_id[i][0]; - const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; - if (new_seq.tail < 0 || new_seq.n_cells == 0) { - seqs += 1; - } - } - return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); - } - - void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { - GGML_ASSERT(batch.equal_seqs); - int32_t min_cells = min_cells_per_seq(batch); - - // TODO: minimize work required to find freeable cells - // currently, this finds freeable cells by excluding non-freeable cells, - // because some conditions are more easily expressed this way. - - freeable.assign(size, 1); - - for (llama_rs_seq_meta & seq : seq_tails) { - seq.ubatch_end_pos = -1; - } - - for (uint32_t i = 0; i < batch.n_seqs; ++i) { - int32_t n_seq_id = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_id; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); - llama_rs_seq_meta & seq = seq_tails[seq_id]; - seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; - } - } - - for (llama_rs_seq_meta & seq : seq_tails) { - if (seq.tail >= 0 && freeable[seq.tail] != 0) { - llama_pos end_pos = seq.ubatch_end_pos; - // When is a tail cell not freeable? - if (end_pos < 0) { - // when any of its tails are not in the batch - freeable[seq.tail] = 0; - } else if (min_cells > 1) { - // TODO: fallback to this less often - llama_rs_cell & tail = cells[seq.tail]; - GGML_ASSERT(tail.pos < end_pos); - if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { - // make a checkpoint before prompt processing - // TODO: should it always be done after instead? - freeable[seq.tail] = 0; - } else { - llama_rs_cell & prev = cells[tail.prev]; - if (prev.pos + checkpoint_interval <= end_pos) { - // make a checkpoint during text generation - freeable[seq.tail] = 0; - } - } - } - } - } - - for (uint32_t i = 0; i < size; ++i) { - llama_rs_cell & cell = cells[i]; - if (!cell.is_empty() && cell.tail_rc == 0) { - // TODO: reduce indirection here - llama_rs_seq_node & seq_node = cell.seq_nodes[0]; - llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; - bool keep_tail = freeable[seq.tail] == 0; - // kept tails use an additional cell, so make them allow freeing a checkpoint - int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; - // A checkpoint is kept if there's enough alloted space for this sequence - // or if it's the state right before the tail - if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { - freeable[i] = 0; - } - } - } - } - - // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. - // Why an iterator? Because it allows using std::vector::erase. - std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - // The iterator needs to point inside the correct vector - GGML_ASSERT(&(*node_iter) >= rs_cell.seq_nodes.data() && &(*node_iter) < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); - if (node_iter != rs_cell.seq_nodes.end()) { - // update the tree - llama_rs_seq_node node = *node_iter; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail - cells[node.next_cell].prev = rs_cell.prev; - } - if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { - llama_rs_cell & prev_cell = cells[rs_cell.prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); - // assuming the previous node is always found - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); - prev_node->next_cell = node.next_cell; - if (node.is_tail()) { - // move the tail back to the previous cell - prev_cell.tail_rc += 1; - if (prev_cell.seq_nodes.size() > 1) { - if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { - if (prev_cell.tail_rc == 1) { - n_shared_tail_cells += 1; - } - - if (rs_cell.tail_rc == 1) { - if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { - // o oo oo - // |/ -> o/ - // | | - // e.g. when removing the leaf of a split tree - n_seqs -= 1; - } else { - // o - // o -> oo - // | | - // e.g. when merging back with a previous tail - n_shared_tail_cells -= 1; - } - } - } - } - } - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (rs_cell.tail_rc == 1) { - if (seq.tail < 0) { - // no more tail, no more sequence - if (rs_cell.seq_nodes.size() > 1) { - n_shared_tail_cells -= 1; - } else { - n_seqs -= 1; - } - } - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; - } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { - // will fully become a tail cell - if (rs_cell.tail_rc > 0) { - n_seqs += 1; - n_shared_tail_cells -= 1; - } - } - if (node_iter == rs_cell.seq_nodes.begin()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - - auto next_node = std::next(node_iter); - if (next_node != rs_cell.seq_nodes.end()) { - // the next node is the new first one, so update its n_cells - if ((uint32_t) next_node->seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node->seq_id]; - next_seq.n_cells += 1; - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else { - // this was the last seq_id of the cell - used -= 1; - rs_cell.pos = -1; - rs_cell.src = -1; - rs_cell.prev = -1; - // the other fields *should* have already been updated elsewhere - } - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - return rs_cell.seq_nodes.erase(node_iter); - } - return node_iter; - } - - void clear_cell(llama_rs_cell & rs_cell) { - GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { - node_iter = remove_seq_node_from_cell(rs_cell, node_iter); - } - } - - // returns whether or not the seq_id was removed - bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { - if (i_cell < size && (size_t) id < size) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once - return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); - } - return false; - } - - bool swap_cells(uint32_t i_src, uint32_t i_dst) { - if (i_src < size && i_dst < size && i_src != i_dst) { - llama_rs_cell & src = cells[i_src]; - llama_rs_cell & dst = cells[i_dst]; - - for (llama_rs_seq_node & seq_node : src.seq_nodes) { - if (seq_node.next_cell >= 0) { - llama_rs_cell & next = cells[seq_node.next_cell]; - next.prev = i_dst; - if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } else { - // this is a tail - seq_tails[seq_node.seq_id].tail = i_dst; - } - } - for (llama_rs_seq_node & seq_node : dst.seq_nodes) { - if (seq_node.next_cell >= 0) { - llama_rs_cell & next = cells[seq_node.next_cell]; - next.prev = i_src; - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } - } else { - // this is a tail - seq_tails[seq_node.seq_id].tail = i_src; - } - } - - if (src.prev == dst.prev) { - // avoid swapping them twice - if (src.prev >= 0) { - llama_rs_cell & prev = cells[src.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } else if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } - } - } else { - if (src.prev >= 0) { - llama_rs_cell & prev = cells[src.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_src) { - seq_node.next_cell = i_dst; - } - } - } - if (dst.prev >= 0) { - llama_rs_cell & prev = cells[dst.prev]; - for (llama_rs_seq_node & seq_node : prev.seq_nodes) { - if ((uint32_t) seq_node.next_cell == i_dst) { - seq_node.next_cell = i_src; - } - } - } - } - - std::swap(src.pos, dst.pos); - std::swap(src.src, dst.src); - std::swap(src.prev, dst.prev); - std::swap(src.tail_rc, dst.tail_rc); - std::swap(src.seq_nodes, dst.seq_nodes); - - return true; - } - return false; - } - - bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { - if (i_cell < size && (size_t) id < seq_tails.size()) { - llama_rs_cell & rs_cell = cells[i_cell]; - auto & seq = seq_tails[id]; - int32_t prev = rs_cell.prev; - if (end_pos >= 0) { - if (end_pos <= rs_cell.pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, end_pos, rs_cell.pos, id); - } - rs_cell.pos = end_pos; - } else { - // if no pos was specified, then the target cell should already have a valid one. - GGML_ASSERT(!rs_cell.is_empty()); - } - if ((uint32_t) seq.tail == i_cell) { - // the cell is already the tail of this seq_id - if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { - GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id - // remove non-tail seq_ids (branch off them) - for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { - if (!rs_cell.seq_nodes[i].is_tail()) { - remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); - } - } - } - return true; - } - if (rs_cell.is_empty()) { - prev = seq.tail; - } - // ensure the new tail won't mess up the tree - GGML_ASSERT(seq.tail == -1 || seq.tail == prev); - if (prev >= 0 && (uint32_t) prev < size) { - // the targeted cell has a previous cell - llama_rs_cell & prev_cell = cells[prev]; - auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); - GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing - GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken - if (rs_cell.is_empty()) { - rs_cell.src = prev_cell.src; - } - prev_node->next_cell = i_cell; - rs_cell.prev = prev; - if (seq.tail == prev) { - // What to do when the tail moves... - // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) - // O -> oO (n_seqs--, n_shared_tail_cells++) - // O -> O (seq.n_cells++) - // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) - // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) - // _ -> oO (n_shared_tail_cells++) - // _ -> O (seq.n_cells++, n_seqs++) - // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) - // Oo -> OO+ (n_shared_tail_cell--) - // OOo -> O (seq.n_cells++, n_seqs++) - if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { - // from fully tail - if (prev_cell.tail_rc > 1) { - // the previous tail becomes shared with a non-tail - n_shared_tail_cells += 1; - } - if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { - // the new tail cell was previously a fully non-tail cell - n_shared_tail_cells += 1; - n_seqs -= 1; - } - } else { - if (rs_cell.is_empty()) { - // from shared to unique - n_seqs += 1; - } - if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { - // from last shared to fully tail - n_shared_tail_cells -= 1; - } - } - } - prev_cell.tail_rc -= 1; - } - if (rs_cell.is_empty()) { - // to unique - seq.n_cells += 1; - if (seq.tail < 0) { - // from empty to unique - n_seqs += 1; - // make sure it's cleared - rs_cell.src = -1; - } - used += 1; - } else if (rs_cell.tail_rc == 0) { - // to shared - if (seq.tail < 0) { - // from empty to shared - n_shared_tail_cells += 1; - } - } - // the target cell was not already a tail of this seq_id - rs_cell.insert_node(id); // next_cell == -1 by default - rs_cell.tail_rc += 1; - seq.tail = i_cell; - return true; - } - return false; - } - size_t total_size() const { size_t size = 0; for (struct ggml_tensor * r : r_l) { @@ -4341,7 +3802,6 @@ static bool llama_kv_cache_init( cache.rs.cells.resize(rs_size); cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(rs_size); - cache.rs.freeable.reserve(rs_size); // count used buffer types std::map buft_layer_count; @@ -4429,88 +3889,80 @@ static bool llama_kv_cache_init( static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_ubatch & batch) { - const uint32_t kv_size = cache.kv.size; - const uint32_t rs_size = cache.rs.size; + struct llama_kv_self_cache & kv_self = cache.kv; + struct llama_rs_self_cache & rs_self = cache.rs; + const uint32_t kv_size = kv_self.size; + const uint32_t rs_size = rs_self.size; const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; - // only check first, to allow failing gracefully - if (rs_size > 0) { - // everything should fit if all seq_ids are smaller than the max - for (uint32_t i = 0; i < n_seqs; ++i) { - int32_t n_seq_id = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id = batch.seq_id[i][j]; + // check only at first, to allow failing gracefully + { + if (rs_size > 0) { + if (!batch.equal_seqs) { + LLAMA_LOG_ERROR("%s: can't process batch with unequal new tokens per sequence for recurrent models\n", __func__); + return false; + } - if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { - // too big seq_id - // TODO: would it be possible to resize the rs cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); + // everything should fit if all seq_ids are smaller than the max + for (uint32_t i = 0; i < n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { + // too big seq_id + // TODO: would it be possible to resize the rs cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, rs_size); + return false; + } + } + } + } + + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); + return false; + } + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } + + uint32_t n_tested = 0; + + while (true) { + if (kv_self.head + n_tokens > kv_size) { + n_tested += kv_size - kv_self.head; + kv_self.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (kv_self.cells[kv_self.head + i].pos >= 0) { + found = false; + kv_self.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return false; } } } - // TODO: configurable checkpoint interval - cache.rs.freeable_for_batch(batch, 8); - { - uint32_t freeable_rs_cell_count = 0; - for (uint32_t is_freeable : cache.rs.freeable) { - freeable_rs_cell_count += (uint32_t) (is_freeable != 0); - if (freeable_rs_cell_count >= n_seqs) { - // there's enough, no need to count them all - break; - } - } - if (n_seqs > freeable_rs_cell_count) { - // This should not happen - LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); - return false; - } - } - } - - if (kv_size > 0) { - // one KV cell per token - if (n_tokens > kv_size) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); - return false; - } - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (cache.kv.head > cache.kv.used + 2*n_tokens) { - cache.kv.head = 0; - } - - uint32_t n_tested = 0; - - while (true) { - if (cache.kv.head + n_tokens > kv_size) { - n_tested += kv_size - cache.kv.head; - cache.kv.head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.kv.cells[cache.kv.head + i].pos >= 0) { - found = false; - cache.kv.head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= kv_size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } } // now modification can be done, and should NOT fail @@ -4520,154 +3972,142 @@ static bool llama_kv_cache_find_slot( // each cache cell can store the state for a whole sequence. // A slot should be always be contiguous. - uint32_t min_head = 0; - uint32_t min_n = cache.rs.size; - uint32_t min_free = 0; + int32_t min = rs_size - 1; + int32_t max = 0; - // compact the freeable cell list - // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 - // while also finding the smallest cell range for the slot - { - uint32_t next_free = 0; - for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { - if (cache.rs.freeable[i]) { - cache.rs.freeable[next_free] = i; - next_free += 1; + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = batch.n_seq_id[s]; + for (uint32_t j = 1; j < n_seq_id; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; - if (next_free >= n_seqs) { - uint32_t head = cache.rs.freeable[next_free - n_seqs]; - // i is the last seen freeable cell id - uint32_t n = i - head + 1; - // keep the first smallest big enough slot - if (n < min_n) { - min_free = next_free - n_seqs; - min_head = head; - min_n = n; - if (n == n_seqs) { - // it's the smallest it can be - break; - } - } + llama_rs_seq_meta & seq = rs_self.seq_tails[seq_id]; + if (seq.tail >= 0) { + llama_rs_cell & cell = rs_self.cells[seq.tail]; + // Clear previous tail cells from seq_ids that become shared. + // Only happens on batches with multiple seq_ids per token, + // but the seq_ids each had their own tail cell. + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + rs_self.used -= 1; } } } } - // sanity check - GGML_ASSERT(min_head + min_n <= cache.rs.size); + // find next empty cell + uint32_t next_empty_cell = rs_self.head; - // keep only the necessary range - cache.rs.freeable.resize(min_free + n_seqs); - cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); - GGML_ASSERT(cache.rs.freeable.size() == n_seqs); - GGML_ASSERT(min_n >= n_seqs); - cache.rs.freeable.resize(min_n); - - // expand the free list - // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 - for (uint32_t i = n_seqs; i-- > 0;) { - uint32_t dst = cache.rs.freeable[i] - min_head; - if (dst != i) { - cache.rs.freeable[i] = 0; - } - GGML_ASSERT(dst >= i); - cache.rs.freeable[dst] = 1; + for (uint32_t i = 0; i < rs_size; ++i) { + if (next_empty_cell >= rs_size) { next_empty_cell -= rs_size; } + llama_rs_cell & cell = rs_self.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; } - // coalesce the free cells together - // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 - // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 - { - uint32_t top_free = min_n - 1; - for (uint32_t i = min_n; i-- > 1;) { - uint32_t is_free = cache.rs.freeable[i]; - if (!is_free) { - GGML_ASSERT(top_free > i); - cache.rs.swap_cells(min_head + i, min_head + top_free); - std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); - // the previous one has to be free, - // otherwise it would already have been swapped. - top_free -= 1; + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = batch.seq_id[s][0]; + llama_rs_seq_meta & seq_meta = rs_self.seq_tails[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_rs_cell & cell = rs_self.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_rs_cell & empty_cell = rs_self.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_rs_cell & orig_cell = rs_self.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten } - // stop early if all freeable cells have already been put at the beginning - if (top_free < n_seqs) { break; } - } - } - - // order the re-used cells identically to their batch order - // (and clear the non-reused cells) - { - for (uint32_t i = 0; i < n_seqs; ++i) { - // ignore the already-swapped cells - if (cache.rs.freeable[i]) { - llama_rs_cell & cell = cache.rs.cells[min_head + i]; - if (!cell.is_empty()) { - if (cell.tail_rc == 0) { - cache.rs.clear_cell(cell); - } else { - // Find the seq_id of the first tail of this cell - llama_seq_id seq_id = -1; - for (llama_rs_seq_node & seq_node : cell.seq_nodes) { - if (seq_node.is_tail()) { - seq_id = seq_node.seq_id; - break; - } - } - GGML_ASSERT(seq_id != -1); - - // Which seq_id of the batch is it? - int32_t nth_seq_id = -1; - for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { - if (seq_id == batch.seq_id[s][0]) { - nth_seq_id = s; - break; - } - } - GGML_ASSERT(nth_seq_id != -1); - - cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); - cache.rs.freeable[i] = 0; - std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); - i -= 1; // check this cell again, now that it was swapped - } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < rs_size; ++i) { + if (next_empty_cell >= rs_size) { next_empty_cell -= rs_size; } + llama_rs_cell & cell = rs_self.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; } } } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } } - // reserve - { - for (uint32_t i = 0; i < n_seqs; ++i) { - uint32_t i_cell = min_head + i; - int32_t n_seq_id = batch.n_seq_id[i]; - llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; - // set the pos with the first seq_id - cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); - // insert the rest of the seq_ids by re-using the cell's pos - for (int j = 1; j < n_seq_id; ++j) { - cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = rs_self.seq_tails[batch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_rs_cell & dst_cell = rs_self.cells[dst_id]; + llama_rs_cell & src_cell = rs_self.cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + rs_self.seq_tails[seq_id].tail = src_id; } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + rs_self.seq_tails[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_rs_cell & cell = rs_self.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = batch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + rs_self.seq_tails[seq_id].tail = cell_id; } } // allow getting the range of used cells, from head to head + n - cache.rs.head = min_head; - cache.rs.n = min_n; + rs_self.head = min; + rs_self.n = max - min + 1; } if (kv_size > 0) { for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; + kv_self.cells[kv_self.head + k].pos = batch.pos[k]; for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + kv_self.cells[kv_self.head + k].seq_id.insert(batch.seq_id[s][j]); } } } - cache.kv.used += n_tokens; + kv_self.used += n_tokens; } return true; @@ -4686,20 +4126,7 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_self_cache & cache return 0; } -// find how many recurrent state cells are currently in use -static uint32_t llama_rs_cache_cell_max(const struct llama_rs_self_cache & cache) { - for (uint32_t i = cache.size; i > 0; --i) { - const llama_rs_cell & cell = cache.cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -static void llama_past_clear(struct llama_kv_cache & cache) { +static void llama_kv_cache_clear(struct llama_kv_cache & cache) { if (cache.kv.size > 0) { for (uint32_t i = 0; i < cache.kv.size; ++i) { llama_kv_cell & kv_cell = cache.kv.cells[i]; @@ -4717,14 +4144,10 @@ static void llama_past_clear(struct llama_kv_cache & cache) { llama_rs_cell & rs_cell = cache.rs.cells[i]; rs_cell.pos = -1; rs_cell.src = -1; - rs_cell.prev = -1; - rs_cell.tail_rc = 0; - rs_cell.seq_nodes.clear(); + rs_cell.seq_id.clear(); } - cache.rs.head = 0; - cache.rs.used = 0; - cache.rs.n_seqs = 0; - cache.rs.n_shared_tail_cells = 0; + cache.rs.head = 0; + cache.rs.used = 0; cache.rs.seq_tails.clear(); cache.rs.seq_tails.resize(cache.rs.size); } @@ -4733,63 +4156,65 @@ static void llama_past_clear(struct llama_kv_cache & cache) { } } -static llama_pos llama_past_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { +static bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - + // models like Mamba or RWKV can't have a state partially erased + // TODO: refactor the recurrent state cache to allow partial rollbacks if (cache.rs.size > 0) { - if (seq_id >= (int64_t) cache.rs.size) { - // could be fatal - return n_past; - } uint32_t new_head = cache.rs.size; - // adjust p0 and p1 according to the states found - llama_pos new_p0 = 0; - llama_pos new_p1 = std::numeric_limits::max(); - // partial seq_id removal has to happen from the tail - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - int32_t cell_id = seq.tail; - - while (cell_id >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - // copy before the cell is potentially changed - int32_t prev_id = rs_cell.prev; - if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { - // non-tail removal for shared cells can only be done when clearing a cell - // (i.e. when the next cell's link to the previous cell can be safely changed) - p1 = rs_cell.pos + 1; - } - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); - // if the node isn't found, the sequence tree is malformed - GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); - cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - // get the smallest removed cell id - if (new_head > (uint32_t) cell_id) { new_head = cell_id; } - } else { - // one more than the biggest non-removed cell of this sequence - if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } - - if (rs_cell.pos < p0) { - // new_p0 should be right after the max pos in the states before p0 - if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } - } else { // (rs_cell.pos >= p1) - // new_p1 should be the min pos in the states after p1 - if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } + if (seq_id >= (int64_t) cache.rs.seq_tails.size()) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cache.rs.seq_tails[seq_id].tail; + if (tail_id >= 0) { + const llama_rs_cell & cell = cache.rs.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + // Assume there's only one cell per seq_id + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (seq_id < 0 || i == (uint32_t) seq_id) { + int32_t tail_id = cache.rs.seq_tails[i].tail; + if (tail_id >= 0) { + llama_rs_cell rs_cell = cache.rs.cells[tail_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.seq_id.erase(i); + if (rs_cell.is_empty()) { + // keep count of the number of used cells + if (cache.rs.cells[i].pos >= 0) { cache.rs.used--; } + + cache.rs.cells[i].pos = -1; + cache.rs.cells[i].src = -1; + if (new_head == cache.rs.size) { new_head = i; } + } + } + cache.rs.seq_tails[i].tail = -1; } } - cell_id = prev_id; } - p0 = new_p0; - p1 = new_p1; // If we freed up a slot, set head to it so searching can start there. if (new_head != cache.rs.size && new_head < cache.rs.head) { @@ -4801,24 +4226,20 @@ static llama_pos llama_past_seq_rm( uint32_t new_head = cache.kv.size; for (uint32_t i = 0; i < cache.kv.size; ++i) { - llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (cache.kv.cells[i].pos >= p0 && cache.kv.cells[i].pos < p1) { + if (seq_id < 0) { + cache.kv.cells[i].seq_id.clear(); + } else if (cache.kv.cells[i].has_seq_id(seq_id)) { + cache.kv.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.kv.cells[i].is_empty()) { + // keep count of the number of used cells + if (cache.kv.cells[i].pos >= 0) { cache.kv.used--; } - if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { - if (kv_cell.pos >= p0 && kv_cell.pos < p1) { - if (seq_id < 0) { - kv_cell.seq_id.clear(); - } else { // (kv_cell.has_seq_id(seq_id)) - kv_cell.seq_id.erase(seq_id); - } - if (kv_cell.is_empty()) { - // keep count of the number of used cells - if (kv_cell.pos >= 0) { cache.kv.used--; } - - kv_cell.pos = -1; - if (new_head == cache.kv.size) { new_head = i; } - } - } else if (kv_cell.pos >= n_past) { - n_past = kv_cell.pos + 1; + cache.kv.cells[i].pos = -1; + if (new_head == cache.kv.size) { new_head = i; } } } } @@ -4829,59 +4250,29 @@ static llama_pos llama_past_seq_rm( } } - return n_past; + return true; } -static llama_pos llama_past_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - // TODO: in practice this seems to be only used on whole sequences; - // should partial sequence copy support be removed? - // TODO: What if the destination sequence is not empty? - - llama_pos n_past = 0; - if (cache.rs.size > 0) { - // have to start from the beginning for recurrent models - p0 = 0; - if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { - int32_t src_head = -1; - int32_t head_pos = p1; - int32_t src_next = -1; - // find the start of the sequence - for (uint32_t i = 0; i < cache.rs.size; ++i) { - llama_rs_cell & rs_cell = cache.rs.cells[i]; - if (!rs_cell.is_empty() && rs_cell.prev < 0) { - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); - if (seq_node != rs_cell.seq_nodes.end()) { - src_head = i; - head_pos = rs_cell.pos; - src_next = seq_node->next_cell; - break; - } - } - } - while (src_head >= 0 && head_pos < p1) { - cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); - src_head = src_next; - if (head_pos >= n_past) { n_past = head_pos + 1; } - if (src_next >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[src_next]; - auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); - head_pos = rs_cell.pos; - // it should always be found if the seq tree is valid - GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); - src_next = seq_node->next_cell; - } + llama_rs_seq_meta & seq_meta = cache.rs.seq_tails[seq_id_src]; + if (seq_meta.tail >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[seq_meta.tail]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.seq_id.insert(seq_id_dst); + // TODO: What if the destination sequence is not empty? + GGML_ASSERT(cache.rs.seq_tails[seq_id_dst].tail < 0); + cache.rs.seq_tails[seq_id_dst].tail = seq_meta.tail; } } - p1 = n_past; } if (cache.kv.size > 0) { @@ -4889,32 +4280,30 @@ static llama_pos llama_past_seq_cp( llama_kv_cell & kv_cell = cache.kv.cells[i]; if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { kv_cell.seq_id.insert(seq_id_dst); - if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } } } } - - return n_past; } -static void llama_past_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { if (cache.rs.size > 0) { uint32_t new_head = cache.rs.size; - // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= 0) { cache.rs.used--; } + rs_cell.pos = -1; + rs_cell.seq_id.clear(); + if (new_head == cache.rs.size) { new_head = i; } + } else { + rs_cell.seq_id.clear(); + rs_cell.seq_id.insert(seq_id); + } + } for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { - if (i == (uint32_t) seq_id) { continue; } - llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; - int32_t cell_id = seq.tail; - while (cell_id >= 0) { - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); - GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); - cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); - cell_id = rs_cell.prev; - if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { - new_head = cell_id; - } + if (i != (uint32_t) seq_id) { + cache.rs.seq_tails[i].tail = -1; } } @@ -4947,41 +4336,29 @@ static void llama_past_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_ } } -static void llama_past_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { +static void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } if (cache.rs.size > 0) { - // for Mamba-like or RKWV models, only the pos needs to be shifted - auto & seq = cache.rs.seq_tails[seq_id]; - // follow the sequence from its tail - int32_t cell_id = seq.tail; - uint32_t new_head = cache.rs.size; - while (cell_id >= 0) { - GGML_ASSERT((uint32_t) cell_id < cache.rs.size); - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - cell_id = rs_cell.prev; - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - rs_cell.pos += delta; - if (rs_cell.pos < 0) { - // NOTE: this affects the other sequences which share the cell - cache.rs.clear_cell(rs_cell); - if (new_head > (uint32_t) cell_id) { - new_head = cell_id; - } + // for recurrent states, the pos shift is faked + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + // TODO: handle deletion + // (but this should not happen anyway when only the last states are stored) + GGML_ASSERT(rs_cell.pos >= 0); } } } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.rs.head = new_head != cache.rs.size ? new_head : 0; } if (cache.kv.size > 0) { @@ -5015,26 +4392,24 @@ static void llama_past_seq_add( } } -static void llama_past_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } if (cache.rs.size > 0) { - // for Mamba-like or RWKV models, only the pos needs to be changed - auto & seq = cache.rs.seq_tails[seq_id]; - int32_t cell_id = seq.tail; - while (cell_id >= 0) { - GGML_ASSERT((uint32_t) cell_id < cache.rs.size); - llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; - if (rs_cell.pos >= p0 && rs_cell.pos < p1) { - rs_cell.pos /= d; + // for recurrent states, the pos shift is faked + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; + } } - cell_id = rs_cell.prev; } } @@ -5056,7 +4431,7 @@ static void llama_past_seq_div( } } -static llama_pos llama_past_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { llama_pos result = -1; if (cache.rs.size > 0) { @@ -21341,86 +20716,49 @@ int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.rs.used; } -void llama_past_clear(struct llama_context * ctx) { - llama_past_clear(ctx->cache); -} - -// deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx->cache); } -llama_pos llama_past_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - return llama_past_seq_rm(ctx->cache, seq_id, p0, p1); -} - -// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_pos n_past = llama_past_seq_rm(ctx, seq_id, p0, p1); - return n_past >= p0; + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_kv_cache_seq_rm(ctx->cache, seq_id, p0, p1); } - -llama_pos llama_past_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { - return 0; + // TODO: error? + return; } if (seq_id_src == seq_id_dst) { - return llama_past_seq_pos_max(ctx->cache, seq_id_dst) + 1; + return; } - return llama_past_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_kv_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } -// deprecated -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - llama_past_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - -void llama_past_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_past_seq_keep(ctx->cache, seq_id); -} - -// deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_past_seq_keep(ctx, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_kv_cache_seq_keep(ctx->cache, seq_id); } -void llama_past_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (delta == 0) { return; } - llama_past_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_kv_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } -// deprecated -void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_past_seq_add(ctx, seq_id, p0, p1, delta); -} - -void llama_past_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } if (d == 1) { return; } - llama_past_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_kv_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } -// deprecated -void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - llama_past_seq_div(ctx, seq_id, p0, p1, d); -} - -llama_pos llama_past_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } - return llama_past_seq_pos_max(ctx->cache, seq_id); -} - -// deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - llama_pos max_pos = llama_past_seq_pos_max(ctx, seq_id); - return max_pos < 0 ? 0 : max_pos; + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } + return llama_kv_cache_seq_pos_max(ctx->cache, seq_id); } void llama_kv_cache_defrag(struct llama_context * ctx) { @@ -21562,14 +20900,14 @@ struct llama_data_write { for (uint32_t i = range.first; i < range.second; ++i) { const auto & cell = rs_self.cells[i]; const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; write(&pos, sizeof(pos)); write(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id) { - for (auto seq_node : cell.seq_nodes) { - write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + for (auto seq_id : cell.seq_id) { + write(&seq_id, sizeof(seq_id)); } } } @@ -21968,8 +21306,7 @@ struct llama_data_read { return false; } - cell.insert_node(seq_id); - + cell.seq_id.insert(seq_id); } } @@ -22233,10 +21570,10 @@ struct llama_data_read { bool res = true; if (seq_id == -1) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); } else { - llama_past_seq_rm(ctx, seq_id, -1, -1); + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); // Only a single recurrent cell at most, // because otherwise the cells can be shuffled when a slot is allocated if (rs_cell_count > 1) { @@ -22250,9 +21587,9 @@ struct llama_data_read { if (!res) { if (seq_id == -1) { - llama_past_clear(ctx); + llama_kv_cache_clear(ctx); } else { - llama_past_seq_rm(ctx, seq_id, -1, -1); + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); }