From 828effd9d74d770e03852b6123d54f12e92bb950 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Feb 2025 15:54:44 +0200 Subject: [PATCH] kv-cache : basic abstraction ggml-ci --- src/llama-context.cpp | 288 +++++++++++++++++++++-------------------- src/llama-context.h | 4 +- src/llama-kv-cache.cpp | 84 +++++++----- src/llama-kv-cache.h | 66 +++++++--- 4 files changed, 244 insertions(+), 198 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4341c571e3..5c77b29c13 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2384,15 +2384,16 @@ llama_context_kv_self::llama_context_kv_self( const llama_model & model, llama_context_params params, llama_graph_type gtype) : - llama_context_base(model, params, gtype), - kv_self(model.hparams) { + llama_context_base(model, params, gtype) { LLAMA_LOG_INFO("%s: constructing llama_context_kv_self\n", __func__); const auto & hparams = model.hparams; + kv_self = std::make_unique(hparams); + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self.get_padding(cparams)); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -2406,14 +2407,14 @@ llama_context_kv_self::llama_context_kv_self( GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); if (!hparams.vocab_only) { - if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); throw std::runtime_error("failed to initialize self-attention cache"); } { - const size_t memory_size_k = kv_self.size_k_bytes(); - const size_t memory_size_v = kv_self.size_v_bytes(); + const size_t memory_size_k = kv_self->size_k_bytes(); + const size_t memory_size_v = kv_self->size_v_bytes(); LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), @@ -2427,19 +2428,19 @@ llama_context_kv_self::~llama_context_kv_self() = default; void llama_context_kv_self::reserve() { // simulate full KV cache - kv_self.n = kv_self.size; + kv_self->n = kv_self->size; - LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); + LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n); llama_context_base::reserve(); } llama_kv_cache * llama_context_kv_self::get_kv_self() { - return &kv_self; + return kv_self.get(); } const llama_kv_cache * llama_context_kv_self::get_kv_self() const { - return &kv_self; + return kv_self.get(); } void llama_context_kv_self::kv_self_update() { @@ -2449,8 +2450,8 @@ void llama_context_kv_self::kv_self_update() { bool need_reserve = false; - if (kv.has_shift) { - if (!kv.can_shift) { + if (kv->has_shift) { + if (!kv->get_can_shift()) { GGML_ABORT("The current context does not support K-shift"); } @@ -2474,16 +2475,16 @@ void llama_context_kv_self::kv_self_update() { } { - kv.has_shift = false; + kv->has_shift = false; - for (uint32_t i = 0; i < kv.size; ++i) { - kv.cells[i].delta = 0; + for (uint32_t i = 0; i < kv->size; ++i) { + kv->cells[i].delta = 0; } } } // defragment the KV cache if needed - if (kv.do_defrag) { + if (kv->do_defrag) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); ggml_backend_sched_reset(sched.get()); @@ -2499,7 +2500,7 @@ void llama_context_kv_self::kv_self_update() { graph_compute(gf, false); - kv.do_defrag = false; + kv->do_defrag = false; need_reserve = true; } @@ -2513,7 +2514,7 @@ void llama_context_kv_self::kv_self_update() { uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // simulate full KV cache - kv_self.n = kv_self.size; + kv_self->n = kv_self->size; llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; @@ -2537,7 +2538,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; @@ -2674,7 +2675,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -2689,7 +2690,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // TODO: remove this stuff class batch_guard { public: - batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { + batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) { } ~batch_guard() { @@ -2712,7 +2713,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { llama_kv_slot_restorer kv_slot_restorer; }; - batch_guard bg(kv_self); + batch_guard bg(*kv_self); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -2797,11 +2798,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { - kv_self.head = 0; + if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) { + kv_self->head = 0; } - const auto slot_info = kv_self.find_slot(ubatch); + const auto slot_info = kv_self->find_slot(ubatch); if (!slot_info) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); return -3; @@ -2813,12 +2814,12 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self.get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); + const uint32_t pad = kv_self->get_padding(cparams); + kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); } } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -2847,11 +2848,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // update the kv ring buffer { - kv_self.head += ubatch.n_tokens; + kv_self->head += ubatch.n_tokens; // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; + if (kv_self->head >= kv_self->size) { + kv_self->head = 0; } } @@ -2972,13 +2973,13 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + kv_self.get_padding(cparams))/float(kv_self.n)) : 0.0f; + const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - kv_self.defrag(); + kv_self->defrag(); } } @@ -2997,8 +2998,8 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { int32_t * data = (int32_t *) inp.self_k_shift->data; - for (uint32_t i = 0; i < kv_self.size; ++i) { - data[i] = kv_self.cells[i].delta; + for (uint32_t i = 0; i < kv_self->size; ++i) { + data[i] = kv_self->cells[i].delta; } // the K-shift graph requires just this input @@ -3011,7 +3012,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { if (inp.self_kq_mask || inp.self_kq_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. if (cparams.causal_attn) { - const int64_t n_kv = kv_self.n; + const int64_t n_kv = kv_self->n; const int64_t n_tokens = ubatch.n_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seqs = ubatch.n_seqs; @@ -3041,11 +3042,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) { f = -INFINITY; } else { if (hparams.use_alibi) { - f = -std::abs(kv_self.cells[i].pos - pos); + f = -std::abs(kv_self->cells[i].pos - pos); } else { f = 0.0f; } @@ -3057,7 +3058,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { f = -INFINITY; } data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; @@ -3137,11 +3138,11 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { int32_t * data = (int32_t *) inp.self_pos_bucket->data; - const int64_t n_kv = kv_self.n; + const int64_t n_kv = kv_self->n; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false); + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, false); } } } @@ -3164,7 +3165,7 @@ ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) ggml_tensor * llama_context_kv_self::build_inp_pos_bucket( ggml_context * ctx0, int32_t n_tokens) { - const auto n_kv = kv_self.n; + const auto n_kv = kv_self->n; inp.self_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); ggml_set_input(inp.self_pos_bucket); @@ -3177,7 +3178,7 @@ void llama_context_kv_self::build_attn_inp( int32_t n_tokens, bool causal, bool swa) { - const auto n_kv = kv_self.n; + const auto n_kv = kv_self->n; inp.self_kq_mask = causal ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) @@ -3224,13 +3225,13 @@ ggml_tensor * llama_context_kv_self::build_attn( // store to KV cache { - GGML_ASSERT(!kv_self.recurrent); + GGML_ASSERT(!kv_self->recurrent); - const auto kv_head = kv_self.head; + const auto kv_head = kv_self->head; - GGML_ASSERT(kv_self.size == n_ctx); + GGML_ASSERT(kv_self->size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa)*kv_head); + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head); //cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache @@ -3241,12 +3242,12 @@ ggml_tensor * llama_context_kv_self::build_attn( struct ggml_tensor * v_cache_view = nullptr; if (!v_trans) { - v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head); + v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); } else { // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv_self.v_l[il]), - (kv_head)*ggml_element_size(kv_self.v_l[il])); + v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv_self->v_l[il]), + (kv_head)*ggml_element_size(kv_self->v_l[il])); v_cur = ggml_transpose(ctx0, v_cur); } @@ -3281,7 +3282,7 @@ ggml_tensor * llama_context_kv_self::build_attn( const auto & kq_mask = is_sliding ? inp.self_kq_mask_swa_cnv : inp.self_kq_mask_cnv; - const auto n_kv = kv_self.n; + const auto n_kv = kv_self->n; const int64_t n_head_kv = hparams.n_head_kv(il); @@ -3292,23 +3293,23 @@ ggml_tensor * llama_context_kv_self::build_attn( //cb(q, "q", il); ggml_tensor * k = - ggml_view_3d(ctx0, kv_self.k_l[il], + ggml_view_3d(ctx0, kv_self->k_l[il], n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), 0); //cb(k, "k", il); ggml_tensor * v = !v_trans ? - ggml_view_3d(ctx0, kv_self.v_l[il], + ggml_view_3d(ctx0, kv_self->v_l[il], n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), 0) : - ggml_view_3d(ctx0, kv_self.v_l[il], + ggml_view_3d(ctx0, kv_self->v_l[il], n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + ggml_element_size(kv_self->v_l[il])*n_ctx, + ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, 0); struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); @@ -3326,7 +3327,7 @@ void llama_context_kv_self::build_kv_self_shift( const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - //GGML_ASSERT(kv_self.size == n_ctx); + //GGML_ASSERT(kv_self->size == n_ctx); ggml_tensor * inp_self_k_shift = build_inp_self_k_shift(ctx0); @@ -3337,13 +3338,13 @@ void llama_context_kv_self::build_kv_self_shift( struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * k = - ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_head_kv, kv_self.size, - ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_view_3d(ctx0, kv_self->k_l[il], + n_embd_head_k, n_head_kv, kv_self->size, + ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self.k_l[il]->buffer); + ggml_tensor * cur = build_rope_shift(ctx0, k, inp_self_k_shift, rope_factors, kv_self->k_l[il]->buffer); ggml_build_forward_expand(gf, cur); } @@ -3356,8 +3357,8 @@ void llama_context_kv_self::build_kv_self_defrag( const uint32_t n_layer = hparams.n_layer; - const uint32_t n_kv = kv_self.cell_max(); - const uint32_t n_used = kv_self.used; + const uint32_t n_kv = kv_self->cell_max(); + const uint32_t n_used = kv_self->used; assert(n_used <= n_kv); @@ -3382,7 +3383,7 @@ void llama_context_kv_self::build_kv_self_defrag( std::vector ids(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = kv_self.cells[i0]; + const auto & cell0 = kv_self->cells[i0]; if (!cell0.is_empty()) { ids[i0] = i0; @@ -3395,7 +3396,7 @@ void llama_context_kv_self::build_kv_self_defrag( uint32_t nh = 1; // determine the size of the hole - while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) { + while (i0 + nh < n_used && kv_self->cells[i0 + nh].is_empty()) { nh++; } @@ -3404,7 +3405,7 @@ void llama_context_kv_self::build_kv_self_defrag( // starting from the end, find nh non-empty cells for (; is > i0; --is) { - const auto & cell1 = kv_self.cells[is]; + const auto & cell1 = kv_self->cells[is]; if (cell1.is_empty() || ids[is] != n_kv) { continue; @@ -3433,7 +3434,7 @@ void llama_context_kv_self::build_kv_self_defrag( // go back and move the nf cells to the hole for (; i1 < n_kv; ++i1) { - auto & cell1 = kv_self.cells[i1]; + auto & cell1 = kv_self->cells[i1]; if (cell1.is_empty() || ids[i1] != n_kv) { if (n_moves == max_moves) { @@ -3449,11 +3450,11 @@ void llama_context_kv_self::build_kv_self_defrag( ids[i1] = i0 + nf; // move the cell meta data - kv_self.cells[i0 + nf] = cell1; + kv_self->cells[i0 + nf] = cell1; // clear the old cell and move the head there cell1 = llama_kv_cell(); - kv_self.head = n_used; + kv_self->head = n_used; if (!cont) { n_moves++; @@ -3572,40 +3573,40 @@ void llama_context_kv_self::build_kv_self_defrag( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], n_embd_k_gqa, nm, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i)); + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i)); - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il], + ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], n_embd_k_gqa, nm, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id)); ggml_tensor * view_v_src; ggml_tensor * view_v_dst; if (cparams.flash_attn) { // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], n_embd_v_gqa, nm, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i)); - view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], n_embd_v_gqa, nm, - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id)); } else { - view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_row_size(kv_self->v_l[il]->type, kv_self->size), + ggml_row_size(kv_self->v_l[il]->type, i)); - view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + ggml_row_size(kv_self->v_l[il]->type, kv_self->size), + ggml_row_size(kv_self->v_l[il]->type, id)); } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); @@ -3625,7 +3626,7 @@ size_t llama_context_kv_self::state_write_data(llama_io_write_i & io) { llama_context_base::state_write_data(io); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); - kv_self.state_write(io); + kv_self->state_write(io); return io.n_bytes(); } @@ -3634,7 +3635,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) { llama_context_base::state_read_data(io); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - kv_self.state_read(io); + kv_self->state_read(io); return io.n_bytes(); } @@ -3642,7 +3643,7 @@ size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) { size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { llama_context_base::state_seq_write_data(io, seq_id); - kv_self.state_write(io, seq_id); + kv_self->state_write(io, seq_id); return io.n_bytes(); } @@ -3650,7 +3651,7 @@ size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_ size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { llama_context_base::state_seq_read_data(io, seq_id); - kv_self.state_read(io, seq_id); + kv_self->state_read(io, seq_id); return io.n_bytes(); } @@ -3663,12 +3664,13 @@ llama_context_recurrent::llama_context_recurrent( const llama_model & model, llama_context_params params, llama_graph_type gtype) : - llama_context_base(model, params, gtype), - kv_self(model.hparams) { + llama_context_base(model, params, gtype) { LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); const auto & hparams = model.hparams; + kv_self = std::make_unique(hparams); + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); // Mamba only needs a constant number of KV cache cells per sequence @@ -3684,14 +3686,14 @@ llama_context_recurrent::llama_context_recurrent( GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); if (!hparams.vocab_only) { - if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); throw std::runtime_error("failed to initialize self-attention cache"); } { - const size_t memory_size_k = kv_self.size_k_bytes(); - const size_t memory_size_v = kv_self.size_v_bytes(); + const size_t memory_size_k = kv_self->size_k_bytes(); + const size_t memory_size_v = kv_self->size_v_bytes(); LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), @@ -3705,20 +3707,20 @@ llama_context_recurrent::~llama_context_recurrent() = default; void llama_context_recurrent::reserve() { // simulate full KV cache - kv_self.n = kv_self.size; + kv_self->n = kv_self->size; - LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self.n); + LLAMA_LOG_DEBUG("%s: kv_self.n = %u\n", __func__, kv_self->n); // TODO: implement recurrent-specific reserve logic llama_context_base::reserve(); } llama_kv_cache * llama_context_recurrent::get_kv_self() { - return &kv_self; + return kv_self.get(); } const llama_kv_cache * llama_context_recurrent::get_kv_self() const { - return &kv_self; + return kv_self.get(); } void llama_context_recurrent::kv_self_update() { @@ -3740,7 +3742,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -3755,7 +3757,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { // TODO: remove this stuff class batch_guard { public: - batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { + batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) { } ~batch_guard() { @@ -3778,7 +3780,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { llama_kv_slot_restorer kv_slot_restorer; }; - batch_guard bg(kv_self); + batch_guard bg(*kv_self); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -3870,11 +3872,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { - kv_self.head = 0; + if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) { + kv_self->head = 0; } - const auto slot_info = kv_self.find_slot(ubatch); + const auto slot_info = kv_self->find_slot(ubatch); if (!slot_info) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); return -3; @@ -3883,7 +3885,7 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { bg.save(slot_info); } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -3912,11 +3914,11 @@ int llama_context_recurrent::decode(llama_batch & inp_batch) { // update the kv ring buffer { - kv_self.head += ubatch.n_tokens; + kv_self->head += ubatch.n_tokens; // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; + if (kv_self->head >= kv_self->size) { + kv_self->head = 0; } } @@ -4044,9 +4046,9 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { // call base functionality llama_context_base::input_set(ubatch); - GGML_ASSERT(kv_self.recurrent); + GGML_ASSERT(kv_self->recurrent); - const int64_t n_kv = kv_self.n; + const int64_t n_kv = kv_self->n; if (inp.s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(inp.s_mask->buffer)); @@ -4054,8 +4056,8 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = kv_self.cells[cell_id]; + const uint32_t cell_id = i + kv_self->head; + llama_kv_cell & kv_cell = kv_self->cells[cell_id]; data[i] = (float) (kv_cell.src >= 0); @@ -4073,11 +4075,11 @@ void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = kv_self.cells[cell_id]; + const uint32_t cell_id = i + kv_self->head; + llama_kv_cell & kv_cell = kv_self->cells[cell_id]; // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { kv_cell.src = cell_id; } @@ -4101,7 +4103,7 @@ ggml_cgraph * llama_context_recurrent::graph_init() { ggml_tensor * llama_context_recurrent::build_inp_s_copy( ggml_context * ctx0) { - const auto n_kv = kv_self.n; + const auto n_kv = kv_self->n; inp.s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); //cb(inp.s_copy, "inp_s_copy", -1); @@ -4112,7 +4114,7 @@ ggml_tensor * llama_context_recurrent::build_inp_s_copy( ggml_tensor * llama_context_recurrent::build_inp_s_mask( ggml_context * ctx0) { - const auto n_kv = kv_self.n; + const auto n_kv = kv_self->n; inp.s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); //cb(inp.s_mask, "inp_s_mask", -1); @@ -4129,10 +4131,10 @@ ggml_tensor * llama_context_recurrent::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) { - const auto n_kv = kv_self.n; - const auto kv_head = kv_self.head; + const auto n_kv = kv_self->n; + const auto kv_head = kv_self->head; - struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self.size); + struct ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv @@ -4164,7 +4166,7 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer( int il) { const auto & hparams = model.hparams; - const auto kv_head = kv_self.head; + const auto kv_head = kv_self->head; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -4182,8 +4184,8 @@ ggml_tensor * llama_context_recurrent::build_mamba_layer( GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - struct ggml_tensor * conv_states_all = kv_self.k_l[il]; - struct ggml_tensor * ssm_states_all = kv_self.v_l[il]; + struct ggml_tensor * conv_states_all = kv_self->k_l[il]; + struct ggml_tensor * ssm_states_all = kv_self->v_l[il]; // (ab)using the KV cache to store the states struct ggml_tensor * conv = build_copy_mask_state( @@ -4300,7 +4302,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_load( const int64_t n_seqs = ubatch.n_seqs; - struct ggml_tensor * token_shift_all = kv_self.k_l[il]; + struct ggml_tensor * token_shift_all = kv_self->k_l[il]; struct ggml_tensor * token_shift = build_copy_mask_state( ctx0, gf, token_shift_all, state_copy, state_mask, @@ -4323,12 +4325,12 @@ ggml_tensor * llama_context_recurrent::build_rwkv_token_shift_store( const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self.head; + const auto kv_head = kv_self->head; return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) ); } @@ -4350,7 +4352,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self.head; + const auto kv_head = kv_self->head; const auto & layer = model.layers[il]; @@ -4458,7 +4460,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( } struct ggml_tensor * wkv_state = build_copy_mask_state( - ctx0, gf, kv_self.v_l[il], state_copy, state_mask, + ctx0, gf, kv_self->v_l[il], state_copy, state_mask, hparams.n_embd_v_s(), n_seqs); struct ggml_tensor * wkv_output; @@ -4477,9 +4479,9 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( wkv_state, ggml_view_1d( ctx0, - kv_self.v_l[il], + kv_self->v_l[il], hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) ) ) ); @@ -4507,7 +4509,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) { llama_context_base::state_write_data(io); - kv_self.state_write(io); + kv_self->state_write(io); return io.n_bytes(); } @@ -4515,7 +4517,7 @@ size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) { size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) { llama_context_base::state_read_data(io); - kv_self.state_read(io); + kv_self->state_read(io); return io.n_bytes(); } @@ -4523,7 +4525,7 @@ size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) { size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { llama_context_base::state_seq_write_data(io, seq_id); - kv_self.state_write(io, seq_id); + kv_self->state_write(io, seq_id); return io.n_bytes(); } @@ -4531,7 +4533,7 @@ size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llam size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { llama_context_base::state_seq_read_data(io, seq_id); - kv_self.state_read(io, seq_id); + kv_self->state_read(io, seq_id); return io.n_bytes(); } @@ -5211,7 +5213,7 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * return; } - llama_kv_cache_view_update(view, *kv); + llama_kv_cache_view_update(view, kv); } // diff --git a/src/llama-context.h b/src/llama-context.h index 1b807ccf84..d74db70c77 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -630,7 +630,7 @@ private: // members // - llama_kv_cache kv_self; + std::unique_ptr kv_self; }; // a recurrent transformer (ie.e RWKV, Mamba) @@ -745,7 +745,7 @@ private: // // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models? - llama_kv_cache_recurrent kv_self; + std::unique_ptr kv_self; }; // TODO: tmp - need something better to pass the data from the encoder to the decoder diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e1b07c9932..0cd4142d5f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -6,17 +6,16 @@ #include "llama-model.h" #include -#include #include #include #include static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; -llama_kv_cache::llama_kv_cache(const llama_hparams & hparams) : hparams(hparams) { +llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams) : hparams(hparams) { } -bool llama_kv_cache::init( +bool llama_kv_cache_unified::init( const llama_model & model, const llama_cparams & cparams, ggml_type type_k, @@ -123,7 +122,7 @@ bool llama_kv_cache::init( return true; } -int32_t llama_kv_cache::n_tokens() const { +int32_t llama_kv_cache_unified::n_tokens() const { int32_t result = 0; for (uint32_t i = 0; i < size; i++) { @@ -133,7 +132,11 @@ int32_t llama_kv_cache::n_tokens() const { return result; } -size_t llama_kv_cache::total_size() const { +uint32_t llama_kv_cache_unified::used_cells() const { + return used; +} + +size_t llama_kv_cache_unified::total_size() const { size_t size = 0; for (const auto & buf : bufs) { size += ggml_backend_buffer_get_size(buf.get()); @@ -142,7 +145,7 @@ size_t llama_kv_cache::total_size() const { return size; } -llama_pos llama_kv_cache::pos_max() const { +llama_pos llama_kv_cache_unified::pos_max() const { llama_pos pos_max = -1; for (const auto & cell : cells) { pos_max = std::max(pos_max, cell.pos); @@ -151,7 +154,7 @@ llama_pos llama_kv_cache::pos_max() const { return pos_max; } -void llama_kv_cache::clear() { +void llama_kv_cache_unified::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); @@ -166,7 +169,7 @@ void llama_kv_cache::clear() { } } -bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { uint32_t new_head = size; if (p0 < 0) { @@ -237,7 +240,7 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { return true; } -void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { if (seq_id_src == seq_id_dst) { return; } @@ -288,7 +291,7 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } } -void llama_kv_cache::seq_keep(llama_seq_id seq_id) { +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { uint32_t new_head = size; for (uint32_t i = 0; i < size; ++i) { @@ -320,7 +323,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (delta == 0) { return; } @@ -378,7 +381,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll head = new_head != size ? new_head : 0; } -void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (d == 1) { return; } @@ -424,7 +427,7 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } } -llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) { +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) { llama_pos result = 0; for (uint32_t i = 0; i < size; ++i) { @@ -436,13 +439,17 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) { return result; } -void llama_kv_cache::defrag() { +void llama_kv_cache_unified::defrag() { if (!recurrent) { do_defrag = true; } } -struct llama_kv_cache_slot_info llama_kv_cache::find_slot( +bool llama_kv_cache_unified::get_can_shift() const { + return can_shift; +} + +struct llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( const struct llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -663,12 +670,12 @@ struct llama_kv_cache_slot_info llama_kv_cache::find_slot( return llama_kv_cache_slot_info(head, head + n_tokens); } -uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) const { +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; } -uint32_t llama_kv_cache::cell_max() const { +uint32_t llama_kv_cache_unified::cell_max() const { for (uint32_t i = size; i > 0; --i) { const llama_kv_cell & cell = cells[i - 1]; @@ -680,7 +687,7 @@ uint32_t llama_kv_cache::cell_max() const { return 0; } -size_t llama_kv_cache::size_k_bytes() const { +size_t llama_kv_cache_unified::size_k_bytes() const { size_t size_k_bytes = 0; for (const auto & k : k_l) { @@ -690,7 +697,7 @@ size_t llama_kv_cache::size_k_bytes() const { return size_k_bytes; } -size_t llama_kv_cache::size_v_bytes() const { +size_t llama_kv_cache_unified::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & v : v_l) { @@ -700,7 +707,7 @@ size_t llama_kv_cache::size_v_bytes() const { return size_v_bytes; } -void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -738,7 +745,7 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id) con state_write_data(io, cell_ranges); } -void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { uint32_t cell_count; io.read_to(&cell_count, sizeof(cell_count)); @@ -756,7 +763,7 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id) { } } -void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { const auto & cell = cells[i]; @@ -775,7 +782,7 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges) const { +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = hparams.n_layer; @@ -855,7 +862,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector= llama_n_seq_max(ctx)) { if (seq_id < 0) { //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); @@ -957,7 +964,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, return true; } -bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t v_trans; uint32_t n_layer; io.read_to(&v_trans, sizeof(v_trans)); @@ -1092,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) { return 0; } - return kv->used; + return kv->used_cells(); } void llama_kv_cache_clear(llama_kv_cache * kv) { @@ -1183,7 +1190,7 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv) { return false; } - return kv->can_shift; + return kv->get_can_shift(); } // @@ -1216,9 +1223,16 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } } -void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) { - if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) { - view->n_cells = int32_t(kv.size); +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv) { + // TODO: rework this in the future, for now quick hack + const llama_kv_cache_unified * kvu = dynamic_cast(kv); + if (kvu == nullptr) { + LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__); + return; + } + + if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) { + view->n_cells = int32_t(kvu->size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -1227,7 +1241,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = kv.cells; + const std::vector & kv_cells = kvu->cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -1236,7 +1250,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -1274,8 +1288,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != kv.used) { + if (uint32_t(used_cells) != kvu->used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, kv.used, used_cells); + __func__, kvu->used, used_cells); } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dda9bfec48..99eb0be3c7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -45,12 +45,39 @@ struct llama_kv_cache_slot_info { operator bool() const { return found; } }; +struct llama_kv_cache { +public: + virtual int32_t n_tokens() const = 0; + virtual uint32_t used_cells() const = 0; // TODO: remove + + virtual void clear() = 0; + virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; + virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; + virtual void seq_keep(llama_seq_id seq_id) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + + virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0; + + virtual void defrag() = 0; + virtual bool get_can_shift() const = 0; +}; + + +// C++ alias +class llama_kv_cache_i : public llama_kv_cache { +public: + using llama_kv_cache::llama_kv_cache; +}; + + // ring-buffer of cached KV data // TODO: pimpl // TODO: add notion of max sequences -struct llama_kv_cache { - llama_kv_cache(const llama_hparams & hparams); - virtual ~llama_kv_cache() = default; +class llama_kv_cache_unified : public llama_kv_cache_i { +public: + llama_kv_cache_unified(const llama_hparams & hparams); + virtual ~llama_kv_cache_unified() = default; // TODO: become constructor bool init( @@ -61,24 +88,26 @@ struct llama_kv_cache { uint32_t kv_size, bool offload); - int32_t n_tokens() const; + int32_t n_tokens() const override; + uint32_t used_cells() const override; size_t total_size() const; // TODO: better data structures to reduce the cost of this operation llama_pos pos_max() const; - void clear(); + void clear() override; - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1); - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); - void seq_keep(llama_seq_id seq_id); - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - llama_pos seq_pos_max(llama_seq_id seq_id); + llama_pos seq_pos_max(llama_seq_id seq_id) override; - void defrag(); + void defrag() override; + bool get_can_shift() const override; // find an empty slot of size "n_tokens" in the cache // updates the cache head @@ -143,9 +172,10 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache -struct llama_kv_cache_recurrent : public llama_kv_cache { - using llama_kv_cache::llama_kv_cache; +// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified +class llama_kv_cache_recurrent : public llama_kv_cache_unified { +public: + using llama_kv_cache_unified::llama_kv_cache_unified; }; // @@ -166,9 +196,9 @@ struct llama_kv_slot_restorer { bool do_restore = false; - llama_kv_cache & cache; + llama_kv_cache_unified & cache; - explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) { + explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) { old_state.head = cache.head; old_state.n = cache.n; } @@ -249,4 +279,4 @@ bool llama_kv_cache_can_shift(const llama_kv_cache * kv); struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max); -void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv); +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache * kv);