#pragma once #include "llama-batch.h" #include "llama-graph.h" #include "llama-kv-cells.h" #include "llama-memory.h" #include #include struct llama_cparams; struct llama_hparams; struct llama_model; struct llama_context; // // llama_kv_cache // class llama_kv_cache : public llama_memory_i { public: static uint32_t get_padding(const llama_cparams & cparams); struct stream_copy_info { bool empty() const { assert(ssrc.size() == sdst.size()); return ssrc.empty(); } std::vector ssrc; std::vector sdst; }; // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] struct slot_info { // data for ggml_set_rows using idx_vec_t = std::vector; // number of streams: ns = s1 - s0 + 1 uint32_t s0; uint32_t s1; std::vector strm; // [ns] std::vector idxs; // [ns] uint32_t head() const { GGML_ASSERT(idxs.size() == 1); GGML_ASSERT(!idxs[0].empty()); return idxs[0][0]; } void resize(size_t n) { strm.resize(n); idxs.resize(n); } size_t size() const { GGML_ASSERT(idxs.size() == strm.size()); GGML_ASSERT(!idxs.empty()); return idxs[0].size(); } size_t n_stream() const { return strm.size(); } bool empty() const { return idxs.empty(); } void clear() { idxs.clear(); } }; using slot_info_vec_t = std::vector; llama_kv_cache( const llama_model & model, ggml_type type_k, ggml_type type_v, bool v_trans, bool offload, bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, const layer_filter_cb & filter, const layer_reuse_cb & reuse); ~llama_kv_cache() = default; // // llama_memory_i // llama_memory_context_ptr init_batch( llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; llama_memory_context_ptr init_full() override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_kv_cache specific API // uint32_t get_size() const; uint32_t get_n_stream() const; bool get_has_shift() const; // // graph_build API // uint32_t get_n_kv(const slot_info & sinfo) const; // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; // // preparation API // // find places for the provided ubatches in the cache, returns the slot infos // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous // return empty slot_info on failure slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); // // input API // ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_k_shift(ggml_tensor * dst) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: const llama_model & model; const llama_hparams & hparams; struct kv_layer { // layer index in the model // note: can be different from the layer index in the KV cache uint32_t il; ggml_tensor * k; ggml_tensor * v; std::vector k_stream; std::vector v_stream; }; bool v_trans = true; // the value tensor is transposed const uint32_t n_seq_max = 1; const uint32_t n_stream = 1; // required padding const uint32_t n_pad = 1; // SWA const uint32_t n_swa = 0; // env: LLAMA_KV_CACHE_DEBUG int debug = 0; // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; std::vector bufs; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector v_heads; std::vector v_cells; // maps from a sequence id to a stream id std::vector seq_to_stream; // pending stream copies that will be applied during the next update stream_copy_info sc_info; std::vector layers; // model layer id -> KV cache layer id std::unordered_map map_layer_ids; size_t total_size() const; size_t size_k_bytes() const; size_t size_v_bytes() const; bool is_masked_swa(llama_pos p0, llama_pos p1) const; ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, ggml_tensor * factors, float freq_base, float freq_scale) const; ggml_cgraph * build_graph_shift( llm_graph_result * res, llama_context * lctx) const; struct cell_ranges_t { uint32_t strm; std::vector> data; // ranges, from inclusive, to exclusive }; void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; class llama_kv_cache_context : public llama_memory_context_i { public: // some shorthands using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; using stream_copy_info = llama_kv_cache::stream_copy_info; // used for errors llama_kv_cache_context(llama_memory_status status); // used to create a full-cache context llama_kv_cache_context( llama_kv_cache * kv); // used to create an update context llama_kv_cache_context( llama_kv_cache * kv, llama_context * lctx, bool do_shift, stream_copy_info sc_info); // used to create a batch procesing context from a batch llama_kv_cache_context( llama_kv_cache * kv, slot_info_vec_t sinfos, std::vector ubatches); virtual ~llama_kv_cache_context(); // // llama_memory_context_i // bool next() override; bool apply() override; llama_memory_status get_status() const override; const llama_ubatch & get_ubatch() const override; // // llama_kv_cache_context specific API // uint32_t get_n_kv() const; // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location // note: the heads in k_cur and v_cur should be layed out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; // create destination indices for each head of the current batch for where it would be written in the KV cache // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but // helps understand the implementation logic of cpy_k and cpy_v ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_k_shift (ggml_tensor * dst) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: llama_memory_status status; llama_kv_cache * kv; llama_context * lctx; // // update context // bool do_shift = false; stream_copy_info sc_info; // // batch processing context // // the index of the cur ubatch to process size_t i_cur = 0; slot_info_vec_t sinfos; std::vector ubatches; // // data needed for building the compute graph for the current ubatch: // // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; };