mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cache : refactor the update/defrag mechanism (#13988)
* kv-cache : refactor update mechanism ggml-ci * memory : improve status handling * defrag : reset head + add comments ggml-ci * cont : minor fixes ggml-ci
This commit is contained in:
		| @@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const { | |||||||
|     return kv_self; |     return kv_self; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_context::kv_self_update() { | void llama_context::kv_self_defrag_sched() { | ||||||
|  |     if (!memory) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     memory_force_optimize = true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool llama_context::kv_self_update(bool optimize) { | ||||||
|     if (!memory) { |     if (!memory) { | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); |     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); | ||||||
|  |  | ||||||
|     if (!kv_self->update(*this)) { |     { | ||||||
|         // no updates have been performed |         // TODO: remove in the future | ||||||
|         return false; |         optimize |= memory_force_optimize; | ||||||
|  |         memory_force_optimize = false; | ||||||
|  |  | ||||||
|  |         const auto kv_state = kv_self->init_update(this, optimize); | ||||||
|  |         switch (kv_state->get_status()) { | ||||||
|  |             case LLAMA_MEMORY_STATUS_SUCCESS: | ||||||
|  |                 { | ||||||
|  |                     // noop | ||||||
|  |                 } break; | ||||||
|  |             case LLAMA_MEMORY_STATUS_NO_UPDATE: | ||||||
|  |                 { | ||||||
|  |                     // no updates need to be performed | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |             case LLAMA_MEMORY_STATUS_FAILED_PREPARE: | ||||||
|  |             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: | ||||||
|  |                 { | ||||||
|  |                     LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (!kv_state->apply()) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // if the KV cache did any computation, we have to reserve a new worst-case graph |     // if the KV cache did any computation, we have to reserve a new worst-case graph | ||||||
|     const auto kv_state = kv_self->init_full(); |     const auto kv_state = kv_self->init_full(); | ||||||
|     if (!kv_state) { |     if (!kv_state) { | ||||||
|         throw std::runtime_error("failed to initialize KV cache"); |         throw std::runtime_error("failed to initialize memory state"); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint32_t n_seqs   = cparams.n_seq_max; |     const uint32_t n_seqs   = cparams.n_seq_max; | ||||||
| @@ -452,7 +484,7 @@ bool llama_context::kv_self_update() { | |||||||
|  |  | ||||||
|     auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); |     auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); | ||||||
|     if (!gf) { |     if (!gf) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); |         LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return true; |     return true; | ||||||
| @@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|         n_outputs_all = 1; |         n_outputs_all = 1; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     bool did_optimize = false; | ||||||
|  |  | ||||||
|     // handle any pending defrags/shifts |     // handle any pending defrags/shifts | ||||||
|     kv_self_update(); |     kv_self_update(false); | ||||||
|  |  | ||||||
|     llama_memory_state_ptr kv_state; |     llama_memory_state_ptr kv_state; | ||||||
|  |  | ||||||
|     bool did_defrag = false; |  | ||||||
|  |  | ||||||
|     while (true) { |     while (true) { | ||||||
|         kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); |         kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); | ||||||
|         if (!kv_state) { |         if (!kv_state) { | ||||||
| @@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|             case LLAMA_MEMORY_STATUS_SUCCESS: |             case LLAMA_MEMORY_STATUS_SUCCESS: | ||||||
|                 { |                 { | ||||||
|                 } break; |                 } break; | ||||||
|  |             case LLAMA_MEMORY_STATUS_NO_UPDATE: | ||||||
|  |                 { | ||||||
|  |                     LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); | ||||||
|  |  | ||||||
|  |                     return -2; | ||||||
|  |                 } | ||||||
|             case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |             case LLAMA_MEMORY_STATUS_FAILED_PREPARE: | ||||||
|                 { |                 { | ||||||
|                     if (!did_defrag) { |                     if (!did_optimize) { | ||||||
|                         did_defrag = true; |                         did_optimize = true; | ||||||
|  |  | ||||||
|                         kv_self->defrag_sched(-1.0f); |                         if (kv_self_update(true)) { | ||||||
|                         if (kv_self_update()) { |                             LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); | ||||||
|                             LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); |  | ||||||
|  |  | ||||||
|                             continue; |                             continue; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); |                     LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); | ||||||
|  |  | ||||||
|                     return 1; |                     return 1; | ||||||
|                 } |                 } | ||||||
|             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: | ||||||
|                 { |                 { | ||||||
|  |                     LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); | ||||||
|  |  | ||||||
|                     return -2; |                     return -2; | ||||||
|                 } |                 } | ||||||
|         } |         } | ||||||
| @@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) { | |||||||
|     // wait for the computation to finish (automatically done when obtaining the model output) |     // wait for the computation to finish (automatically done when obtaining the model output) | ||||||
|     //synchronize(); |     //synchronize(); | ||||||
|  |  | ||||||
|     // decide if we need to defrag the kv cache |  | ||||||
|     if (cparams.defrag_thold > 0.0f) { |  | ||||||
|         kv_self->defrag_sched(cparams.defrag_thold); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to |     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to | ||||||
|     // overlap with device computation. |     // overlap with device computation. | ||||||
|     ggml_backend_sched_reset(sched.get()); |     ggml_backend_sched_reset(sched.get()); | ||||||
| @@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) { | |||||||
|  |  | ||||||
| // deprecated | // deprecated | ||||||
| void llama_kv_self_update(llama_context * ctx) { | void llama_kv_self_update(llama_context * ctx) { | ||||||
|     ctx->kv_self_update(); |     ctx->kv_self_update(false); | ||||||
| } | } | ||||||
|  |  | ||||||
| enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { | ||||||
| @@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { | |||||||
|  |  | ||||||
| // deprecated | // deprecated | ||||||
| void llama_kv_self_defrag(llama_context * ctx) { | void llama_kv_self_defrag(llama_context * ctx) { | ||||||
|     auto * kv = ctx->get_kv_self(); |  | ||||||
|     if (!kv) { |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // force defrag |     // force defrag | ||||||
|     kv->defrag_sched(-1.0f); |     ctx->kv_self_defrag_sched(); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_self_can_shift(const llama_context * ctx) { | bool llama_kv_self_can_shift(const llama_context * ctx) { | ||||||
|   | |||||||
| @@ -52,7 +52,8 @@ struct llama_context { | |||||||
|  |  | ||||||
|     // return true of the KV cache was updated |     // return true of the KV cache was updated | ||||||
|     // TODO: remove |     // TODO: remove | ||||||
|     bool kv_self_update(); |     bool kv_self_update(bool optimize); | ||||||
|  |     void kv_self_defrag_sched(); | ||||||
|  |  | ||||||
|     enum llama_pooling_type pooling_type() const; |     enum llama_pooling_type pooling_type() const; | ||||||
|  |  | ||||||
| @@ -231,6 +232,9 @@ private: | |||||||
|  |  | ||||||
|     std::unique_ptr<llama_memory_i> memory; |     std::unique_ptr<llama_memory_i> memory; | ||||||
|  |  | ||||||
|  |     // TODO: temporary, until the llama_kv_self_defrag() API is removed | ||||||
|  |     bool memory_force_optimize = false; | ||||||
|  |  | ||||||
|     // decode output (2-dimensional array: [n_outputs][n_vocab]) |     // decode output (2-dimensional array: [n_outputs][n_vocab]) | ||||||
|     size_t  logits_size = 0; // capacity (of floats) for logits |     size_t  logits_size = 0; // capacity (of floats) for logits | ||||||
|     float * logits      = nullptr; |     float * logits      = nullptr; | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| #include "llama-kv-cache-recurrent.h" | #include "llama-kv-cache-recurrent.h" | ||||||
|  |  | ||||||
| #include "llama-impl.h" | #include "llama-impl.h" | ||||||
|  | #include "llama-io.h" | ||||||
| #include "llama-batch.h" | #include "llama-batch.h" | ||||||
| #include "llama-model.h" | #include "llama-model.h" | ||||||
|  |  | ||||||
| @@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { | |||||||
|     return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); |     return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { | ||||||
|  |     GGML_UNUSED(lctx); | ||||||
|  |     GGML_UNUSED(optimize); | ||||||
|  |  | ||||||
|  |     return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE); | ||||||
|  | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { | bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { | ||||||
|     // simply remember the full state because it is very small for this type of cache |     // simply remember the full state because it is very small for this type of cache | ||||||
|     // TODO: optimize |     // TODO: optimize | ||||||
| @@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche | |||||||
|     return success; |     return success; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_recurrent::update(llama_context & lctx) { |  | ||||||
|     GGML_UNUSED(lctx); |  | ||||||
|     // noop |  | ||||||
|     return false; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void llama_kv_cache_recurrent::defrag_sched(float thold) { |  | ||||||
|     GGML_UNUSED(thold); |  | ||||||
|     // noop |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { | bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { | ||||||
|     const uint32_t n_tokens = ubatch.n_tokens; |     const uint32_t n_tokens = ubatch.n_tokens; | ||||||
|     const uint32_t n_seqs   = ubatch.n_seqs; |     const uint32_t n_seqs   = ubatch.n_seqs; | ||||||
|   | |||||||
| @@ -52,9 +52,7 @@ public: | |||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|     bool update(llama_context & lctx) override; |     llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; | ||||||
|  |  | ||||||
|     void defrag_sched(float thold) override; |  | ||||||
|  |  | ||||||
|     bool prepare(const std::vector<llama_ubatch> & ubatches); |     bool prepare(const std::vector<llama_ubatch> & ubatches); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch | |||||||
|  |  | ||||||
|     assert(heads_base.size() == heads_swa.size()); |     assert(heads_base.size() == heads_swa.size()); | ||||||
|  |  | ||||||
|     return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, |     return std::make_unique<llama_kv_cache_unified_iswa_state>( | ||||||
|             this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); |             this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { | llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { | ||||||
|     return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); |     return std::make_unique<llama_kv_cache_unified_iswa_state>(this); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { | llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { | ||||||
|     bool res = false; |     return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize); | ||||||
|  |  | ||||||
|     res = res | kv_base->update(lctx); |  | ||||||
|     res = res | kv_swa ->update(lctx); |  | ||||||
|  |  | ||||||
|     return res; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void llama_kv_cache_unified_iswa::defrag_sched(float thold) { |  | ||||||
|     kv_base->defrag_sched(thold); |  | ||||||
|     kv_swa ->defrag_sched(thold); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified_iswa::get_can_shift() const { | bool llama_kv_cache_unified_iswa::get_can_shift() const { | ||||||
| @@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { | |||||||
| llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} | llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} | ||||||
|  |  | ||||||
| llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | ||||||
|         llama_memory_status status, |         llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { | ||||||
|         llama_kv_cache_unified_iswa * kv) : status(status) { |     state_base = kv->get_base()->init_full(); | ||||||
|     state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); |     state_swa  = kv->get_swa ()->init_full(); | ||||||
|     state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); |  | ||||||
|  |     status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | ||||||
|  |         llama_kv_cache_unified_iswa * kv, | ||||||
|  |         llama_context * lctx, | ||||||
|  |         bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { | ||||||
|  |     state_base = kv->get_base()->init_update(lctx, optimize); | ||||||
|  |     state_swa  = kv->get_swa ()->init_update(lctx, optimize); | ||||||
|  |  | ||||||
|  |     status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( | ||||||
|         llama_memory_status status, |  | ||||||
|         llama_kv_cache_unified_iswa * kv, |         llama_kv_cache_unified_iswa * kv, | ||||||
|         llama_sbatch sbatch, |         llama_sbatch sbatch, | ||||||
|         std::vector<uint32_t> heads_base, |         std::vector<uint32_t> heads_base, | ||||||
|         std::vector<uint32_t> heads_swa, |         std::vector<uint32_t> heads_swa, | ||||||
|         std::vector<llama_ubatch> ubatches) |         std::vector<llama_ubatch> ubatches) | ||||||
|     : status(status), |         : status(LLAMA_MEMORY_STATUS_SUCCESS), | ||||||
|     sbatch(std::move(sbatch)), |         sbatch(std::move(sbatch)), | ||||||
|     ubatches(std::move(ubatches)) { |         ubatches(std::move(ubatches)) { | ||||||
|         // note: here we copy the ubatches. not sure if this is ideal |     // note: here we copy the ubatches. not sure if this is ideal | ||||||
|         state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); |     state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); | ||||||
|         state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)); |     state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)); | ||||||
|     } |  | ||||||
|  |     status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); | ||||||
|  | } | ||||||
|  |  | ||||||
| llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; | llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; | ||||||
|  |  | ||||||
| @@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { | |||||||
|  |  | ||||||
| const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { | const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { | ||||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||||
|  |  | ||||||
|     return ubatches[i_next]; |     return ubatches[i_next]; | ||||||
| } | } | ||||||
|  |  | ||||||
| const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { | const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { | ||||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||||
|  |  | ||||||
|     return state_base.get(); |     return static_cast<const llama_kv_cache_unified_state *>(state_base.get()); | ||||||
| } | } | ||||||
|  |  | ||||||
| const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const { | const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const { | ||||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||||
|  |  | ||||||
|     return state_swa.get(); |     return static_cast<const llama_kv_cache_unified_state *>(state_swa.get()); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -54,9 +54,7 @@ public: | |||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|     bool update(llama_context & lctx) override; |     llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; | ||||||
|  |  | ||||||
|     void defrag_sched(float thold) override; |  | ||||||
|  |  | ||||||
|     bool get_can_shift() const override; |     bool get_can_shift() const override; | ||||||
|  |  | ||||||
| @@ -86,12 +84,16 @@ public: | |||||||
|  |  | ||||||
|     // used to create a full-cache state |     // used to create a full-cache state | ||||||
|     llama_kv_cache_unified_iswa_state( |     llama_kv_cache_unified_iswa_state( | ||||||
|             llama_memory_status status, |  | ||||||
|             llama_kv_cache_unified_iswa * kv); |             llama_kv_cache_unified_iswa * kv); | ||||||
|  |  | ||||||
|  |     // used to create an update state | ||||||
|  |     llama_kv_cache_unified_iswa_state( | ||||||
|  |             llama_kv_cache_unified_iswa * kv, | ||||||
|  |             llama_context * lctx, | ||||||
|  |             bool optimize); | ||||||
|  |  | ||||||
|     // used to create a state from a batch |     // used to create a state from a batch | ||||||
|     llama_kv_cache_unified_iswa_state( |     llama_kv_cache_unified_iswa_state( | ||||||
|             llama_memory_status status, |  | ||||||
|             llama_kv_cache_unified_iswa * kv, |             llama_kv_cache_unified_iswa * kv, | ||||||
|             llama_sbatch sbatch, |             llama_sbatch sbatch, | ||||||
|             std::vector<uint32_t> heads_base, |             std::vector<uint32_t> heads_base, | ||||||
| @@ -120,7 +122,7 @@ public: | |||||||
|     const llama_kv_cache_unified_state * get_swa()  const; |     const llama_kv_cache_unified_state * get_swa()  const; | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     const llama_memory_status status; |     llama_memory_status status; | ||||||
|  |  | ||||||
|     //llama_kv_cache_unified_iswa * kv; |     //llama_kv_cache_unified_iswa * kv; | ||||||
|  |  | ||||||
| @@ -131,6 +133,6 @@ private: | |||||||
|  |  | ||||||
|     std::vector<llama_ubatch> ubatches; |     std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
|     std::unique_ptr<llama_kv_cache_unified_state> state_base; |     llama_memory_state_ptr state_base; | ||||||
|     std::unique_ptr<llama_kv_cache_unified_state> state_swa; |     llama_memory_state_ptr state_swa; | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| #include "llama-kv-cache-unified.h" | #include "llama-kv-cache-unified.h" | ||||||
|  |  | ||||||
| #include "llama-impl.h" | #include "llama-impl.h" | ||||||
|  | #include "llama-io.h" | ||||||
| #include "llama-model.h" | #include "llama-model.h" | ||||||
| #include "llama-context.h" | #include "llama-context.h" | ||||||
|  |  | ||||||
| @@ -320,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( | |||||||
|         return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |         return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, |     return std::make_unique<llama_kv_cache_unified_state>( | ||||||
|             this, std::move(sbatch), std::move(heads), std::move(ubatches)); |             this, std::move(sbatch), std::move(heads), std::move(ubatches)); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_unified::init_full() { | llama_memory_state_ptr llama_kv_cache_unified::init_full() { | ||||||
|     return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); |     return std::make_unique<llama_kv_cache_unified_state>(this); | ||||||
| } | } | ||||||
|  |  | ||||||
| std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { | llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { | ||||||
|     std::vector<uint32_t> res; |     bool do_shift = get_has_shift(); | ||||||
|  |  | ||||||
|  |     defrag_info dinfo; | ||||||
|  |  | ||||||
|  |     // see if we need to defrag | ||||||
|  |     { | ||||||
|  |         bool do_defrag = optimize; | ||||||
|  |  | ||||||
|  |         const auto thold = lctx->get_cparams().defrag_thold; | ||||||
|  |  | ||||||
|  |         if (!do_defrag && thold > 0.0f) { | ||||||
|  |             const auto n_kv = cells.used_max_p1(); | ||||||
|  |  | ||||||
|  |             // - do not defrag small contexts (i.e. < 2048 tokens) | ||||||
|  |             // - count the padding towards the number of used tokens | ||||||
|  |             const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; | ||||||
|  |  | ||||||
|  |             if (fragmentation > thold) { | ||||||
|  |                 LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); | ||||||
|  |  | ||||||
|  |                 do_defrag = true; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (do_defrag) { | ||||||
|  |             dinfo = defrag_prepare(lctx->graph_max_nodes()); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { | ||||||
|  |     llama_kv_cache_unified::ubatch_heads res; | ||||||
|  |  | ||||||
|     struct state { |     struct state { | ||||||
|         uint32_t head_old; // old position of the head, before placing the ubatch |         uint32_t head_old; // old position of the head, before placing the ubatch | ||||||
| @@ -374,12 +408,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub | |||||||
|     return res; |     return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified::update(llama_context & lctx) { | bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { | ||||||
|     bool updated = false; |     bool updated = false; | ||||||
|  |  | ||||||
|     auto * sched = lctx.get_sched(); |     auto * sched = lctx->get_sched(); | ||||||
|  |  | ||||||
|     if (cells.get_has_shift()) { |     if (do_shift) { | ||||||
|         if (!get_can_shift()) { |         if (!get_can_shift()) { | ||||||
|             GGML_ABORT("The current KV cache / model configuration does not support K-shift"); |             GGML_ABORT("The current KV cache / model configuration does not support K-shift"); | ||||||
|         } |         } | ||||||
| @@ -390,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
|         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { |         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { | ||||||
|             ggml_backend_sched_reset(sched); |             ggml_backend_sched_reset(sched); | ||||||
|  |  | ||||||
|             auto * gf = lctx.graph_init(); |             auto * gf = lctx->graph_init(); | ||||||
|  |  | ||||||
|             auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); |             auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); | ||||||
|             if (!res) { |             if (!res) { | ||||||
|                 LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); |                 LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); | ||||||
|                 return updated; |                 return updated; | ||||||
| @@ -405,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
|  |  | ||||||
|             res->set_inputs(nullptr); |             res->set_inputs(nullptr); | ||||||
|  |  | ||||||
|             if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { |             if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { | ||||||
|                 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); |                 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); | ||||||
|                 return updated; |                 return updated; | ||||||
|             } |             } | ||||||
| @@ -416,56 +450,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
|         cells.reset_shift(); |         cells.reset_shift(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (do_defrag) { |     if (!dinfo.empty()) { | ||||||
|         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); |         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); | ||||||
|  |  | ||||||
|         if (defrag_prepare(lctx.graph_max_nodes())) { |         // apply moves: | ||||||
|             ggml_backend_sched_reset(sched); |         { | ||||||
|  |             const auto n_kv = dinfo.ids.size(); | ||||||
|  |  | ||||||
|             auto * gf = lctx.graph_init(); |             for (uint32_t i = 0; i < n_kv; ++i) { | ||||||
|  |                 assert(dinfo.ids[i] <= n_kv); | ||||||
|  |  | ||||||
|             auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); |                 if (dinfo.ids[i] == n_kv) { | ||||||
|             if (!res) { |                     continue; | ||||||
|                 LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); |                 } | ||||||
|                 return updated; |  | ||||||
|  |                 cells.mv(i, dinfo.ids[i]); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if (!ggml_backend_sched_alloc_graph(sched, gf)) { |             // reset the head so we can find the first free slot during the next ubatch | ||||||
|                 LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); |             head = 0; | ||||||
|                 return updated; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             res->set_inputs(nullptr); |  | ||||||
|  |  | ||||||
|             if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { |  | ||||||
|                 LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); |  | ||||||
|                 return updated; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             updated = true; |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         do_defrag = false; |         ggml_backend_sched_reset(sched); | ||||||
|  |  | ||||||
|  |         auto * gf = lctx->graph_init(); | ||||||
|  |  | ||||||
|  |         auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); | ||||||
|  |         if (!res) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); | ||||||
|  |             return updated; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (!ggml_backend_sched_alloc_graph(sched, gf)) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); | ||||||
|  |             return updated; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         res->set_inputs(nullptr); | ||||||
|  |  | ||||||
|  |         if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { | ||||||
|  |             LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); | ||||||
|  |             return updated; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         updated = true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return updated; |     return updated; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::defrag_sched(float thold) { |  | ||||||
|     const auto n_kv = cells.used_max_p1(); |  | ||||||
|  |  | ||||||
|     // - do not defrag small contexts (i.e. < 2048 tokens) |  | ||||||
|     // - count the padding towards the number of used tokens |  | ||||||
|     const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; |  | ||||||
|  |  | ||||||
|     // queue defragmentation for next llama_kv_cache_update |  | ||||||
|     if (fragmentation > thold) { |  | ||||||
|         LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); |  | ||||||
|  |  | ||||||
|         do_defrag = true; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { | ||||||
|     const uint32_t n_tokens = ubatch.n_tokens; |     const uint32_t n_tokens = ubatch.n_tokens; | ||||||
|  |  | ||||||
| @@ -612,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const { | |||||||
|     return cells.size(); |     return cells.size(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | bool llama_kv_cache_unified::get_has_shift() const { | ||||||
|  |     return cells.get_has_shift(); | ||||||
|  | } | ||||||
|  |  | ||||||
| uint32_t llama_kv_cache_unified::get_n_kv() const { | uint32_t llama_kv_cache_unified::get_n_kv() const { | ||||||
|     return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); |     return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); | ||||||
| } | } | ||||||
| @@ -941,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( | |||||||
| } | } | ||||||
|  |  | ||||||
| llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | ||||||
|         const llama_cparams & cparams, |                 const llama_cparams & cparams, | ||||||
|                ggml_context * ctx, |                        ggml_context * ctx, | ||||||
|                 ggml_cgraph * gf) const { |                         ggml_cgraph * gf, | ||||||
|  |                   const defrag_info & dinfo) const { | ||||||
|     auto res = std::make_unique<llm_graph_result>(); |     auto res = std::make_unique<llm_graph_result>(); | ||||||
|  |  | ||||||
|     const auto & ids = defrag_info.ids; |     const auto & ids = dinfo.ids; | ||||||
|  |  | ||||||
| #if 0 | #if 0 | ||||||
|     // CPU defrag |     // CPU defrag | ||||||
| @@ -1087,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | |||||||
|     return res; |     return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { | ||||||
|     const uint32_t n_layer = layers.size(); |     const uint32_t n_layer = layers.size(); | ||||||
|  |  | ||||||
|     const uint32_t n_kv   = cells.used_max_p1(); |     const uint32_t n_kv   = cells.used_max_p1(); | ||||||
| @@ -1108,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|     const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); |     const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); | ||||||
|  |  | ||||||
|     // determine which KV cells to move where |     // determine which KV cells to move where | ||||||
|     // |     defrag_info res; | ||||||
|     //  cell i moves to ids[i] |     auto & ids = res.ids; | ||||||
|     // |  | ||||||
|     //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved |  | ||||||
|     // |  | ||||||
|     auto & ids = defrag_info.ids; |  | ||||||
|  |  | ||||||
|     ids.clear(); |  | ||||||
|     ids.resize(n_kv, n_kv); |     ids.resize(n_kv, n_kv); | ||||||
|  |  | ||||||
|     for (uint32_t i0 = 0; i0 < n_used; ++i0) { |     for (uint32_t i0 = 0; i0 < n_used; ++i0) { | ||||||
| @@ -1179,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|             // this cell goes to (i0 + nf) |             // this cell goes to (i0 + nf) | ||||||
|             ids[i1] = i0 + nf; |             ids[i1] = i0 + nf; | ||||||
|  |  | ||||||
|             // move the cell meta data |  | ||||||
|             cells.mv(i1, i0 + nf); |  | ||||||
|  |  | ||||||
|             head = n_used; |  | ||||||
|  |  | ||||||
|             if (!cont) { |             if (!cont) { | ||||||
|                 n_moves++; |                 n_moves++; | ||||||
|                 cont = true; |                 cont = true; | ||||||
| @@ -1206,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (n_moves == 0) { |     if (n_moves == 0) { | ||||||
|         return false; |         return {}; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); |     LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); | ||||||
|  |  | ||||||
|     LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); |     LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); | ||||||
|  |  | ||||||
|     return true; |     return res; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | ||||||
| @@ -1636,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell | |||||||
| llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} | llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} | ||||||
|  |  | ||||||
| llama_kv_cache_unified_state::llama_kv_cache_unified_state( | llama_kv_cache_unified_state::llama_kv_cache_unified_state( | ||||||
|             llama_memory_status status, |         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { | ||||||
|             llama_kv_cache_unified * kv) : status(status), kv(kv) { |     n_kv = kv->get_size(); | ||||||
|         n_kv = kv->get_size(); |     head = 0; | ||||||
|         head = 0; | } | ||||||
|     } |  | ||||||
|  |  | ||||||
| llama_kv_cache_unified_state::llama_kv_cache_unified_state( | llama_kv_cache_unified_state::llama_kv_cache_unified_state( | ||||||
|             llama_memory_status status, |         llama_kv_cache_unified * kv, | ||||||
|             llama_kv_cache_unified * kv, |         llama_context * lctx, | ||||||
|             llama_sbatch sbatch, |         bool do_shift, | ||||||
|             std::vector<uint32_t> heads, |         defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { | ||||||
|             std::vector<llama_ubatch> ubatches) |     if (!do_shift && dinfo.empty()) { | ||||||
|             : status(status), |         status = LLAMA_MEMORY_STATUS_NO_UPDATE; | ||||||
|               kv(kv), |  | ||||||
|               sbatch(std::move(sbatch)), |  | ||||||
|               heads(std::move(heads)), |  | ||||||
|               ubatches(std::move(ubatches)) { |  | ||||||
|     } |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_kv_cache_unified_state::llama_kv_cache_unified_state( | ||||||
|  |         llama_kv_cache_unified * kv, | ||||||
|  |         llama_sbatch sbatch, | ||||||
|  |         llama_kv_cache_unified::ubatch_heads heads, | ||||||
|  |         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { | ||||||
|  | } | ||||||
|  |  | ||||||
| llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; | llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; | ||||||
|  |  | ||||||
| @@ -1670,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() { | |||||||
| bool llama_kv_cache_unified_state::apply() { | bool llama_kv_cache_unified_state::apply() { | ||||||
|     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); |     assert(status == LLAMA_MEMORY_STATUS_SUCCESS); | ||||||
|  |  | ||||||
|  |     // no ubatches -> this is a KV cache update | ||||||
|  |     if (ubatches.empty()) { | ||||||
|  |         kv->update(lctx, do_shift, dinfo); | ||||||
|  |  | ||||||
|  |         return true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     kv->apply_ubatch(heads[i_next], ubatches[i_next]); |     kv->apply_ubatch(heads[i_next], ubatches[i_next]); | ||||||
|  |  | ||||||
|     n_kv = kv->get_n_kv(); |     n_kv = kv->get_n_kv(); | ||||||
|   | |||||||
| @@ -24,6 +24,19 @@ public: | |||||||
|     // this callback is used to filter out layers that should not be included in the cache |     // this callback is used to filter out layers that should not be included in the cache | ||||||
|     using layer_filter_cb = std::function<bool(int32_t il)>; |     using layer_filter_cb = std::function<bool(int32_t il)>; | ||||||
|  |  | ||||||
|  |     using ubatch_heads = std::vector<uint32_t>; | ||||||
|  |  | ||||||
|  |     struct defrag_info { | ||||||
|  |         bool empty() const { | ||||||
|  |             return ids.empty(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // contains information about which cell moves where: | ||||||
|  |         //  - cell i moves to ids[i] | ||||||
|  |         //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved | ||||||
|  |         std::vector<uint32_t> ids; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     llama_kv_cache_unified( |     llama_kv_cache_unified( | ||||||
|             const llama_model &  model, |             const llama_model &  model, | ||||||
|               layer_filter_cb && filter, |               layer_filter_cb && filter, | ||||||
| @@ -66,9 +79,7 @@ public: | |||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|     bool update(llama_context & lctx) override; |     llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; | ||||||
|  |  | ||||||
|     void defrag_sched(float thold) override; |  | ||||||
|  |  | ||||||
|     bool get_can_shift() const override; |     bool get_can_shift() const override; | ||||||
|  |  | ||||||
| @@ -83,6 +94,8 @@ public: | |||||||
|  |  | ||||||
|     uint32_t get_size() const; |     uint32_t get_size() const; | ||||||
|  |  | ||||||
|  |     bool get_has_shift() const; | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // graph_build API |     // graph_build API | ||||||
|     // |     // | ||||||
| @@ -103,7 +116,9 @@ public: | |||||||
|  |  | ||||||
|     // find places for the provided ubatches in the cache, returns the head locations |     // find places for the provided ubatches in the cache, returns the head locations | ||||||
|     // return empty vector on failure |     // return empty vector on failure | ||||||
|     std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches); |     ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches); | ||||||
|  |  | ||||||
|  |     bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); | ||||||
|  |  | ||||||
|     // return the cell position where we can insert the ubatch |     // return the cell position where we can insert the ubatch | ||||||
|     // return -1 on failure to find a contiguous slot of kv cells |     // return -1 on failure to find a contiguous slot of kv cells | ||||||
| @@ -133,8 +148,7 @@ private: | |||||||
|         ggml_tensor * v; |         ggml_tensor * v; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     bool do_defrag = false; |     bool v_trans = true;  // the value tensor is transposed | ||||||
|     bool v_trans   = true;  // the value tensor is transposed |  | ||||||
|  |  | ||||||
|     // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) |     // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) | ||||||
|     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method |     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method | ||||||
| @@ -160,13 +174,8 @@ private: | |||||||
|     // model layer id -> KV cache layer id |     // model layer id -> KV cache layer id | ||||||
|     std::unordered_map<int32_t, int32_t> map_layer_ids; |     std::unordered_map<int32_t, int32_t> map_layer_ids; | ||||||
|  |  | ||||||
|     // defrag |     // return non-empty vector if cells have been moved | ||||||
|     struct { |     defrag_info defrag_prepare(int32_t n_max_nodes) const; | ||||||
|         std::vector<uint32_t> ids; |  | ||||||
|     } defrag_info; |  | ||||||
|  |  | ||||||
|     // return true if cells have been moved |  | ||||||
|     bool defrag_prepare(int32_t n_max_nodes); |  | ||||||
|  |  | ||||||
|     size_t total_size() const; |     size_t total_size() const; | ||||||
|  |  | ||||||
| @@ -192,7 +201,8 @@ private: | |||||||
|     llm_graph_result_ptr build_graph_defrag( |     llm_graph_result_ptr build_graph_defrag( | ||||||
|             const llama_cparams & cparams, |             const llama_cparams & cparams, | ||||||
|                    ggml_context * ctx, |                    ggml_context * ctx, | ||||||
|                     ggml_cgraph * gf) const; |                     ggml_cgraph * gf, | ||||||
|  |               const defrag_info & dinfo) const; | ||||||
|  |  | ||||||
|     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; |     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; | ||||||
|     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; |     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; | ||||||
| @@ -203,20 +213,29 @@ private: | |||||||
|  |  | ||||||
| class llama_kv_cache_unified_state : public llama_memory_state_i { | class llama_kv_cache_unified_state : public llama_memory_state_i { | ||||||
| public: | public: | ||||||
|  |     // some shorthands | ||||||
|  |     using ubatch_heads = llama_kv_cache_unified::ubatch_heads; | ||||||
|  |     using defrag_info  = llama_kv_cache_unified::defrag_info; | ||||||
|  |  | ||||||
|     // used for errors |     // used for errors | ||||||
|     llama_kv_cache_unified_state(llama_memory_status status); |     llama_kv_cache_unified_state(llama_memory_status status); | ||||||
|  |  | ||||||
|     // used to create a full-cache state |     // used to create a full-cache state | ||||||
|     llama_kv_cache_unified_state( |     llama_kv_cache_unified_state( | ||||||
|             llama_memory_status status, |  | ||||||
|             llama_kv_cache_unified * kv); |             llama_kv_cache_unified * kv); | ||||||
|  |  | ||||||
|     // used to create a state from a batch |     // used to create an update state | ||||||
|  |     llama_kv_cache_unified_state( | ||||||
|  |             llama_kv_cache_unified * kv, | ||||||
|  |             llama_context * lctx, | ||||||
|  |             bool do_shift, | ||||||
|  |             defrag_info dinfo); | ||||||
|  |  | ||||||
|  |     // used to create a decode state from a batch | ||||||
|     llama_kv_cache_unified_state( |     llama_kv_cache_unified_state( | ||||||
|             llama_memory_status status, |  | ||||||
|             llama_kv_cache_unified * kv, |             llama_kv_cache_unified * kv, | ||||||
|             llama_sbatch sbatch, |             llama_sbatch sbatch, | ||||||
|             std::vector<uint32_t> heads, |             ubatch_heads heads, | ||||||
|             std::vector<llama_ubatch> ubatches); |             std::vector<llama_ubatch> ubatches); | ||||||
|  |  | ||||||
|     virtual ~llama_kv_cache_unified_state(); |     virtual ~llama_kv_cache_unified_state(); | ||||||
| @@ -253,16 +272,30 @@ public: | |||||||
|     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; |     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     const llama_memory_status status; |     llama_memory_status status; | ||||||
|  |  | ||||||
|     llama_kv_cache_unified * kv; |     llama_kv_cache_unified * kv; | ||||||
|  |     llama_context * lctx; | ||||||
|  |  | ||||||
|  |     // | ||||||
|  |     // update state | ||||||
|  |     // | ||||||
|  |  | ||||||
|  |     bool do_shift = false; | ||||||
|  |  | ||||||
|  |     defrag_info dinfo; | ||||||
|  |  | ||||||
|  |     // | ||||||
|  |     // batch processing state | ||||||
|  |     // | ||||||
|  |  | ||||||
|     llama_sbatch sbatch; |     llama_sbatch sbatch; | ||||||
|  |  | ||||||
|     // the index of the next ubatch to process |     // the index of the next ubatch to process | ||||||
|     size_t i_next = 0; |     size_t i_next = 0; | ||||||
|  |  | ||||||
|     std::vector<uint32_t> heads; |     ubatch_heads heads; | ||||||
|  |  | ||||||
|     std::vector<llama_ubatch> ubatches; |     std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
|     // |     // | ||||||
|   | |||||||
| @@ -1,12 +1,16 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  |  | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
| #include "llama-io.h" |  | ||||||
| #include "llama-memory.h" | #include "llama-memory.h" | ||||||
|  |  | ||||||
|  | class llama_io_write_i; | ||||||
|  | class llama_io_read_i; | ||||||
|  |  | ||||||
| struct llama_kv_cache : public llama_memory_i { | struct llama_kv_cache : public llama_memory_i { | ||||||
|     virtual ~llama_kv_cache() = default; |     virtual ~llama_kv_cache() = default; | ||||||
|  |  | ||||||
|  |     // TODO: move the init_ interfaces to llama_memory_i | ||||||
|  |  | ||||||
|     // split the input batch into a set of ubatches and verify that they can fit into the cache |     // split the input batch into a set of ubatches and verify that they can fit into the cache | ||||||
|     // return a state object containing the ubatches and KV cache state required to process them |     // return a state object containing the ubatches and KV cache state required to process them | ||||||
|     // check the llama_memory_state_i::get_status() for the result |     // check the llama_memory_state_i::get_status() for the result | ||||||
| @@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i { | |||||||
|     // simulate full cache, used for allocating worst-case compute buffers |     // simulate full cache, used for allocating worst-case compute buffers | ||||||
|     virtual llama_memory_state_ptr init_full() = 0; |     virtual llama_memory_state_ptr init_full() = 0; | ||||||
|  |  | ||||||
|     // process any pending defrag/shift/etc. operations |     // prepare for any pending memory updates, such as shifts, defrags, etc. | ||||||
|     // optionally call once before processing a new batch |     // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update | ||||||
|     // return true if any operations were performed |     virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; | ||||||
|     virtual bool update(llama_context & lctx) = 0; |  | ||||||
|  |  | ||||||
|     // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing |  | ||||||
|     // TODO: change to |  | ||||||
|     //   llama_memory_state_ptr init_defrag(float thold) = 0; |  | ||||||
|     // |  | ||||||
|     virtual void defrag_sched(float thold) = 0; |  | ||||||
|  |  | ||||||
|     // getters |     // getters | ||||||
|     virtual bool get_can_shift() const = 0; |     virtual bool get_can_shift() const = 0; | ||||||
|   | |||||||
| @@ -1 +1,42 @@ | |||||||
| #include "llama-memory.h" | #include "llama-memory.h" | ||||||
|  |  | ||||||
|  | llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { | ||||||
|  |     bool has_update = false; | ||||||
|  |  | ||||||
|  |     switch (s0) { | ||||||
|  |         case LLAMA_MEMORY_STATUS_SUCCESS: | ||||||
|  |             { | ||||||
|  |                 has_update = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         case LLAMA_MEMORY_STATUS_NO_UPDATE: | ||||||
|  |             { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         case LLAMA_MEMORY_STATUS_FAILED_PREPARE: | ||||||
|  |         case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: | ||||||
|  |             { | ||||||
|  |                 return s0; | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     switch (s1) { | ||||||
|  |         case LLAMA_MEMORY_STATUS_SUCCESS: | ||||||
|  |             { | ||||||
|  |                 has_update = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         case LLAMA_MEMORY_STATUS_NO_UPDATE: | ||||||
|  |             { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |         case LLAMA_MEMORY_STATUS_FAILED_PREPARE: | ||||||
|  |         case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: | ||||||
|  |             { | ||||||
|  |                 return s1; | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // if either status has an update, then the combined status has an update | ||||||
|  |     return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; | ||||||
|  | } | ||||||
|   | |||||||
| @@ -36,12 +36,19 @@ public: | |||||||
|     virtual bool get_can_edit() const = 0; |     virtual bool get_can_edit() const = 0; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | using llama_memory_ptr = std::unique_ptr<llama_memory_i>; | ||||||
|  |  | ||||||
| enum llama_memory_status { | enum llama_memory_status { | ||||||
|     LLAMA_MEMORY_STATUS_SUCCESS = 0, |     LLAMA_MEMORY_STATUS_SUCCESS = 0, | ||||||
|  |     LLAMA_MEMORY_STATUS_NO_UPDATE, | ||||||
|     LLAMA_MEMORY_STATUS_FAILED_PREPARE, |     LLAMA_MEMORY_STATUS_FAILED_PREPARE, | ||||||
|     LLAMA_MEMORY_STATUS_FAILED_COMPUTE, |     LLAMA_MEMORY_STATUS_FAILED_COMPUTE, | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | // helper function for combining the status of two memory states | ||||||
|  | // useful for implementing hybrid memory types (e.g. iSWA) | ||||||
|  | llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); | ||||||
|  |  | ||||||
| // the interface for managing the memory state during batch processing | // the interface for managing the memory state during batch processing | ||||||
| // this interface is implemented per memory type. see: | // this interface is implemented per memory type. see: | ||||||
| //   - llama_kv_cache_unified_state | //   - llama_kv_cache_unified_state | ||||||
| @@ -69,7 +76,7 @@ public: | |||||||
|     // get the current ubatch |     // get the current ubatch | ||||||
|     virtual const llama_ubatch & get_ubatch() const = 0; |     virtual const llama_ubatch & get_ubatch() const = 0; | ||||||
|  |  | ||||||
|     // get the status of the memory state |     // get the status of the memory state - used for error handling and checking if any updates would be applied | ||||||
|     virtual llama_memory_status get_status() const = 0; |     virtual llama_memory_status get_status() const = 0; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov