llama : introduce concept of llama_memory

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-28 10:51:17 +02:00
parent 828effd9d7
commit 38db8a5861
6 changed files with 1345 additions and 45 deletions

View File

@@ -2,7 +2,7 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
@@ -13,6 +13,17 @@ struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
operator bool() const { return found; }
};
struct llama_kv_cache {
public:
virtual int32_t n_tokens() const = 0;
virtual uint32_t used_cells() const = 0; // TODO: remove
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual void defrag() = 0;
virtual bool get_can_shift() const = 0;
};
// C++ alias
class llama_kv_cache_i : public llama_kv_cache {
public:
using llama_kv_cache::llama_kv_cache;
};
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache_i {
class llama_kv_cache_unified : public llama_kv_cache {
public:
llama_kv_cache_unified(const llama_hparams & hparams);
virtual ~llama_kv_cache_unified() = default;
@@ -88,8 +73,8 @@ public:
uint32_t kv_size,
bool offload);
int32_t n_tokens() const override;
uint32_t used_cells() const override;
int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override;
size_t total_size() const;
@@ -97,6 +82,7 @@ public:
llama_pos pos_max() const;
void clear() override;
void defrag() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -106,7 +92,6 @@ public:
llama_pos seq_pos_max(llama_seq_id seq_id) override;
void defrag() override;
bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache