kv-cache : rework kv_cell (#13706)

* kv-cache : rework kv_cell

ggml-ci

* kv-cells : use "shift" instead of "delta" consistently

ggml-ci

* llama : add llama_max_parallel_sequences()

ggml-ci

* kv-cells : update comments [no ci]

* context : fail upon construction if sequences exceed max value

ggml-ci

* kv-cells : get_pos() -> pos_get() + comments

ggml-ci

* kv-cells : fix tracking of "used" cells

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-25 16:34:36 +03:00
committed by GitHub
parent c508256db2
commit de2ef53a4b
8 changed files with 470 additions and 253 deletions

View File

@@ -4,6 +4,7 @@
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "llama-kv-cells.h"
#include "ggml-cpp.h"
@@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
// TODO: remove
virtual void set_full() = 0;
//
@@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
//
// =============================================================================================================
// TODO: refactor and simplify this
// TODO: refactor and simplify this [TAG: KV_API]
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
@@ -121,7 +123,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -159,7 +161,7 @@ public:
// llama_kv_cache_unified specific API
//
uint32_t get_n() const;
uint32_t get_n() const;
uint32_t get_size() const;
// get views of the current state of the cache
@@ -180,26 +182,6 @@ private:
const llama_model & model;
const llama_hparams & hparams;
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
// TODO: replace with bitset uint64_t
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
struct kv_layer {
// layer index in the model
// note: can be different from the layer index in the KV cache
@@ -209,15 +191,13 @@ private:
ggml_tensor * v;
};
bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
// computed before each graph build
// TODO: cells should start to maintain this value dynamically based on the edits
uint32_t n = 0;
const uint32_t n_seq_max = 1;
@@ -233,19 +213,29 @@ private:
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
llama_kv_cells_unified cells;
std::vector<kv_layer> layers;
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// recovery information used to restore the KV cells to their original state in case of a failure
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
struct {
void clear() {
cells.clear();
states.clear();
}
std::unordered_map<uint32_t, kv_cell> cells;
struct state {
uint32_t i;
llama_kv_cells_unified cells;
};
// stack with the partial states before each ubatch
std::vector<state> states;
} recovery;
// defrag
@@ -257,6 +247,7 @@ private:
bool defrag_prepare(int32_t n_max_nodes);
// find how many cells are currently in use
// TODO: optimize
uint32_t cell_max() const;
size_t total_size() const;
@@ -325,7 +316,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@@ -431,7 +422,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;