kv-cache : basic abstraction

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-27 15:54:44 +02:00
parent 82675a0180
commit 828effd9d7
4 changed files with 244 additions and 198 deletions

View File

@@ -2384,15 +2384,16 @@ llama_context_kv_self::llama_context_kv_self(
const llama_model & model, const llama_model & model,
llama_context_params params, llama_context_params params,
llama_graph_type gtype) : llama_graph_type gtype) :
llama_context_base(model, params, gtype), llama_context_base(model, params, gtype) {
kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__);
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
kv_self = std::make_unique<llama_kv_cache_unified>(hparams);
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self.get_padding(cparams)); cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
@@ -2406,14 +2407,14 @@ llama_context_kv_self::llama_context_kv_self(
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
throw std::runtime_error("failed to initialize self-attention cache"); throw std::runtime_error("failed to initialize self-attention cache");
} }
{ {
const size_t memory_size_k = kv_self.size_k_bytes(); const size_t memory_size_k = kv_self->size_k_bytes();
const size_t memory_size_v = kv_self.size_v_bytes(); const size_t memory_size_v = kv_self->size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
@@ -2427,19 +2428,19 @@ llama_context_kv_self::~llama_context_kv_self() = default;
void llama_context_kv_self::reserve() { void llama_context_kv_self::reserve() {
// simulate full KV cache // simulate full KV cache
kv_self.n = kv_self.size; kv_self->n = kv_self->size;
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n);
llama_context_base::reserve(); llama_context_base::reserve();
} }
llama_kv_cache * llama_context_kv_self::get_kv_self() { llama_kv_cache * llama_context_kv_self::get_kv_self() {
return &kv_self; return kv_self.get();
} }
const llama_kv_cache * llama_context_kv_self::get_kv_self() const { const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
return &kv_self; return kv_self.get();
} }
void llama_context_kv_self::kv_self_update() { void llama_context_kv_self::kv_self_update() {
@@ -2449,8 +2450,8 @@ void llama_context_kv_self::kv_self_update() {
bool need_reserve = false; bool need_reserve = false;
if (kv.has_shift) { if (kv->has_shift) {
if (!kv.can_shift) { if (!kv->get_can_shift()) {
GGML_ABORT("The current context does not support K-shift"); GGML_ABORT("The current context does not support K-shift");
} }
@@ -2474,16 +2475,16 @@ void llama_context_kv_self::kv_self_update() {
} }
{ {
kv.has_shift = false; kv->has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) { for (uint32_t i = 0; i < kv->size; ++i) {
kv.cells[i].delta = 0; kv->cells[i].delta = 0;
} }
} }
} }
// defragment the KV cache if needed // defragment the KV cache if needed
if (kv.do_defrag) { if (kv->do_defrag) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
@@ -2499,7 +2500,7 @@ void llama_context_kv_self::kv_self_update() {
graph_compute(gf, false); graph_compute(gf, false);
kv.do_defrag = false; kv->do_defrag = false;
need_reserve = true; need_reserve = true;
} }
@@ -2513,7 +2514,7 @@ void llama_context_kv_self::kv_self_update() {
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// simulate full KV cache // simulate full KV cache
kv_self.n = kv_self.size; kv_self->n = kv_self->size;
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -2537,7 +2538,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens; const int32_t n_tokens = batch.n_tokens;
@@ -2674,7 +2675,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@@ -2689,7 +2690,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// TODO: remove this stuff // TODO: remove this stuff
class batch_guard { class batch_guard {
public: public:
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
} }
~batch_guard() { ~batch_guard() {
@@ -2712,7 +2713,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
llama_kv_slot_restorer kv_slot_restorer; llama_kv_slot_restorer kv_slot_restorer;
}; };
batch_guard bg(kv_self); batch_guard bg(*kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@@ -2797,11 +2798,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self.head = 0; kv_self->head = 0;
} }
const auto slot_info = kv_self.find_slot(ubatch); const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) { if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3; return -3;
@@ -2813,12 +2814,12 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams); const uint32_t pad = kv_self->get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
} }
} }
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
@@ -2847,11 +2848,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// update the kv ring buffer // update the kv ring buffer
{ {
kv_self.head += ubatch.n_tokens; kv_self->head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index. // Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) { if (kv_self->head >= kv_self->size) {
kv_self.head = 0; kv_self->head = 0;
} }
} }
@@ -2972,13 +2973,13 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens) // - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens // - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + kv_self.get_padding(cparams))/float(kv_self.n)) : 0.0f; const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update // queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) { if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self.defrag(); kv_self->defrag();
} }
} }
@@ -2997,8 +2998,8 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
int32_t * data = (int32_t *) inp.self_k_shift->data; int32_t * data = (int32_t *) inp.self_k_shift->data;
for (uint32_t i = 0; i < kv_self.size; ++i) { for (uint32_t i = 0; i < kv_self->size; ++i) {
data[i] = kv_self.cells[i].delta; data[i] = kv_self->cells[i].delta;
} }
// the K-shift graph requires just this input // the K-shift graph requires just this input
@@ -3011,7 +3012,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
if (inp.self_kq_mask || inp.self_kq_mask_swa) { if (inp.self_kq_mask || inp.self_kq_mask_swa) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) { if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self->n;
const int64_t n_tokens = ubatch.n_tokens; const int64_t n_tokens = ubatch.n_tokens;
const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
@@ -3041,11 +3042,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
float f; float f;
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
f = -INFINITY; f = -INFINITY;
} else { } else {
if (hparams.use_alibi) { if (hparams.use_alibi) {
f = -std::abs(kv_self.cells[i].pos - pos); f = -std::abs(kv_self->cells[i].pos - pos);
} else { } else {
f = 0.0f; f = 0.0f;
} }
@@ -3057,7 +3058,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
// may need to cut off old tokens for sliding window // may need to cut off old tokens for sliding window
if (data_swa) { if (data_swa) {
if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
f = -INFINITY; f = -INFINITY;
} }
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@@ -3137,11 +3138,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
int32_t * data = (int32_t *) inp.self_pos_bucket->data; int32_t * data = (int32_t *) inp.self_pos_bucket->data;
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self->n;
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false); data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false);
} }
} }
} }
@@ -3164,7 +3165,7 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0)
ggml_tensor * llama_context_kv_self::build_inp_pos_bucket( ggml_tensor * llama_context_kv_self::build_inp_pos_bucket(
ggml_context * ctx0, ggml_context * ctx0,
int32_t n_tokens) { int32_t n_tokens) {
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
ggml_set_input(inp.self_pos_bucket); ggml_set_input(inp.self_pos_bucket);
@@ -3177,7 +3178,7 @@ void llama_context_kv_self::build_attn_inp(
int32_t n_tokens, int32_t n_tokens,
bool causal, bool causal,
bool swa) { bool swa) {
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
inp.self_kq_mask = causal inp.self_kq_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
@@ -3224,13 +3225,13 @@ ggml_tensor * llama_context_kv_self::build_attn(
// store to KV cache // store to KV cache
{ {
GGML_ASSERT(!kv_self.recurrent); GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self.head; const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self.size == n_ctx); GGML_ASSERT(kv_self->size == n_ctx);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head); struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
//cb(k_cache_view, "k_cache_view", il); //cb(k_cache_view, "k_cache_view", il);
// note: storing RoPE-ed version of K in the KV cache // note: storing RoPE-ed version of K in the KV cache
@@ -3241,12 +3242,12 @@ ggml_tensor * llama_context_kv_self::build_attn(
struct ggml_tensor * v_cache_view = nullptr; struct ggml_tensor * v_cache_view = nullptr;
if (!v_trans) { if (!v_trans) {
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head); v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
} else { } else {
// note: the V cache is transposed when not using flash attention // note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa, v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv_self.v_l[il]), ( n_ctx)*ggml_element_size(kv_self->v_l[il]),
(kv_head)*ggml_element_size(kv_self.v_l[il])); (kv_head)*ggml_element_size(kv_self->v_l[il]));
v_cur = ggml_transpose(ctx0, v_cur); v_cur = ggml_transpose(ctx0, v_cur);
} }
@@ -3281,7 +3282,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv; const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv;
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_head_kv = hparams.n_head_kv(il);
@@ -3292,23 +3293,23 @@ ggml_tensor * llama_context_kv_self::build_attn(
//cb(q, "q", il); //cb(q, "q", il);
ggml_tensor * k = ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il], ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_kv, n_head_kv, n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
0); 0);
//cb(k, "k", il); //cb(k, "k", il);
ggml_tensor * v = !v_trans ? ggml_tensor * v = !v_trans ?
ggml_view_3d(ctx0, kv_self.v_l[il], ggml_view_3d(ctx0, kv_self->v_l[il],
n_embd_head_v, n_kv, n_head_kv, n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
0) : 0) :
ggml_view_3d(ctx0, kv_self.v_l[il], ggml_view_3d(ctx0, kv_self->v_l[il],
n_kv, n_embd_head_v, n_head_kv, n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self.v_l[il])*n_ctx, ggml_element_size(kv_self->v_l[il])*n_ctx,
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
0); 0);
struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
@@ -3326,7 +3327,7 @@ void llama_context_kv_self::build_kv_self_shift(
const auto & n_embd_head_k = hparams.n_embd_head_k; const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v; //const auto & n_embd_head_v = hparams.n_embd_head_v;
//GGML_ASSERT(kv_self.size == n_ctx); //GGML_ASSERT(kv_self->size == n_ctx);
ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0); ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0);
@@ -3337,13 +3338,13 @@ void llama_context_kv_self::build_kv_self_shift(
struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * k = struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il], ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_head_kv, kv_self.size, n_embd_head_k, n_head_kv, kv_self->size,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0); 0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self.k_l[il]->buffer); ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self->k_l[il]->buffer);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
} }
@@ -3356,8 +3357,8 @@ void llama_context_kv_self::build_kv_self_defrag(
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_kv = kv_self.cell_max(); const uint32_t n_kv = kv_self->cell_max();
const uint32_t n_used = kv_self.used; const uint32_t n_used = kv_self->used;
assert(n_used <= n_kv); assert(n_used <= n_kv);
@@ -3382,7 +3383,7 @@ void llama_context_kv_self::build_kv_self_defrag(
std::vector<uint32_t> ids(n_kv, n_kv); std::vector<uint32_t> ids(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) { for (uint32_t i0 = 0; i0 < n_used; ++i0) {
const auto & cell0 = kv_self.cells[i0]; const auto & cell0 = kv_self->cells[i0];
if (!cell0.is_empty()) { if (!cell0.is_empty()) {
ids[i0] = i0; ids[i0] = i0;
@@ -3395,7 +3396,7 @@ void llama_context_kv_self::build_kv_self_defrag(
uint32_t nh = 1; uint32_t nh = 1;
// determine the size of the hole // determine the size of the hole
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) { while (i0 + nh < n_used && kv_self->cells[i0 + nh].is_empty()) {
nh++; nh++;
} }
@@ -3404,7 +3405,7 @@ void llama_context_kv_self::build_kv_self_defrag(
// starting from the end, find nh non-empty cells // starting from the end, find nh non-empty cells
for (; is > i0; --is) { for (; is > i0; --is) {
const auto & cell1 = kv_self.cells[is]; const auto & cell1 = kv_self->cells[is];
if (cell1.is_empty() || ids[is] != n_kv) { if (cell1.is_empty() || ids[is] != n_kv) {
continue; continue;
@@ -3433,7 +3434,7 @@ void llama_context_kv_self::build_kv_self_defrag(
// go back and move the nf cells to the hole // go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) { for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1]; auto & cell1 = kv_self->cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) { if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) { if (n_moves == max_moves) {
@@ -3449,11 +3450,11 @@ void llama_context_kv_self::build_kv_self_defrag(
ids[i1] = i0 + nf; ids[i1] = i0 + nf;
// move the cell meta data // move the cell meta data
kv_self.cells[i0 + nf] = cell1; kv_self->cells[i0 + nf] = cell1;
// clear the old cell and move the head there // clear the old cell and move the head there
cell1 = llama_kv_cell(); cell1 = llama_kv_cell();
kv_self.head = n_used; kv_self->head = n_used;
if (!cont) { if (!cont) {
n_moves++; n_moves++;
@@ -3572,40 +3573,40 @@ void llama_context_kv_self::build_kv_self_defrag(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, nm, n_embd_k_gqa, nm,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i)); ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il], ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, nm, n_embd_k_gqa, nm,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
ggml_tensor * view_v_src; ggml_tensor * view_v_src;
ggml_tensor * view_v_dst; ggml_tensor * view_v_dst;
if (cparams.flash_attn) { if (cparams.flash_attn) {
// NOTE: the V cache is not transposed when using flash attention // NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, nm, n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, nm, n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
} else { } else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
nm, n_embd_v_gqa, nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self.v_l[il]->type, i)); ggml_row_size(kv_self->v_l[il]->type, i));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
nm, n_embd_v_gqa, nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self.v_l[il]->type, id)); ggml_row_size(kv_self->v_l[il]->type, id));
} }
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
@@ -3625,7 +3626,7 @@ size_t llama_context_kv_self::state_write_data(llama_io_write_i & io) {
llama_context_base::state_write_data(io); llama_context_base::state_write_data(io);
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self.state_write(io); kv_self->state_write(io);
return io.n_bytes(); return io.n_bytes();
} }
@@ -3634,7 +3635,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) {
llama_context_base::state_read_data(io); llama_context_base::state_read_data(io);
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
kv_self.state_read(io); kv_self->state_read(io);
return io.n_bytes(); return io.n_bytes();
} }
@@ -3642,7 +3643,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) {
size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
llama_context_base::state_seq_write_data(io, seq_id); llama_context_base::state_seq_write_data(io, seq_id);
kv_self.state_write(io, seq_id); kv_self->state_write(io, seq_id);
return io.n_bytes(); return io.n_bytes();
} }
@@ -3650,7 +3651,7 @@ size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_
size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
llama_context_base::state_seq_read_data(io, seq_id); llama_context_base::state_seq_read_data(io, seq_id);
kv_self.state_read(io, seq_id); kv_self->state_read(io, seq_id);
return io.n_bytes(); return io.n_bytes();
} }
@@ -3663,12 +3664,13 @@ llama_context_recurrent::llama_context_recurrent(
const llama_model & model, const llama_model & model,
llama_context_params params, llama_context_params params,
llama_graph_type gtype) : llama_graph_type gtype) :
llama_context_base(model, params, gtype), llama_context_base(model, params, gtype) {
kv_self(model.hparams) {
LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__);
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
kv_self = std::make_unique<llama_kv_cache_recurrent>(hparams);
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
// Mamba only needs a constant number of KV cache cells per sequence // Mamba only needs a constant number of KV cache cells per sequence
@@ -3684,14 +3686,14 @@ llama_context_recurrent::llama_context_recurrent(
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) { if (!hparams.vocab_only) {
if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
throw std::runtime_error("failed to initialize self-attention cache"); throw std::runtime_error("failed to initialize self-attention cache");
} }
{ {
const size_t memory_size_k = kv_self.size_k_bytes(); const size_t memory_size_k = kv_self->size_k_bytes();
const size_t memory_size_v = kv_self.size_v_bytes(); const size_t memory_size_v = kv_self->size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
@@ -3705,20 +3707,20 @@ llama_context_recurrent::~llama_context_recurrent() = default;
void llama_context_recurrent::reserve() { void llama_context_recurrent::reserve() {
// simulate full KV cache // simulate full KV cache
kv_self.n = kv_self.size; kv_self->n = kv_self->size;
LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n);
// TODO: implement recurrent-specific reserve logic // TODO: implement recurrent-specific reserve logic
llama_context_base::reserve(); llama_context_base::reserve();
} }
llama_kv_cache * llama_context_recurrent::get_kv_self() { llama_kv_cache * llama_context_recurrent::get_kv_self() {
return &kv_self; return kv_self.get();
} }
const llama_kv_cache * llama_context_recurrent::get_kv_self() const { const llama_kv_cache * llama_context_recurrent::get_kv_self() const {
return &kv_self; return kv_self.get();
} }
void llama_context_recurrent::kv_self_update() { void llama_context_recurrent::kv_self_update() {
@@ -3740,7 +3742,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@@ -3755,7 +3757,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
// TODO: remove this stuff // TODO: remove this stuff
class batch_guard { class batch_guard {
public: public:
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
} }
~batch_guard() { ~batch_guard() {
@@ -3778,7 +3780,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
llama_kv_slot_restorer kv_slot_restorer; llama_kv_slot_restorer kv_slot_restorer;
}; };
batch_guard bg(kv_self); batch_guard bg(*kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@@ -3870,11 +3872,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
kv_self.head = 0; kv_self->head = 0;
} }
const auto slot_info = kv_self.find_slot(ubatch); const auto slot_info = kv_self->find_slot(ubatch);
if (!slot_info) { if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3; return -3;
@@ -3883,7 +3885,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
bg.save(slot_info); bg.save(slot_info);
} }
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
@@ -3912,11 +3914,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) {
// update the kv ring buffer // update the kv ring buffer
{ {
kv_self.head += ubatch.n_tokens; kv_self->head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index. // Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) { if (kv_self->head >= kv_self->size) {
kv_self.head = 0; kv_self->head = 0;
} }
} }
@@ -4044,9 +4046,9 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// call base functionality // call base functionality
llama_context_base::input_set(ubatch); llama_context_base::input_set(ubatch);
GGML_ASSERT(kv_self.recurrent); GGML_ASSERT(kv_self->recurrent);
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self->n;
if (inp.s_mask) { if (inp.s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer));
@@ -4054,8 +4056,8 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// clear unused states // clear unused states
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head; const uint32_t cell_id = i + kv_self->head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id]; llama_kv_cell & kv_cell = kv_self->cells[cell_id];
data[i] = (float) (kv_cell.src >= 0); data[i] = (float) (kv_cell.src >= 0);
@@ -4073,11 +4075,11 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) { for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head; const uint32_t cell_id = i + kv_self->head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id]; llama_kv_cell & kv_cell = kv_self->cells[cell_id];
// prevent out-of-bound sources // prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
kv_cell.src = cell_id; kv_cell.src = cell_id;
} }
@@ -4101,7 +4103,7 @@ ggml_cgraph * llama_context_recurrent::graph_init() {
ggml_tensor * llama_context_recurrent::build_inp_s_copy( ggml_tensor * llama_context_recurrent::build_inp_s_copy(
ggml_context * ctx0) { ggml_context * ctx0) {
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
//cb(inp.s_copy, "inp_s_copy", -1); //cb(inp.s_copy, "inp_s_copy", -1);
@@ -4112,7 +4114,7 @@ ggml_tensor * llama_context_recurrent::build_inp_s_copy(
ggml_tensor * llama_context_recurrent::build_inp_s_mask( ggml_tensor * llama_context_recurrent::build_inp_s_mask(
ggml_context * ctx0) { ggml_context * ctx0) {
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
//cb(inp.s_mask, "inp_s_mask", -1); //cb(inp.s_mask, "inp_s_mask", -1);
@@ -4129,10 +4131,10 @@ ggml_tensor * llama_context_recurrent::build_copy_mask_state(
ggml_tensor * state_mask, ggml_tensor * state_mask,
int32_t n_state, int32_t n_state,
int32_t n_seqs) { int32_t n_seqs) {
const auto n_kv = kv_self.n; const auto n_kv = kv_self->n;
const auto kv_head = kv_self.head; const auto kv_head = kv_self->head;
struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self.size); struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
// copy states // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
@@ -4164,7 +4166,7 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer(
int il) { int il) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto kv_head = kv_self.head; const auto kv_head = kv_self->head;
const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
@@ -4182,8 +4184,8 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer(
GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
struct ggml_tensor * conv_states_all = kv_self.k_l[il]; struct ggml_tensor * conv_states_all = kv_self->k_l[il];
struct ggml_tensor * ssm_states_all = kv_self.v_l[il]; struct ggml_tensor * ssm_states_all = kv_self->v_l[il];
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
struct ggml_tensor * conv = build_copy_mask_state( struct ggml_tensor * conv = build_copy_mask_state(
@@ -4300,7 +4302,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load(
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
struct ggml_tensor * token_shift_all = kv_self.k_l[il]; struct ggml_tensor * token_shift_all = kv_self->k_l[il];
struct ggml_tensor * token_shift = build_copy_mask_state( struct ggml_tensor * token_shift = build_copy_mask_state(
ctx0, gf, token_shift_all, state_copy, state_mask, ctx0, gf, token_shift_all, state_copy, state_mask,
@@ -4323,12 +4325,12 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store(
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = kv_self.head; const auto kv_head = kv_self->head;
return ggml_cpy( return ggml_cpy(
ctx0, ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
); );
} }
@@ -4350,7 +4352,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
const auto n_head = n_embd / head_size; const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il); const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = kv_self.head; const auto kv_head = kv_self->head;
const auto & layer = model.layers[il]; const auto & layer = model.layers[il];
@@ -4458,7 +4460,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
} }
struct ggml_tensor * wkv_state = build_copy_mask_state( struct ggml_tensor * wkv_state = build_copy_mask_state(
ctx0, gf, kv_self.v_l[il], state_copy, state_mask, ctx0, gf, kv_self->v_l[il], state_copy, state_mask,
hparams.n_embd_v_s(), n_seqs); hparams.n_embd_v_s(), n_seqs);
struct ggml_tensor * wkv_output; struct ggml_tensor * wkv_output;
@@ -4477,9 +4479,9 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
wkv_state, wkv_state,
ggml_view_1d( ggml_view_1d(
ctx0, ctx0,
kv_self.v_l[il], kv_self->v_l[il],
hparams.n_embd_v_s() * n_seqs, hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
) )
) )
); );
@@ -4507,7 +4509,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) { size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) {
llama_context_base::state_write_data(io); llama_context_base::state_write_data(io);
kv_self.state_write(io); kv_self->state_write(io);
return io.n_bytes(); return io.n_bytes();
} }
@@ -4515,7 +4517,7 @@ size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) {
size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) { size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) {
llama_context_base::state_read_data(io); llama_context_base::state_read_data(io);
kv_self.state_read(io); kv_self->state_read(io);
return io.n_bytes(); return io.n_bytes();
} }
@@ -4523,7 +4525,7 @@ size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) {
size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
llama_context_base::state_seq_write_data(io, seq_id); llama_context_base::state_seq_write_data(io, seq_id);
kv_self.state_write(io, seq_id); kv_self->state_write(io, seq_id);
return io.n_bytes(); return io.n_bytes();
} }
@@ -4531,7 +4533,7 @@ size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llam
size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
llama_context_base::state_seq_read_data(io, seq_id); llama_context_base::state_seq_read_data(io, seq_id);
kv_self.state_read(io, seq_id); kv_self->state_read(io, seq_id);
return io.n_bytes(); return io.n_bytes();
} }
@@ -5211,7 +5213,7 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view *
return; return;
} }
llama_kv_cache_view_update(view, *kv); llama_kv_cache_view_update(view, kv);
} }
// //

View File

@@ -630,7 +630,7 @@ private:
// members // members
// //
llama_kv_cache kv_self; std::unique_ptr<llama_kv_cache_unified> kv_self;
}; };
// a recurrent transformer (ie.e RWKV, Mamba) // a recurrent transformer (ie.e RWKV, Mamba)
@@ -745,7 +745,7 @@ private:
// //
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models? // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
llama_kv_cache_recurrent kv_self; std::unique_ptr<llama_kv_cache_recurrent> kv_self;
}; };
// TODO: tmp - need something better to pass the data from the encoder to the decoder // TODO: tmp - need something better to pass the data from the encoder to the decoder

View File

@@ -6,17 +6,16 @@
#include "llama-model.h" #include "llama-model.h"
#include <algorithm> #include <algorithm>
#include <cassert>
#include <limits> #include <limits>
#include <map> #include <map>
#include <stdexcept> #include <stdexcept>
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
llama_kv_cache::llama_kv_cache(const llama_hparams & hparams) : hparams(hparams) { llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams) : hparams(hparams) {
} }
bool llama_kv_cache::init( bool llama_kv_cache_unified::init(
const llama_model & model, const llama_model & model,
const llama_cparams & cparams, const llama_cparams & cparams,
ggml_type type_k, ggml_type type_k,
@@ -123,7 +122,7 @@ bool llama_kv_cache::init(
return true; return true;
} }
int32_t llama_kv_cache::n_tokens() const { int32_t llama_kv_cache_unified::n_tokens() const {
int32_t result = 0; int32_t result = 0;
for (uint32_t i = 0; i < size; i++) { for (uint32_t i = 0; i < size; i++) {
@@ -133,7 +132,11 @@ int32_t llama_kv_cache::n_tokens() const {
return result; return result;
} }
size_t llama_kv_cache::total_size() const { uint32_t llama_kv_cache_unified::used_cells() const {
return used;
}
size_t llama_kv_cache_unified::total_size() const {
size_t size = 0; size_t size = 0;
for (const auto & buf : bufs) { for (const auto & buf : bufs) {
size += ggml_backend_buffer_get_size(buf.get()); size += ggml_backend_buffer_get_size(buf.get());
@@ -142,7 +145,7 @@ size_t llama_kv_cache::total_size() const {
return size; return size;
} }
llama_pos llama_kv_cache::pos_max() const { llama_pos llama_kv_cache_unified::pos_max() const {
llama_pos pos_max = -1; llama_pos pos_max = -1;
for (const auto & cell : cells) { for (const auto & cell : cells) {
pos_max = std::max(pos_max, cell.pos); pos_max = std::max(pos_max, cell.pos);
@@ -151,7 +154,7 @@ llama_pos llama_kv_cache::pos_max() const {
return pos_max; return pos_max;
} }
void llama_kv_cache::clear() { void llama_kv_cache_unified::clear() {
for (int32_t i = 0; i < (int32_t) size; ++i) { for (int32_t i = 0; i < (int32_t) size; ++i) {
cells[i].pos = -1; cells[i].pos = -1;
cells[i].seq_id.clear(); cells[i].seq_id.clear();
@@ -166,7 +169,7 @@ void llama_kv_cache::clear() {
} }
} }
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
uint32_t new_head = size; uint32_t new_head = size;
if (p0 < 0) { if (p0 < 0) {
@@ -237,7 +240,7 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
return true; return true;
} }
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) { if (seq_id_src == seq_id_dst) {
return; return;
} }
@@ -288,7 +291,7 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
} }
} }
void llama_kv_cache::seq_keep(llama_seq_id seq_id) { void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
uint32_t new_head = size; uint32_t new_head = size;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
@@ -320,7 +323,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
} }
} }
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) { if (delta == 0) {
return; return;
} }
@@ -378,7 +381,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
head = new_head != size ? new_head : 0; head = new_head != size ? new_head : 0;
} }
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (d == 1) { if (d == 1) {
return; return;
} }
@@ -424,7 +427,7 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
} }
} }
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) { llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
llama_pos result = 0; llama_pos result = 0;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
@@ -436,13 +439,17 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) {
return result; return result;
} }
void llama_kv_cache::defrag() { void llama_kv_cache_unified::defrag() {
if (!recurrent) { if (!recurrent) {
do_defrag = true; do_defrag = true;
} }
} }
struct llama_kv_cache_slot_info llama_kv_cache::find_slot( bool llama_kv_cache_unified::get_can_shift() const {
return can_shift;
}
struct llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
const struct llama_ubatch & ubatch) { const struct llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seqs = ubatch.n_seqs;
@@ -663,12 +670,12 @@ struct llama_kv_cache_slot_info llama_kv_cache::find_slot(
return llama_kv_cache_slot_info(head, head + n_tokens); return llama_kv_cache_slot_info(head, head + n_tokens);
} }
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) const { uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
// the FA kernels require padding to avoid extra runtime boundary checks // the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u; return cparams.flash_attn ? 256u : 32u;
} }
uint32_t llama_kv_cache::cell_max() const { uint32_t llama_kv_cache_unified::cell_max() const {
for (uint32_t i = size; i > 0; --i) { for (uint32_t i = size; i > 0; --i) {
const llama_kv_cell & cell = cells[i - 1]; const llama_kv_cell & cell = cells[i - 1];
@@ -680,7 +687,7 @@ uint32_t llama_kv_cache::cell_max() const {
return 0; return 0;
} }
size_t llama_kv_cache::size_k_bytes() const { size_t llama_kv_cache_unified::size_k_bytes() const {
size_t size_k_bytes = 0; size_t size_k_bytes = 0;
for (const auto & k : k_l) { for (const auto & k : k_l) {
@@ -690,7 +697,7 @@ size_t llama_kv_cache::size_k_bytes() const {
return size_k_bytes; return size_k_bytes;
} }
size_t llama_kv_cache::size_v_bytes() const { size_t llama_kv_cache_unified::size_v_bytes() const {
size_t size_v_bytes = 0; size_t size_v_bytes = 0;
for (const auto & v : v_l) { for (const auto & v : v_l) {
@@ -700,7 +707,7 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes; return size_v_bytes;
} }
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0; uint32_t cell_count = 0;
@@ -738,7 +745,7 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) con
state_write_data(io, cell_ranges); state_write_data(io, cell_ranges);
} }
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) { void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
uint32_t cell_count; uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count)); io.read_to(&cell_count, sizeof(cell_count));
@@ -756,7 +763,7 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
} }
} }
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) { for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = cells[i]; const auto & cell = cells[i];
@@ -775,7 +782,7 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<s
} }
} }
void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t v_trans = this->v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
@@ -855,7 +862,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<s
} }
} }
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) { if (dest_seq_id != -1) {
// single sequence // single sequence
@@ -921,7 +928,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count,
llama_seq_id seq_id; llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id)); io.read_to(&seq_id, sizeof(seq_id));
// TODO: llama_kv_cache should have a notion of max sequences // TODO: llama_kv_cache_unified should have a notion of max sequences
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
if (seq_id < 0) { if (seq_id < 0) {
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -957,7 +964,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count,
return true; return true;
} }
bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t cell_count) { bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t v_trans; uint32_t v_trans;
uint32_t n_layer; uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans)); io.read_to(&v_trans, sizeof(v_trans));
@@ -1092,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return 0; return 0;
} }
return kv->used; return kv->used_cells();
} }
void llama_kv_cache_clear(llama_kv_cache * kv) { void llama_kv_cache_clear(llama_kv_cache * kv) {
@@ -1183,7 +1190,7 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
return false; return false;
} }
return kv->can_shift; return kv->get_can_shift();
} }
// //
@@ -1216,9 +1223,16 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
} }
} }
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) { void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv) {
if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) { // TODO: rework this in the future, for now quick hack
view->n_cells = int32_t(kv.size); const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
if (kvu == nullptr) {
LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
return;
}
if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
view->n_cells = int32_t(kvu->size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p; view->cells = (struct llama_kv_cache_view_cell *)p;
@@ -1227,7 +1241,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
view->cells_sequences = (llama_seq_id *)p; view->cells_sequences = (llama_seq_id *)p;
} }
const std::vector<llama_kv_cell> & kv_cells = kv.cells; const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
llama_kv_cache_view_cell * c_curr = view->cells; llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences; llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0; int32_t used_cells = 0;
@@ -1236,7 +1250,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
uint32_t max_contig = 0; uint32_t max_contig = 0;
int32_t max_contig_idx = -1; int32_t max_contig_idx = -1;
for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) { for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
const size_t curr_size = kv_cells[i].seq_id.size(); const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size; token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -1274,8 +1288,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
view->max_contiguous_idx = max_contig_idx; view->max_contiguous_idx = max_contig_idx;
view->token_count = token_count; view->token_count = token_count;
view->used_cells = used_cells; view->used_cells = used_cells;
if (uint32_t(used_cells) != kv.used) { if (uint32_t(used_cells) != kvu->used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, kv.used, used_cells); __func__, kvu->used, used_cells);
} }
} }

View File

@@ -45,12 +45,39 @@ struct llama_kv_cache_slot_info {
operator bool() const { return found; } operator bool() const { return found; }
}; };
struct llama_kv_cache {
public:
virtual int32_t n_tokens() const = 0;
virtual uint32_t used_cells() const = 0; // TODO: remove
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual void defrag() = 0;
virtual bool get_can_shift() const = 0;
};
// C++ alias
class llama_kv_cache_i : public llama_kv_cache {
public:
using llama_kv_cache::llama_kv_cache;
};
// ring-buffer of cached KV data // ring-buffer of cached KV data
// TODO: pimpl // TODO: pimpl
// TODO: add notion of max sequences // TODO: add notion of max sequences
struct llama_kv_cache { class llama_kv_cache_unified : public llama_kv_cache_i {
llama_kv_cache(const llama_hparams & hparams); public:
virtual ~llama_kv_cache() = default; llama_kv_cache_unified(const llama_hparams & hparams);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor // TODO: become constructor
bool init( bool init(
@@ -61,24 +88,26 @@ struct llama_kv_cache {
uint32_t kv_size, uint32_t kv_size,
bool offload); bool offload);
int32_t n_tokens() const; int32_t n_tokens() const override;
uint32_t used_cells() const override;
size_t total_size() const; size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation // TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const; llama_pos pos_max() const;
void clear(); void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1); bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id); void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id); llama_pos seq_pos_max(llama_seq_id seq_id) override;
void defrag(); void defrag() override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache // find an empty slot of size "n_tokens" in the cache
// updates the cache head // updates the cache head
@@ -143,9 +172,10 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache // TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
struct llama_kv_cache_recurrent : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache_unified {
using llama_kv_cache::llama_kv_cache; public:
using llama_kv_cache_unified::llama_kv_cache_unified;
}; };
// //
@@ -166,9 +196,9 @@ struct llama_kv_slot_restorer {
bool do_restore = false; bool do_restore = false;
llama_kv_cache & cache; llama_kv_cache_unified & cache;
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) { explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
old_state.head = cache.head; old_state.head = cache.head;
old_state.n = cache.n; old_state.n = cache.n;
} }
@@ -249,4 +279,4 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max); struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv); void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);