From 08011c2ca12ee95b2041561f69ef0cc0be865dca Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 20 Feb 2025 20:54:18 +0200 Subject: [PATCH] context : add llama_kv_cache_recurrent prototype ggml-ci --- src/llama-context.cpp | 550 +++++++++++++++++++++++++++++++++++------- src/llama-context.h | 20 +- src/llama-kv-cache.h | 9 +- 3 files changed, 477 insertions(+), 102 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 648a669b16..64728e8b59 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -359,17 +359,17 @@ int32_t llama_context::max_nodes() const { } llama_kv_cache * llama_context::get_kv_self() { - LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); + LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); return nullptr; } const llama_kv_cache * llama_context::get_kv_self() const { - LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); + LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); return nullptr; } void llama_context::kv_self_update() { - LLAMA_LOG_DEBUG("%s: llama_context does not have a KV cache\n", __func__); + LLAMA_LOG_WARN("%s: llama_context does not have a KV cache\n", __func__); } enum llama_pooling_type llama_context::pooling_type() const { @@ -2246,14 +2246,7 @@ llama_context_kv_self::llama_context_kv_self( ggml_type type_k = params.type_k; ggml_type type_v = params.type_v; - // Mamba only needs a constant number of KV cache cells per sequence - if (llama_model_is_recurrent(&model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + GGML_ASSERT(!llama_model_is_recurrent(&model)); GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -2286,6 +2279,61 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const { return &kv_self; } +void llama_context_kv_self::kv_self_update() { + auto & kv = kv_self; + + if (kv.has_shift) { + if (!kv.can_shift) { + GGML_ABORT("The current context does not support K-shift"); + } + + // apply K-shift if needed + if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched.get()); + + auto * gf = graph_init(); + + build_kv_self_shift(ctx_compute.get(), gf); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set({}); + + graph_compute(gf, false); + + need_reserve = true; + } + + { + kv.has_shift = false; + + for (uint32_t i = 0; i < kv.size; ++i) { + kv.cells[i].delta = 0; + } + } + } + + // defragment the KV cache if needed + if (kv.do_defrag) { + ggml_backend_sched_reset(sched.get()); + + auto * gf = graph_init(); + + build_kv_self_defrag(ctx_compute.get(), gf); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + // no input + //input_set({}); + + graph_compute(gf, false); + + kv.do_defrag = false; + + need_reserve = true; + } +} + ggml_cgraph * llama_context_kv_self::graph_init() { inp_embd_enc = nullptr; inp_pos_bucket = nullptr; @@ -2310,7 +2358,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; @@ -2470,7 +2518,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -2552,7 +2600,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { const bool logits_all = n_outputs_all == n_tokens_all; sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ true, /* logits_all */ logits_all); // reserve output buffer @@ -2569,18 +2617,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { const auto & n_ubatch = cparams.n_ubatch; - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = sbatch.split_equal(n_ubatch); - } - } else { - ubatch = sbatch.split_simple(n_ubatch); - } + ubatch = sbatch.split_simple(n_ubatch); // count the outputs in this u_batch { @@ -2617,7 +2654,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { bg.save(slot_info); - if (!kv_self.recurrent) { + { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -2821,10 +2858,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { return 0; } -llama_pos llama_context_kv_self::pos_max() const { - return kv_self.pos_max(); -} - uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const { return kv_self.get_padding(cparams); } @@ -3062,61 +3095,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } } -void llama_context_kv_self::kv_self_update() { - auto & kv = kv_self; - - if (kv.has_shift) { - if (!kv.can_shift) { - GGML_ABORT("The current context does not support K-shift"); - } - - // apply K-shift if needed - if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - build_kv_self_shift(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - input_set({}); - - graph_compute(gf, false); - - need_reserve = true; - } - - { - kv.has_shift = false; - - for (uint32_t i = 0; i < kv.size; ++i) { - kv.cells[i].delta = 0; - } - } - } - - // defragment the KV cache if needed - if (kv.do_defrag) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - build_kv_self_defrag(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - // no input - //input_set({}); - - graph_compute(gf, false); - - kv.do_defrag = false; - - need_reserve = true; - } -} - ggml_tensor * llama_context_kv_self::build_inp_self_k_shift(ggml_context * ctx0) { inp_self_k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx()); ggml_set_input(inp_self_k_shift); @@ -3176,7 +3154,9 @@ ggml_tensor * llama_context_kv_self::build_attn( // store to KV cache { - const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head; + GGML_ASSERT(!kv_self.recurrent); + + const auto kv_head = worst_case ? kv_self.size - n_tokens : kv_self.head; GGML_ASSERT(kv_self.size == n_ctx); @@ -3684,22 +3664,406 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( llama_context_recurrent::llama_context_recurrent( const llama_model & model, const llama_context_params & params) : - llama_context_kv_self(model, params) { + llama_context(model, params), + kv_self(model.hparams) { LLAMA_LOG_INFO("%s: constructing llama_context_recurrent\n", __func__); + + const auto & hparams = model.hparams; + + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + + // Mamba only needs a constant number of KV cache cells per sequence + GGML_ASSERT(llama_model_is_recurrent(&model)); + + // Mamba needs at least as many KV cells as there are sequences kept at any time + uint32_t kv_size = std::max((uint32_t) 1, params.n_seq_max); + // it's probably best to keep as much precision as possible for the states + ggml_type type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states + ggml_type type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states + + GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); + + if (!hparams.vocab_only) { + if (!kv_self.init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + throw std::runtime_error("failed to initialize self-attention cache"); + } + + { + const size_t memory_size_k = kv_self.size_k_bytes(); + const size_t memory_size_v = kv_self.size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + } } llama_context_recurrent::~llama_context_recurrent() = default; -ggml_cgraph * llama_context_recurrent::graph_init() { - inp_s_copy = nullptr; - inp_s_mask = nullptr; +llama_kv_cache * llama_context_recurrent::get_kv_self() { + return &kv_self; +} - return llama_context_kv_self::graph_init(); +const llama_kv_cache * llama_context_recurrent::get_kv_self() const { + return &kv_self; +} + +void llama_context_recurrent::kv_self_update() { + // noop +} + +ggml_cgraph * llama_context_recurrent::graph_init() { + inp_s_copy = nullptr; + inp_s_mask = nullptr; + + return llama_context::graph_init(); +} + +int llama_context_recurrent::encode(llama_batch & inp_batch) { + GGML_UNUSED(inp_batch); + + LLAMA_LOG_ERROR("%s: encode() not supported for recurrent models\n", __func__); + return -1; +} + +int llama_context_recurrent::decode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self.pos_max() + 1); + + const llama_batch & batch = batch_allocr.batch; + + const auto & vocab = model.vocab; + const auto & hparams = model.hparams; + + const int32_t n_vocab = vocab.n_tokens(); + + const int64_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + // TODO: remove this stuff + class batch_guard { + public: + batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { + } + + ~batch_guard() { + if (!is_done) { + kv_slot_restorer.restore(); + } + } + + void done() { + is_done = true; + } + + void save(const llama_kv_cache_slot_info & slot_info) { + kv_slot_restorer.save(slot_info); + } + + private: + bool is_done = false; + + llama_kv_slot_restorer kv_slot_restorer; + }; + + batch_guard bg(kv_self); + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int64_t i = 0; i < n_tokens_all; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); + throw std::runtime_error("invalid token"); + } + } + } + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens_all; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = 0; + + // count outputs + if (batch.logits && !embd_pooled) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs_all += batch.logits[i] != 0; + } + } else if (logits_all || embd_pooled) { + n_outputs_all = n_tokens_all; + } else { + // keep last output only + n_outputs_all = 1; + } + + const bool logits_all = n_outputs_all == n_tokens_all; + + sbatch.from_batch(batch, n_embd, + /* simple_split */ false, + /* logits_all */ logits_all); + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + return -2; + }; + + int64_t n_outputs_prev = 0; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch = llama_ubatch(); + + const auto & n_ubatch = cparams.n_ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = sbatch.split_equal(n_ubatch); + } + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (n_outputs_all == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + n_outputs = n_outputs_new; + } + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + kv_self_update(); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; + } + + const auto slot_info = kv_self.find_slot(ubatch); + if (!slot_info) { + LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); + return -3; + } + + bg.save(slot_info); + } + + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + + // reserve a worst case graph if needed + if (need_reserve) { + LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + graph_build(ctx_compute.get(), gf, ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(sched.get()); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + + need_reserve = false; + } + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, false); + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set(ubatch); + + const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); + if (compute_status != GGML_STATUS_SUCCESS) { + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + } + + // update the kv ring buffer + { + kv_self.head += ubatch.n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } + } + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + auto * t_logits = cparams.embeddings ? nullptr : res.t_logits; + auto * t_embd = cparams.embeddings ? res.t_embd : nullptr; + + if (t_embd && res.t_embd_pooled) { + t_embd = res.t_embd_pooled; + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + n_outputs_prev += n_outputs; + } + + // finalize the batch processing + bg.done(); + + // set output mappings + { + bool sorted_output = true; + + GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + + for (int64_t i = 0; i < n_outputs_all; ++i) { + int64_t out_id = sbatch.out_ids[i]; + output_ids[out_id] = i; + if (out_id != i) { + sorted_output = false; + } + } + + if (sorted_output) { + sbatch.out_ids.clear(); + } + } + + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + + // wait for the computation to finish (automatically done when obtaining the model output) + //synchronize(); + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + return 0; } void llama_context_recurrent::input_set(const llama_ubatch & ubatch) { // call base functionality - llama_context_kv_self::input_set(ubatch); + llama_context::input_set(ubatch); GGML_ASSERT(kv_self.recurrent); diff --git a/src/llama-context.h b/src/llama-context.h index c605cec6f6..df6acb265d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -374,9 +374,6 @@ public: virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; - // max token position across all sequences in the current context - llama_pos pos_max() const; - // certain implementations could require a padding for the context size uint32_t get_ctx_padding(const llama_cparams & cparams) const; @@ -453,9 +450,7 @@ protected: }; // a recurrent transformer (ie.e RWKV, Mamba) -// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache -//class llama_context_recurrent : public llama_context { -class llama_context_recurrent : public llama_context_kv_self { +class llama_context_recurrent : public llama_context { public: llama_context_recurrent( const llama_model & model, @@ -463,8 +458,16 @@ public: virtual ~llama_context_recurrent(); + virtual llama_kv_cache * get_kv_self() override; + virtual const llama_kv_cache * get_kv_self() const override; + + virtual void kv_self_update() override; + virtual ggml_cgraph * graph_init() override; + virtual int encode(llama_batch & inp_batch) override; + virtual int decode(llama_batch & inp_batch) override; + virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, bool worst_case) override; @@ -524,10 +527,11 @@ public: protected: virtual void input_set(const llama_ubatch & ubatch) override; + // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models? + llama_kv_cache_recurrent kv_self; + struct ggml_tensor * inp_s_copy; // I32 [kv_size] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - - // TODO: add recurrent cache }; // For internal test use diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 049193fd0f..dda9bfec48 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -48,7 +48,6 @@ struct llama_kv_cache_slot_info { // ring-buffer of cached KV data // TODO: pimpl // TODO: add notion of max sequences -// TODO: add llama_hparams & struct llama_kv_cache { llama_kv_cache(const llama_hparams & hparams); virtual ~llama_kv_cache() = default; @@ -108,7 +107,10 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; + + // TODO: remove this and implement llama_kv_cache_recurrent instead bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed bool can_shift = false; @@ -141,6 +143,11 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +// TODO: temporary reusing llama_kv_cache -- implement recurrent cache and simplify llama_kv_cache +struct llama_kv_cache_recurrent : public llama_kv_cache { + using llama_kv_cache::llama_kv_cache; +}; + // // kv cache restore //