mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
context : add llama_kv_cache_recurrent prototype
ggml-ci
This commit is contained in:
@@ -374,9 +374,6 @@ public:
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
// certain implementations could require a padding for the context size
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
@@ -453,9 +450,7 @@ protected:
|
||||
};
|
||||
|
||||
// a recurrent transformer (ie.e RWKV, Mamba)
|
||||
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
|
||||
//class llama_context_recurrent : public llama_context {
|
||||
class llama_context_recurrent : public llama_context_kv_self {
|
||||
class llama_context_recurrent : public llama_context {
|
||||
public:
|
||||
llama_context_recurrent(
|
||||
const llama_model & model,
|
||||
@@ -463,8 +458,16 @@ public:
|
||||
|
||||
virtual ~llama_context_recurrent();
|
||||
|
||||
virtual llama_kv_cache * get_kv_self() override;
|
||||
virtual const llama_kv_cache * get_kv_self() const override;
|
||||
|
||||
virtual void kv_self_update() override;
|
||||
|
||||
virtual ggml_cgraph * graph_init() override;
|
||||
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
|
||||
virtual ggml_tensor * build_inp_s_copy(
|
||||
ggml_context * ctx0,
|
||||
bool worst_case) override;
|
||||
@@ -524,10 +527,11 @@ public:
|
||||
protected:
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||
llama_kv_cache_recurrent kv_self;
|
||||
|
||||
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
||||
|
||||
// TODO: add recurrent cache
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
||||
Reference in New Issue
Block a user