mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
kv-cache : basic abstraction
ggml-ci
This commit is contained in:
@@ -45,12 +45,39 @@ struct llama_kv_cache_slot_info {
|
||||
operator bool() const { return found; }
|
||||
};
|
||||
|
||||
struct llama_kv_cache {
|
||||
public:
|
||||
virtual int32_t n_tokens() const = 0;
|
||||
virtual uint32_t used_cells() const = 0; // TODO: remove
|
||||
|
||||
virtual void clear() = 0;
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
||||
|
||||
virtual void defrag() = 0;
|
||||
virtual bool get_can_shift() const = 0;
|
||||
};
|
||||
|
||||
|
||||
// C++ alias
|
||||
class llama_kv_cache_i : public llama_kv_cache {
|
||||
public:
|
||||
using llama_kv_cache::llama_kv_cache;
|
||||
};
|
||||
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
struct llama_kv_cache {
|
||||
llama_kv_cache(const llama_hparams & hparams);
|
||||
virtual ~llama_kv_cache() = default;
|
||||
class llama_kv_cache_unified : public llama_kv_cache_i {
|
||||
public:
|
||||
llama_kv_cache_unified(const llama_hparams & hparams);
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
@@ -61,24 +88,26 @@ struct llama_kv_cache {
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
|
||||
int32_t n_tokens() const;
|
||||
int32_t n_tokens() const override;
|
||||
uint32_t used_cells() const override;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
// TODO: better data structures to reduce the cost of this operation
|
||||
llama_pos pos_max() const;
|
||||
|
||||
void clear();
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
|
||||
void seq_keep(llama_seq_id seq_id);
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id);
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
||||
|
||||
void defrag();
|
||||
void defrag() override;
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
// updates the cache head
|
||||
@@ -143,9 +172,10 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
|
||||
struct llama_kv_cache_recurrent : public llama_kv_cache {
|
||||
using llama_kv_cache::llama_kv_cache;
|
||||
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
|
||||
class llama_kv_cache_recurrent : public llama_kv_cache_unified {
|
||||
public:
|
||||
using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||
};
|
||||
|
||||
//
|
||||
@@ -166,9 +196,9 @@ struct llama_kv_slot_restorer {
|
||||
|
||||
bool do_restore = false;
|
||||
|
||||
llama_kv_cache & cache;
|
||||
llama_kv_cache_unified & cache;
|
||||
|
||||
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
|
||||
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
|
||||
old_state.head = cache.head;
|
||||
old_state.n = cache.n;
|
||||
}
|
||||
@@ -249,4 +279,4 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
|
||||
|
||||
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
|
||||
|
||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
|
||||
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);
|
||||
|
||||
Reference in New Issue
Block a user