mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
fix: Update recurrent cache for changes to remove intermediate kv_cache interface
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cells.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@@ -16,7 +17,7 @@
|
||||
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
|
||||
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_hybrid_recurrent(
|
||||
const llama_model & model,
|
||||
@@ -42,6 +43,18 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
@@ -53,24 +66,6 @@ public:
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
bool update(llama_context & lctx) override;
|
||||
|
||||
void defrag_sched(float thold) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
@@ -92,12 +87,21 @@ private:
|
||||
|
||||
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
|
||||
public:
|
||||
using llama_kv_cache_unified_state_ptr = std::unique_ptr<llama_kv_cache_unified_state>;
|
||||
using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>;
|
||||
|
||||
// init failure
|
||||
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv);
|
||||
|
||||
// init update
|
||||
explicit llama_kv_cache_hybrid_recurrent_state(
|
||||
llama_kv_cache_hybrid_recurrent * kv,
|
||||
llama_kv_cache_unified_state * state_unified,
|
||||
llama_kv_cache_recurrent_state * state_recurrent);
|
||||
|
||||
// init success
|
||||
llama_kv_cache_hybrid_recurrent_state(
|
||||
llama_kv_cache_hybrid_recurrent * kv,
|
||||
@@ -116,7 +120,7 @@ public:
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_hybrid_recurrent_state_i
|
||||
// llama_kv_cache_hybrid_recurrent_state
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_state_attn () const;
|
||||
@@ -135,6 +139,6 @@ private:
|
||||
std::vector<uint32_t> heads_attn;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_kv_cache_unified_state state_attn;
|
||||
const llama_kv_cache_recurrent_state state_recurrent;
|
||||
const llama_kv_cache_unified_state_ptr state_attn;
|
||||
const llama_kv_cache_recurrent_state_ptr state_recurrent;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user