kv-cache : basic abstraction

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-27 15:54:44 +02:00
parent 82675a0180
commit 828effd9d7
4 changed files with 244 additions and 198 deletions

View File

@@ -45,12 +45,39 @@ struct llama_kv_cache_slot_info {
operator bool() const { return found; }
};
struct llama_kv_cache {
public:
virtual int32_t n_tokens() const = 0;
virtual uint32_t used_cells() const = 0; // TODO: remove
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual void defrag() = 0;
virtual bool get_can_shift() const = 0;
};
// C++ alias
class llama_kv_cache_i : public llama_kv_cache {
public:
using llama_kv_cache::llama_kv_cache;
};
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
struct llama_kv_cache {
llama_kv_cache(const llama_hparams & hparams);
virtual ~llama_kv_cache() = default;
class llama_kv_cache_unified : public llama_kv_cache_i {
public:
llama_kv_cache_unified(const llama_hparams & hparams);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
@@ -61,24 +88,26 @@ struct llama_kv_cache {
uint32_t kv_size,
bool offload);
int32_t n_tokens() const;
int32_t n_tokens() const override;
uint32_t used_cells() const override;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
void clear();
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
void seq_keep(llama_seq_id seq_id);
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id);
llama_pos seq_pos_max(llama_seq_id seq_id) override;
void defrag();
void defrag() override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
@@ -143,9 +172,10 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
struct llama_kv_cache_recurrent : public llama_kv_cache {
using llama_kv_cache::llama_kv_cache;
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
class llama_kv_cache_recurrent : public llama_kv_cache_unified {
public:
using llama_kv_cache_unified::llama_kv_cache_unified;
};
//
@@ -166,9 +196,9 @@ struct llama_kv_slot_restorer {
bool do_restore = false;
llama_kv_cache & cache;
llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
@@ -249,4 +279,4 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);