This commit is contained in:
Georgi Gerganov
2025-06-24 07:17:44 +03:00
parent 7664390bc8
commit efc33ea60d
2 changed files with 86 additions and 16 deletions

View File

@@ -59,9 +59,25 @@ llama_kv_cache_unified::llama_kv_cache_unified(
return it->second;
};
head = 0;
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
v_heads[s] = 0;
}
cells.resize(kv_size);
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
v_cells[s].resize(kv_size);
}
// by default, all sequence ids are mapped to the 0th virtual sequence
seq_virt_idx.resize(LLAMA_MAX_SEQ, 0);
if (n_seq_virt > 1) {
seq_virt_idx.resize(n_seq_virt, 0);
for (uint32_t s = 0; s < n_seq_virt; ++s) {
seq_virt_idx[s] = s;
}
}
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
@@ -141,9 +157,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
}
void llama_kv_cache_unified::clear(bool data) {
cells.reset();
head = 0;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
v_cells[s].reset();
v_heads[s] = 0;
}
if (data) {
for (auto & buf : bufs) {
@@ -153,6 +170,9 @@ void llama_kv_cache_unified::clear(bool data) {
}
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
uint32_t new_head = cells.size();
if (p0 < 0) {
@@ -199,6 +219,10 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
}
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
GGML_ASSERT(n_seq_virt == 1 && "TODO: implement seq_cp() for n_seq_virt > 1");
auto & cells = v_cells[seq_virt_idx[seq_id_src]];
if (seq_id_src == seq_id_dst) {
return;
}
@@ -223,6 +247,9 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
}
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
uint32_t new_head = cells.size();
for (uint32_t i = 0; i < cells.size(); ++i) {
@@ -240,6 +267,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
if (shift == 0) {
return;
}
@@ -279,6 +309,9 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
}
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
auto & cells = v_cells[seq_virt_idx[seq_id]];
auto & head = v_heads[seq_virt_idx[seq_id]];
if (d == 1) {
return;
}
@@ -308,10 +341,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
const auto & cells = v_cells[seq_virt_idx[seq_id]];
return cells.seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
const auto & cells = v_cells[seq_virt_idx[seq_id]];
return cells.seq_pos_max(seq_id);
}
@@ -357,7 +394,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
defrag_info dinfo;
// see if we need to defrag
{
if (n_seq_virt == 1) {
// note : for now do not consider defrag for n_seq_virt > 1
const auto & cells = v_cells[seq_virt_idx[0]];
bool do_defrag = optimize;
const auto thold = lctx->get_cparams().defrag_thold;
@@ -473,12 +513,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
updated = true;
}
cells.reset_shift();
for (uint32_t s = 0; s < n_seq_virt; ++s) {
auto & cells = v_cells[seq_virt_idx[s]];
cells.reset_shift();
}
}
if (!dinfo.empty()) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
// note: for now do not consider defrag for n_seq_virt > 1
auto & cells = v_cells[seq_virt_idx[0]];
auto & head = v_heads[seq_virt_idx[0]];
// apply moves:
{
const auto n_kv = dinfo.ids.size();
@@ -532,7 +580,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head;
// TODO: implement
auto & cells = v_cells[seq_virt_idx[0]];
uint32_t head_cur = v_heads[0];
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
@@ -546,7 +597,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
}
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
if ((debug == 2 && n_swa > 0) || debug > 2) {
std::string ss;
@@ -748,15 +800,31 @@ bool llama_kv_cache_unified::get_can_shift() const {
}
uint32_t llama_kv_cache_unified::get_size() const {
auto & cells = v_cells[seq_virt_idx[0]];
return cells.size();
}
bool llama_kv_cache_unified::get_has_shift() const {
return cells.get_has_shift();
bool result = false;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
result |= v_cells[seq_virt_idx[s]].get_has_shift();
}
return result;
}
uint32_t llama_kv_cache_unified::get_n_kv() const {
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
uint32_t result = 0;
for (uint32_t s = 0; s < n_seq_virt; ++s) {
const auto & cells = v_cells[seq_virt_idx[s]];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
}
return result;
}
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {

View File

@@ -167,10 +167,6 @@ private:
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
const uint32_t n_seq_virt = 1;
@@ -192,7 +188,13 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
llama_kv_cells_unified cells;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t v_heads[LLAMA_MAX_SEQ];
llama_kv_cells_unified v_cells[LLAMA_MAX_SEQ];
std::vector<uint32_t> seq_virt_idx;
std::vector<kv_layer> layers;