mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	fix: Update recurrent cache for changes to remove intermediate kv_cache interface
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
		| @@ -49,6 +49,59 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( | |||||||
|         n_seq_max |         n_seq_max | ||||||
|     )) {} |     )) {} | ||||||
|  |  | ||||||
|  | llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { | ||||||
|  |  | ||||||
|  |     // since this includes a recurrent cache, we cannot use split_simple | ||||||
|  |     auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); | ||||||
|  |  | ||||||
|  |     // follow the recurrent pattern for creating the ubatch splits | ||||||
|  |     std::vector<llama_ubatch> ubatches; | ||||||
|  |     while (sbatch.n_tokens > 0) { | ||||||
|  |         llama_ubatch ubatch; | ||||||
|  |  | ||||||
|  |         if (embd_pooled) { | ||||||
|  |             // Pooled embeddings cannot be split across ubatches (yet) | ||||||
|  |             ubatch = sbatch.split_seq(n_ubatch); | ||||||
|  |         } else { | ||||||
|  |             ubatch = sbatch.split_equal(n_ubatch); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         ubatches.push_back(ubatch); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // prepare the recurrent batches first | ||||||
|  |     if (!kv_recurrent->prepare(ubatches)) { | ||||||
|  |         // TODO: will the recurrent cache be in an undefined state at this point? | ||||||
|  |         LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); | ||||||
|  |         return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // prepare the attention cache | ||||||
|  |     auto heads_attn = kv_attn->prepare(ubatches); | ||||||
|  |     if (heads_attn.empty()) { | ||||||
|  |         LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); | ||||||
|  |         return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return std::make_unique<llama_kv_cache_hybrid_recurrent_state>( | ||||||
|  |         this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { | ||||||
|  |     return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { | ||||||
|  |     return std::make_unique<llama_kv_cache_hybrid_recurrent_state>( | ||||||
|  |         this, | ||||||
|  |         static_cast<llama_kv_cache_unified_state *>(  kv_attn     ->init_update(lctx, optimize).release()), | ||||||
|  |         static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release())); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { | ||||||
|  |     // Shifting is trivially supported for recurrent | ||||||
|  |     return kv_attn->get_can_shift(); | ||||||
|  | } | ||||||
| void llama_kv_cache_hybrid_recurrent::clear() { | void llama_kv_cache_hybrid_recurrent::clear() { | ||||||
|     kv_attn     ->clear(); |     kv_attn     ->clear(); | ||||||
|     kv_recurrent->clear(); |     kv_recurrent->clear(); | ||||||
| @@ -93,67 +146,6 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons | |||||||
|     return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); |     return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); | ||||||
| } | } | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { |  | ||||||
|  |  | ||||||
|     // since this includes a recurrent cache, we cannot use split_simple |  | ||||||
|     auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); |  | ||||||
|  |  | ||||||
|     // follow the recurrent pattern for creating the ubatch splits |  | ||||||
|     std::vector<llama_ubatch> ubatches; |  | ||||||
|     while (sbatch.n_tokens > 0) { |  | ||||||
|         llama_ubatch ubatch; |  | ||||||
|  |  | ||||||
|         if (embd_pooled) { |  | ||||||
|             // Pooled embeddings cannot be split across ubatches (yet) |  | ||||||
|             ubatch = sbatch.split_seq(n_ubatch); |  | ||||||
|         } else { |  | ||||||
|             ubatch = sbatch.split_equal(n_ubatch); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         ubatches.push_back(ubatch); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // prepare the recurrent batches first |  | ||||||
|     if (!kv_recurrent->prepare(ubatches)) { |  | ||||||
|         // TODO: will the recurrent cache be in an undefined state at this point? |  | ||||||
|         LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); |  | ||||||
|         return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // prepare the attention cache |  | ||||||
|     auto heads_attn = kv_attn->prepare(ubatches); |  | ||||||
|     if (heads_attn.empty()) { |  | ||||||
|         LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); |  | ||||||
|         return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     return std::make_unique<llama_kv_cache_hybrid_recurrent_state>( |  | ||||||
|         this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { |  | ||||||
|     return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { |  | ||||||
|     bool res = false; |  | ||||||
|  |  | ||||||
|     res = res | kv_attn     ->update(lctx); |  | ||||||
|     res = res | kv_recurrent->update(lctx); |  | ||||||
|  |  | ||||||
|     return res; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { |  | ||||||
|     kv_attn     ->defrag_sched(thold); |  | ||||||
|     kv_recurrent->defrag_sched(thold); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { |  | ||||||
|     // Shifting is trivially supported for recurrent |  | ||||||
|     return kv_attn->get_can_shift(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { | void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { | ||||||
|     kv_attn     ->state_write(io, seq_id); |     kv_attn     ->state_write(io, seq_id); | ||||||
|     kv_recurrent->state_write(io, seq_id); |     kv_recurrent->state_write(io, seq_id); | ||||||
| @@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c | |||||||
| } | } | ||||||
|  |  | ||||||
| llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) | llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) | ||||||
|     : status(status), state_attn(status), state_recurrent(status) {} |     : status(status), | ||||||
|  |       state_attn(new llama_kv_cache_unified_state(status)), | ||||||
|  |       state_recurrent(new llama_kv_cache_recurrent_state(status)) {} | ||||||
|  |  | ||||||
| llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) | llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) | ||||||
|     : status(LLAMA_MEMORY_STATUS_SUCCESS), |     : status(LLAMA_MEMORY_STATUS_SUCCESS), | ||||||
|       kv(kv), |       kv(kv), | ||||||
|       state_attn(status, kv->get_kv_attn()), |       state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), | ||||||
|       state_recurrent(status, kv->get_kv_recurrent()) {} |       state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} | ||||||
|  |  | ||||||
|  | llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( | ||||||
|  |         llama_kv_cache_hybrid_recurrent * kv, | ||||||
|  |            llama_kv_cache_unified_state * state_unified, | ||||||
|  |          llama_kv_cache_recurrent_state * state_recurrent) | ||||||
|  |     : status(LLAMA_MEMORY_STATUS_SUCCESS), | ||||||
|  |       kv(kv), | ||||||
|  |       state_attn(state_unified), | ||||||
|  |       state_recurrent(state_recurrent) {} | ||||||
|  |  | ||||||
| llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( | llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( | ||||||
|     llama_kv_cache_hybrid_recurrent * kv, |     llama_kv_cache_hybrid_recurrent * kv, | ||||||
| @@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( | |||||||
|       // NOTE: these child states are only used as wrapper APIs for the |       // NOTE: these child states are only used as wrapper APIs for the | ||||||
|       //    const methods, so we use the "init full" signature since the |       //    const methods, so we use the "init full" signature since the | ||||||
|       //    actual state is not used. |       //    actual state is not used. | ||||||
|       state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), |       state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), | ||||||
|       state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} |       state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} | ||||||
|  |  | ||||||
|  |  | ||||||
| bool llama_kv_cache_hybrid_recurrent_state::next() { | bool llama_kv_cache_hybrid_recurrent_state::next() { | ||||||
| @@ -233,9 +236,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { | const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { | ||||||
|     return &state_attn; |     return state_attn.get(); | ||||||
| } | } | ||||||
|  |  | ||||||
| const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { | const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { | ||||||
|     return &state_recurrent; |     return state_recurrent.get(); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,9 +2,10 @@ | |||||||
|  |  | ||||||
| #include "llama-batch.h" | #include "llama-batch.h" | ||||||
| #include "llama-graph.h" | #include "llama-graph.h" | ||||||
| #include "llama-kv-cache.h" |  | ||||||
| #include "llama-kv-cache-recurrent.h" | #include "llama-kv-cache-recurrent.h" | ||||||
| #include "llama-kv-cache-unified.h" | #include "llama-kv-cache-unified.h" | ||||||
|  | #include "llama-kv-cells.h" | ||||||
|  | #include "llama-memory.h" | ||||||
|  |  | ||||||
| #include <memory> | #include <memory> | ||||||
| #include <vector> | #include <vector> | ||||||
| @@ -16,7 +17,7 @@ | |||||||
| // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to | // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to | ||||||
| //   support models where each layer may be either attention-based or recurrent | //   support models where each layer may be either attention-based or recurrent | ||||||
|  |  | ||||||
| class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { | class llama_kv_cache_hybrid_recurrent : public llama_memory_i { | ||||||
| public: | public: | ||||||
|     llama_kv_cache_hybrid_recurrent( |     llama_kv_cache_hybrid_recurrent( | ||||||
|             const llama_model & model, |             const llama_model & model, | ||||||
| @@ -42,6 +43,18 @@ public: | |||||||
|     // llama_memory_i |     // llama_memory_i | ||||||
|     // |     // | ||||||
|  |  | ||||||
|  |     llama_memory_state_ptr init_batch( | ||||||
|  |             const llama_batch & batch, | ||||||
|  |             uint32_t n_ubatch, | ||||||
|  |             bool embd_pooled, | ||||||
|  |             bool logits_all) override; | ||||||
|  |  | ||||||
|  |     llama_memory_state_ptr init_full() override; | ||||||
|  |  | ||||||
|  |     llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; | ||||||
|  |  | ||||||
|  |     bool get_can_shift() const override; | ||||||
|  |  | ||||||
|     void clear() override; |     void clear() override; | ||||||
|  |  | ||||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; |     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||||
| @@ -53,24 +66,6 @@ public: | |||||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; |     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||||
|     llama_pos seq_pos_max(llama_seq_id seq_id) const override; |     llama_pos seq_pos_max(llama_seq_id seq_id) const override; | ||||||
|  |  | ||||||
|     // |  | ||||||
|     // llama_kv_cache |  | ||||||
|     // |  | ||||||
|  |  | ||||||
|     llama_memory_state_ptr init_batch( |  | ||||||
|             const llama_batch & batch, |  | ||||||
|             uint32_t n_ubatch, |  | ||||||
|             bool embd_pooled, |  | ||||||
|             bool logits_all) override; |  | ||||||
|  |  | ||||||
|     llama_memory_state_ptr init_full() override; |  | ||||||
|  |  | ||||||
|     bool update(llama_context & lctx) override; |  | ||||||
|  |  | ||||||
|     void defrag_sched(float thold) override; |  | ||||||
|  |  | ||||||
|     bool get_can_shift() const override; |  | ||||||
|  |  | ||||||
|     // state write/load |     // state write/load | ||||||
|  |  | ||||||
|     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; |     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; | ||||||
| @@ -92,12 +87,21 @@ private: | |||||||
|  |  | ||||||
| class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { | class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { | ||||||
| public: | public: | ||||||
|  |     using llama_kv_cache_unified_state_ptr   = std::unique_ptr<llama_kv_cache_unified_state>; | ||||||
|  |     using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>; | ||||||
|  |  | ||||||
|     // init failure |     // init failure | ||||||
|     explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); |     explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); | ||||||
|  |  | ||||||
|     // init full |     // init full | ||||||
|     explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); |     explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); | ||||||
|  |  | ||||||
|  |     // init update | ||||||
|  |     explicit llama_kv_cache_hybrid_recurrent_state( | ||||||
|  |         llama_kv_cache_hybrid_recurrent * kv, | ||||||
|  |            llama_kv_cache_unified_state * state_unified, | ||||||
|  |          llama_kv_cache_recurrent_state * state_recurrent); | ||||||
|  |  | ||||||
|     // init success |     // init success | ||||||
|     llama_kv_cache_hybrid_recurrent_state( |     llama_kv_cache_hybrid_recurrent_state( | ||||||
|         llama_kv_cache_hybrid_recurrent * kv, |         llama_kv_cache_hybrid_recurrent * kv, | ||||||
| @@ -116,7 +120,7 @@ public: | |||||||
|     const llama_ubatch & get_ubatch() const override; |     const llama_ubatch & get_ubatch() const override; | ||||||
|  |  | ||||||
|     // |     // | ||||||
|     // llama_kv_cache_hybrid_recurrent_state_i |     // llama_kv_cache_hybrid_recurrent_state | ||||||
|     // |     // | ||||||
|  |  | ||||||
|     const llama_kv_cache_unified_state   * get_state_attn     () const; |     const llama_kv_cache_unified_state   * get_state_attn     () const; | ||||||
| @@ -135,6 +139,6 @@ private: | |||||||
|     std::vector<uint32_t>     heads_attn; |     std::vector<uint32_t>     heads_attn; | ||||||
|     std::vector<llama_ubatch> ubatches; |     std::vector<llama_ubatch> ubatches; | ||||||
|  |  | ||||||
|     const llama_kv_cache_unified_state   state_attn; |     const llama_kv_cache_unified_state_ptr   state_attn; | ||||||
|     const llama_kv_cache_recurrent_state state_recurrent; |     const llama_kv_cache_recurrent_state_ptr state_recurrent; | ||||||
| }; | }; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Gabe Goodhart
					Gabe Goodhart