context : add llama_kv_cache_recurrent prototype

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-20 20:54:18 +02:00
parent ad870c49f4
commit 08011c2ca1
3 changed files with 477 additions and 102 deletions

View File

@@ -359,17 +359,17 @@ int32_t llama_context::max_nodes() const {
}
llama_kv_cache * llama_context::get_kv_self() {
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
return nullptr;
}
const llama_kv_cache * llama_context::get_kv_self() const {
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
return nullptr;
}
void llama_context::kv_self_update() {
LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__);
LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__);
}
enum llama_pooling_type llama_context::pooling_type() const {
@@ -2246,14 +2246,7 @@ llama_context_kv_self::llama_context_kv_self(
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(&model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(!llama_model_is_recurrent(&model));
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
@@ -2286,6 +2279,61 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
return &kv_self;
}
void llama_context_kv_self::kv_self_update() {
auto & kv = kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
}
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
build_kv_self_shift(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set({});
graph_compute(gf, false);
need_reserve = true;
}
{
kv.has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) {
kv.cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv.do_defrag) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
build_kv_self_defrag(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
// no input
//input_set({});
graph_compute(gf, false);
kv.do_defrag = false;
need_reserve = true;
}
}
ggml_cgraph * llama_context_kv_self::graph_init() {
inp_embd_enc = nullptr;
inp_pos_bucket = nullptr;
@@ -2310,7 +2358,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
@@ -2470,7 +2518,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
@@ -2552,7 +2600,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const bool logits_all = n_outputs_all == n_tokens_all;
sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* simple_split */ true,
/* logits_all */ logits_all);
// reserve output buffer
@@ -2569,18 +2617,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const auto & n_ubatch = cparams.n_ubatch;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(n_ubatch);
}
} else {
ubatch = sbatch.split_simple(n_ubatch);
}
ubatch = sbatch.split_simple(n_ubatch);
// count the outputs in this u_batch
{
@@ -2617,7 +2654,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
bg.save(slot_info);
if (!kv_self.recurrent) {
{
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
@@ -2821,10 +2858,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
return 0;
}
llama_pos llama_context_kv_self::pos_max() const {
return kv_self.pos_max();
}
uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const {
return kv_self.get_padding(cparams);
}
@@ -3062,61 +3095,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
}
}
void llama_context_kv_self::kv_self_update() {
auto & kv = kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
}
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
build_kv_self_shift(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set({});
graph_compute(gf, false);
need_reserve = true;
}
{
kv.has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) {
kv.cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv.do_defrag) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
build_kv_self_defrag(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
// no input
//input_set({});
graph_compute(gf, false);
kv.do_defrag = false;
need_reserve = true;
}
}
ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) {
inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx());
ggml_set_input(inp_self_k_shift);
@@ -3176,7 +3154,9 @@ ggml_tensor * llama_context_kv_self::build_attn(
// store to KV cache
{
const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;
GGML_ASSERT(!kv_self.recurrent);
const auto kv_head = worst_case ? kv_self.size - n_tokens : kv_self.head;
GGML_ASSERT(kv_self.size == n_ctx);
@@ -3684,22 +3664,406 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
llama_context_recurrent::llama_context_recurrent(
const llama_model & model,
const llama_context_params & params) :
llama_context_kv_self(model, params) {
llama_context(model, params),
kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
const auto & hparams = model.hparams;
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
// Mamba only needs a constant number of KV cache cells per sequence
GGML_ASSERT(llama_model_is_recurrent(&model));
// Mamba needs at least as many KV cells as there are sequences kept at any time
uint32_t kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
ggml_type type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
ggml_type type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) {
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
throw std::runtime_error("failed to initialize self-attention cache");
}
{
const size_t memory_size_k = kv_self.size_k_bytes();
const size_t memory_size_v = kv_self.size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
}
}
llama_context_recurrent::~llama_context_recurrent() = default;
ggml_cgraph * llama_context_recurrent::graph_init() {
inp_s_copy = nullptr;
inp_s_mask = nullptr;
llama_kv_cache * llama_context_recurrent::get_kv_self() {
return &kv_self;
}
return llama_context_kv_self::graph_init();
const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
return &kv_self;
}
void llama_context_recurrent::kv_self_update() {
// noop
}
ggml_cgraph * llama_context_recurrent::graph_init() {
inp_s_copy = nullptr;
inp_s_mask = nullptr;
return llama_context::graph_init();
}
int llama_context_recurrent::encode(llama_batch & inp_batch) {
GGML_UNUSED(inp_batch);
LLAMA_LOG_ERROR("%s: encode() not supported for recurrent models\n", __func__);
return -1;
}
int llama_context_recurrent::decode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
int64_t n_outputs_all = 0;
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
sbatch.from_batch(batch, n_embd,
/* simple_split */ false,
/* logits_all */ logits_all);
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
return -2;
};
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = llama_ubatch();
const auto & n_ubatch = cparams.n_ubatch;
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(n_ubatch);
}
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
auto * gf = graph_init();
graph_build(ctx_compute.get(), gf, ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(sched.get());
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
need_reserve = false;
}
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, false);
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
ggml_backend_sched_alloc_graph(sched.get(), gf);
input_set(ubatch);
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
}
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
if (t_embd && res.t_embd_pooled) {
t_embd = res.t_embd_pooled;
}
// extract logits
if (t_logits && n_outputs > 0) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
float * logits_out = logits + n_outputs_prev*n_vocab;
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
}
}
// extract embeddings
if (t_embd && n_outputs > 0) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd + n_outputs_prev*n_embd;
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
}
n_outputs_prev += n_outputs;
}
// finalize the batch processing
bg.done();
// set output mappings
{
bool sorted_output = true;
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
for (int64_t i = 0; i < n_outputs_all; ++i) {
int64_t out_id = sbatch.out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
sorted_output = false;
}
}
if (sorted_output) {
sbatch.out_ids.clear();
}
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
return 0;
}
void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_kv_self::input_set(ubatch);
llama_context::input_set(ubatch);
GGML_ASSERT(kv_self.recurrent);

View File

@@ -374,9 +374,6 @@ public:
virtual int encode(llama_batch & inp_batch) override;
virtual int decode(llama_batch & inp_batch) override;
// max token position across all sequences in the current context
llama_pos pos_max() const;
// certain implementations could require a padding for the context size
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
@@ -453,9 +450,7 @@ protected:
};
// a recurrent transformer (ie.e RWKV, Mamba)
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
//class llama_context_recurrent : public llama_context {
class llama_context_recurrent : public llama_context_kv_self {
class llama_context_recurrent : public llama_context {
public:
llama_context_recurrent(
const llama_model & model,
@@ -463,8 +458,16 @@ public:
virtual ~llama_context_recurrent();
virtual llama_kv_cache * get_kv_self() override;
virtual const llama_kv_cache * get_kv_self() const override;
virtual void kv_self_update() override;
virtual ggml_cgraph * graph_init() override;
virtual int encode(llama_batch & inp_batch) override;
virtual int decode(llama_batch & inp_batch) override;
virtual ggml_tensor * build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) override;
@@ -524,10 +527,11 @@ public:
protected:
virtual void input_set(const llama_ubatch & ubatch) override;
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
llama_kv_cache_recurrent kv_self;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
// TODO: add recurrent cache
};
// For internal test use

View File

@@ -48,7 +48,6 @@ struct llama_kv_cache_slot_info {
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
// TODO: add llama_hparams &
struct llama_kv_cache {
llama_kv_cache(const llama_hparams & hparams);
virtual ~llama_kv_cache() = default;
@@ -108,7 +107,10 @@ struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
@@ -141,6 +143,11 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache
struct llama_kv_cache_recurrent : public llama_kv_cache {
using llama_kv_cache::llama_kv_cache;
};
//
// kv cache restore
//