mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
kv-cache : prepare for abstraction
ggml-ci
This commit is contained in:
@@ -2,12 +2,12 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
@@ -49,31 +49,13 @@ struct llama_kv_cache_slot_info {
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
// TODO: add llama_hparams &
|
||||
struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
std::vector<struct ggml_tensor *> v_l;
|
||||
struct llama_kv_cache : public llama_graph_kv_cache_i {
|
||||
llama_kv_cache(const llama_hparams & hparams);
|
||||
virtual ~llama_kv_cache() = default;
|
||||
|
||||
// TODO: become constructor
|
||||
bool init(
|
||||
const llama_model & model,
|
||||
const llama_model & model, // TODO: do not reference the model
|
||||
const llama_cparams & cparams,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
@@ -115,8 +97,48 @@ struct llama_kv_cache {
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
|
||||
// graph build API
|
||||
|
||||
virtual void build_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
llama_graph_i * lgf) override;
|
||||
|
||||
virtual void build_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
int32_t max_nodes,
|
||||
bool v_trans) override;
|
||||
|
||||
// state save/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||
|
||||
// members
|
||||
|
||||
const llama_hparams & hparams;
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
bool can_shift = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_impl also uses it, so it
|
||||
// cannot be freely changed after a slot has been allocated.
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
||||
// computed before each graph build
|
||||
uint32_t n = 0;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
std::vector<struct ggml_tensor *> v_l;
|
||||
|
||||
private:
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
@@ -126,10 +148,10 @@ private:
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user