kv-cache : refactor + add llama_memory_state_i (#13746)

* kv-cache : simplify the "struct llama_kv_cache" interface

ggml-ci

* kv-cache : revert the (n_swa + n_ubatch) change (for next PR)

ggml-ci

* kv-cache : some comments

ggml-ci

* context : fix graph reserve for multiple sequences

ggml-ci

* kv-cache : fix typo [no ci]

* kv-cache : fix find_slot() logic for free slots

ggml-ci

* llama : add TODO for deprecating the defrag API in the future

* kv-cache : improve find_slot() using min/max seq pos info

ggml-ci

* llama : handle aborts and compute errors

ggml-ci

* memory : extract state into llama_memory_state

ggml-ci

* kv-cache : add comments

ggml-ci

* server : update batching logic to reset n_batch on successful decode

* server : upon full re-processing, remove the sequence from the cache

* kv-cache : add TODO for doing split_equal when split_simple fails

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-31 10:24:04 +03:00
committed by GitHub
parent eb3949938e
commit 12d0188c0d
14 changed files with 1304 additions and 655 deletions

View File

@@ -2,6 +2,7 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "llama-kv-cells.h"
@@ -14,48 +15,35 @@
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
// call if batch processing fails - restores the cache state
virtual void restore() = 0;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// call after successful batch processing - clears any pending state
virtual void commit() = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// process any pending defrag/shift/etc. operations
// optionally call once before processing a new batch
// return true if any operations were performed
virtual bool update(llama_context & lctx) = 0;
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
// TODO: change to
// llama_memory_state_ptr init_defrag(float thold) = 0;
//
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
// TODO: remove
virtual void set_full() = 0;
//
// batch processing
//
// =============================================================================================================
// TODO: refactor and simplify this [TAG: KV_API]
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;
// =============================================================================================================
// getters
virtual bool get_can_shift() const = 0;
@@ -69,25 +57,6 @@ struct llama_kv_cache : public llama_memory_i {
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
//
// llama_kv_cache_guard
//
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
~llama_kv_cache_guard() {
kv->restore();
}
void commit() {
kv->commit();
}
private:
llama_kv_cache * kv;
};
//
// llama_kv_cache_unified
//
@@ -133,23 +102,18 @@ public:
// llama_kv_cache
//
void restore() override;
void commit() override;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool update(llama_context & ctx) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;
bool get_can_shift() const override;
// state write/load
@@ -161,18 +125,40 @@ public:
// llama_kv_cache_unified specific API
//
uint32_t get_n() const;
uint32_t get_size() const;
//
// graph_build API
//
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
// store k_cur and v_cur in the cache based on the current head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
//
// preparation API
//
// find places for the provided ubatches in the cache, returns the head locations
// return empty vector on failure
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
// return the cell position where we can insert the ubatch
// return -1 on failure to find a contiguous slot of kv cells
int32_t find_slot(const llama_ubatch & ubatch) const;
// emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
//
// set_input API
//
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_shift (ggml_tensor * dst) const;
@@ -194,11 +180,9 @@ private:
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
// computed before each graph build
// TODO: cells should start to maintain this value dynamically based on the edits
uint32_t n = 0;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0;
const uint32_t n_seq_max = 1;
@@ -220,24 +204,6 @@ private:
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// recovery information used to restore the KV cells to their original state in case of a failure
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
struct {
void clear() {
states.clear();
}
struct state {
uint32_t i;
llama_kv_cells_unified cells;
};
// stack with the partial states before each ubatch
std::vector<state> states;
} recovery;
// defrag
struct {
std::vector<uint32_t> ids;
@@ -279,13 +245,88 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_kv_cache_unified_state : public llama_memory_state_i {
public:
// used for errors
llama_kv_cache_unified_state(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv);
// used to create a state from a batch
llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state();
//
// llama_memory_state_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_state specific API
//
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
const llama_memory_status status;
llama_kv_cache_unified * kv;
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<uint32_t> heads;
std::vector<llama_ubatch> ubatches;
//
// data needed for building the compute graph for the current ubatch:
//
// a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv;
// the beginning of the current slot in which the ubatch will be inserted
int32_t head;
};
//
// llama_kv_cache_unified_iswa
//
// utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
// upon successful commit, the SWA cache removes old tokens outside the n_swa window
class llama_kv_cache_unified_iswa : public llama_kv_cache {
public:
@@ -322,20 +363,18 @@ public:
// llama_kv_cache
//
void restore() override;
void commit() override;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool update(llama_context & ctx) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
bool get_can_shift() const override;
// state write/load
@@ -347,58 +386,80 @@ public:
// llama_kv_cache_unified_iswa specific API
//
llama_kv_cache_unified * get_kv_base() const;
llama_kv_cache_unified * get_kv_swa () const;
llama_kv_cache_unified * get_base() const;
llama_kv_cache_unified * get_swa () const;
private:
const llama_hparams & hparams;
bool do_prune = true;
struct {
struct entry {
llama_pos pmin;
llama_pos pmax;
};
void clear() {
pos.clear();
}
// used to perform SWA pruning of old tokens
std::unordered_map<llama_seq_id, entry> pos;
} pending;
std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
public:
// used for errors
llama_kv_cache_unified_iswa_state(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv);
// used to create a state from a batch
llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_state();
//
// llama_memory_state_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_state specific API
//
const llama_kv_cache_unified_state * get_base() const;
const llama_kv_cache_unified_state * get_swa() const;
private:
const llama_memory_status status;
//llama_kv_cache_unified_iswa * kv;
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
std::unique_ptr<llama_kv_cache_unified_state> state_base;
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
};
//
// llama_kv_cache_recurrent
//
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
@@ -428,19 +489,22 @@ public:
// llama_kv_cache
//
void restore() override;
void commit() override;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool update(llama_context & ctx) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
bool prepare(const std::vector<llama_ubatch> & ubatches);
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
// find a contiguous slot of kv cells and emplace the ubatch there
bool find_slot(const llama_ubatch & ubatch);
bool get_can_shift() const override;
@@ -460,6 +524,27 @@ public:
// computed before each graph build
uint32_t n = 0;
// TODO: optimize for recurrent state needs
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
@@ -469,26 +554,11 @@ private:
//const llama_model & model;
const llama_hparams & hparams;
// commit/restore cache
// TODO: rework for recurrent cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
const uint32_t n_seq_max = 1;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
@@ -500,3 +570,67 @@ private:
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
public:
// used for errors
llama_kv_cache_recurrent_state(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv);
// used to create a state from a batch
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_recurrent_state();
//
// llama_memory_state_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_recurrent_state specific API
//
uint32_t get_n_kv() const;
uint32_t get_head() const;
uint32_t get_size() const;
ggml_tensor * get_k_l(int32_t il) const;
ggml_tensor * get_v_l(int32_t il) const;
int32_t s_copy(int i) const;
float s_mask(int i) const;
private:
const llama_memory_status status;
llama_kv_cache_recurrent * kv;
llama_sbatch sbatch;
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
//
// data needed for building the compute graph for the current ubatch:
// TODO: extract all the state like `head` and `n` here
//
const bool is_full = false;
};