kv-cache : prepare for abstraction

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-18 21:26:42 +02:00
parent 2bffc2d514
commit f5cedbcaaa
7 changed files with 594 additions and 534 deletions

View File

@@ -2,12 +2,12 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-graph.h"
#include "ggml-cpp.h"
#include <set>
#include <vector>
#include <functional>
struct llama_cparams;
struct llama_hparams;
@@ -49,31 +49,13 @@ struct llama_kv_cache_slot_info {
// TODO: pimpl
// TODO: add notion of max sequences
// TODO: add llama_hparams &
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
struct llama_kv_cache : public llama_graph_kv_cache_i {
llama_kv_cache(const llama_hparams & hparams);
virtual ~llama_kv_cache() = default;
// TODO: become constructor
bool init(
const llama_model & model,
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
ggml_type type_k,
ggml_type type_v,
@@ -115,8 +97,48 @@ struct llama_kv_cache {
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
// graph build API
virtual void build_shift(
ggml_context * ctx0,
ggml_cgraph * gf,
llama_graph_i * lgf) override;
virtual void build_defrag(
ggml_context * ctx0,
ggml_cgraph * gf,
int32_t max_nodes,
bool v_trans) override;
// state save/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
private:
ggml_type type_k = GGML_TYPE_F16;
@@ -126,10 +148,10 @@ private:
std::vector<ggml_backend_buffer_ptr> bufs;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//