From efc33ea60d82af222ca58a4172fdb627ea08f4c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Jun 2025 07:17:44 +0300 Subject: [PATCH] wip 2 --- src/llama-kv-cache-unified.cpp | 90 +++++++++++++++++++++++++++++----- src/llama-kv-cache-unified.h | 12 +++-- 2 files changed, 86 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 26bdd390b6..ae2b7489e1 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -59,9 +59,25 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + v_heads[s] = 0; + } - cells.resize(kv_size); + GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max); + + for (uint32_t s = 0; s < n_seq_virt; ++s) { + v_cells[s].resize(kv_size); + } + + // by default, all sequence ids are mapped to the 0th virtual sequence + seq_virt_idx.resize(LLAMA_MAX_SEQ, 0); + + if (n_seq_virt > 1) { + seq_virt_idx.resize(n_seq_virt, 0); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + seq_virt_idx[s] = s; + } + } for (uint32_t il = 0; il < hparams.n_layer; il++) { if (filter && !filter(il)) { @@ -141,9 +157,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_seq_virt; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -153,6 +170,9 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + uint32_t new_head = cells.size(); if (p0 < 0) { @@ -199,6 +219,10 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + GGML_ASSERT(n_seq_virt == 1 && "TODO: implement seq_cp() for n_seq_virt > 1"); + + auto & cells = v_cells[seq_virt_idx[seq_id_src]]; + if (seq_id_src == seq_id_dst) { return; } @@ -223,6 +247,9 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -240,6 +267,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + if (shift == 0) { return; } @@ -279,6 +309,9 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + auto & cells = v_cells[seq_virt_idx[seq_id]]; + auto & head = v_heads[seq_virt_idx[seq_id]]; + if (d == 1) { return; } @@ -308,10 +341,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + const auto & cells = v_cells[seq_virt_idx[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + const auto & cells = v_cells[seq_virt_idx[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -357,7 +394,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_seq_virt == 1) { + // note : for now do not consider defrag for n_seq_virt > 1 + const auto & cells = v_cells[seq_virt_idx[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -473,12 +513,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_seq_virt; ++s) { + auto & cells = v_cells[seq_virt_idx[s]]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_seq_virt > 1 + auto & cells = v_cells[seq_virt_idx[0]]; + auto & head = v_heads[seq_virt_idx[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -532,7 +580,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ const uint32_t n_tokens = ubatch.n_tokens; - uint32_t head_cur = this->head; + // TODO: implement + auto & cells = v_cells[seq_virt_idx[0]]; + + uint32_t head_cur = v_heads[0]; // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it @@ -546,7 +597,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); if ((debug == 2 && n_swa > 0) || debug > 2) { std::string ss; @@ -748,15 +800,31 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + auto & cells = v_cells[seq_virt_idx[0]]; + return cells.size(); } bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_seq_virt; ++s) { + result |= v_cells[seq_virt_idx[s]].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_seq_virt; ++s) { + const auto & cells = v_cells[seq_virt_idx[s]]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 2d361549fe..ef28b7b848 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -167,10 +167,6 @@ private: bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - const uint32_t n_seq_max = 1; const uint32_t n_seq_virt = 1; @@ -192,7 +188,13 @@ private: std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t v_heads[LLAMA_MAX_SEQ]; + + llama_kv_cells_unified v_cells[LLAMA_MAX_SEQ]; + + std::vector seq_virt_idx; std::vector layers;