context : add llama_kv_cache_recurrent prototype

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-20 20:54:18 +02:00
parent ad870c49f4
commit 08011c2ca1
3 changed files with 477 additions and 102 deletions

View File

@@ -374,9 +374,6 @@ public:
virtual int encode(llama_batch & inp_batch) override;
virtual int decode(llama_batch & inp_batch) override;
// max token position across all sequences in the current context
llama_pos pos_max() const;
// certain implementations could require a padding for the context size
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
@@ -453,9 +450,7 @@ protected:
};
// a recurrent transformer (ie.e RWKV, Mamba)
// TODO: temporary reuse kv_self, but in the future, implement recurrent-specific context with specific cache
//class llama_context_recurrent : public llama_context {
class llama_context_recurrent : public llama_context_kv_self {
class llama_context_recurrent : public llama_context {
public:
llama_context_recurrent(
const llama_model & model,
@@ -463,8 +458,16 @@ public:
virtual ~llama_context_recurrent();
virtual llama_kv_cache * get_kv_self() override;
virtual const llama_kv_cache * get_kv_self() const override;
virtual void kv_self_update() override;
virtual ggml_cgraph * graph_init() override;
virtual int encode(llama_batch & inp_batch) override;
virtual int decode(llama_batch & inp_batch) override;
virtual ggml_tensor * build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) override;
@@ -524,10 +527,11 @@ public:
protected:
virtual void input_set(const llama_ubatch & ubatch) override;
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
llama_kv_cache_recurrent kv_self;
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
// TODO: add recurrent cache
};
// For internal test use