llama : introduce concept of llama_memory

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-28 10:51:17 +02:00
parent 828effd9d7
commit 38db8a5861
6 changed files with 1345 additions and 45 deletions

View File

@@ -49,7 +49,7 @@ llama_context_base::llama_context_base(
const llama_model & model, const llama_model & model,
llama_context_params params, llama_context_params params,
llama_graph_type gtype) : llama_graph_type gtype) :
llama_context_i(), llama_context(),
llama_graph_i(gtype), llama_graph_i(gtype),
model(model) { model(model) {
LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype); LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype);

View File

@@ -21,10 +21,10 @@ class llama_io_write_i;
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>; using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
// abstract interface corresponding to the public C API // abstract interface corresponding to the public C API
struct llama_context { class llama_context_i {
public: public:
llama_context() = default; llama_context_i() = default;
virtual ~llama_context() = default; virtual ~llama_context_i() = default;
virtual void init() = 0; virtual void init() = 0;
@@ -157,14 +157,13 @@ public:
size_t n_token_count) = 0; size_t n_token_count) = 0;
}; };
// C++ alias // C alias
class llama_context_i : public llama_context { struct llama_context : public llama_context_i {
public: using llama_context_i::llama_context_i;
using llama_context::llama_context;
}; };
// basic transformer without KV cache // basic transformer without KV cache
class llama_context_base : public llama_context_i, public llama_graph_i { class llama_context_base : public llama_context, public llama_graph_i {
public: public:
llama_context_base( llama_context_base(
const llama_model & model, const llama_model & model,
@@ -821,7 +820,7 @@ public:
llama_cross * cross = nullptr; llama_cross * cross = nullptr;
}; };
class llama_context_enc_dec : public llama_context_i { class llama_context_enc_dec : public llama_context {
public: public:
llama_context_enc_dec( llama_context_enc_dec(
const llama_model & model, const llama_model & model,

View File

@@ -122,7 +122,7 @@ bool llama_kv_cache_unified::init(
return true; return true;
} }
int32_t llama_kv_cache_unified::n_tokens() const { int32_t llama_kv_cache_unified::get_n_tokens() const {
int32_t result = 0; int32_t result = 0;
for (uint32_t i = 0; i < size; i++) { for (uint32_t i = 0; i < size; i++) {
@@ -132,7 +132,7 @@ int32_t llama_kv_cache_unified::n_tokens() const {
return result; return result;
} }
uint32_t llama_kv_cache_unified::used_cells() const { uint32_t llama_kv_cache_unified::get_used_cells() const {
return used; return used;
} }
@@ -1091,7 +1091,7 @@ int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
return 0; return 0;
} }
return kv->n_tokens(); return kv->get_n_tokens();
} }
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) { int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
@@ -1099,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
return 0; return 0;
} }
return kv->used_cells(); return kv->get_used_cells();
} }
void llama_kv_cache_clear(llama_kv_cache * kv) { void llama_kv_cache_clear(llama_kv_cache * kv) {

View File

@@ -2,7 +2,7 @@
#include "llama.h" #include "llama.h"
#include "llama-io.h" #include "llama-io.h"
#include "llama-graph.h" #include "llama-memory.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
@@ -13,6 +13,17 @@ struct llama_cparams;
struct llama_hparams; struct llama_hparams;
struct llama_ubatch; struct llama_ubatch;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
};
struct llama_kv_cell { struct llama_kv_cell {
llama_pos pos = -1; llama_pos pos = -1;
llama_pos delta = 0; llama_pos delta = 0;
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
operator bool() const { return found; } operator bool() const { return found; }
}; };
struct llama_kv_cache {
public:
virtual int32_t n_tokens() const = 0;
virtual uint32_t used_cells() const = 0; // TODO: remove
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual void defrag() = 0;
virtual bool get_can_shift() const = 0;
};
// C++ alias
class llama_kv_cache_i : public llama_kv_cache {
public:
using llama_kv_cache::llama_kv_cache;
};
// ring-buffer of cached KV data // ring-buffer of cached KV data
// TODO: pimpl // TODO: pimpl
// TODO: add notion of max sequences // TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache_i { class llama_kv_cache_unified : public llama_kv_cache {
public: public:
llama_kv_cache_unified(const llama_hparams & hparams); llama_kv_cache_unified(const llama_hparams & hparams);
virtual ~llama_kv_cache_unified() = default; virtual ~llama_kv_cache_unified() = default;
@@ -88,8 +73,8 @@ public:
uint32_t kv_size, uint32_t kv_size,
bool offload); bool offload);
int32_t n_tokens() const override; int32_t get_n_tokens() const override;
uint32_t used_cells() const override; uint32_t get_used_cells() const override;
size_t total_size() const; size_t total_size() const;
@@ -97,6 +82,7 @@ public:
llama_pos pos_max() const; llama_pos pos_max() const;
void clear() override; void clear() override;
void defrag() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -106,7 +92,6 @@ public:
llama_pos seq_pos_max(llama_seq_id seq_id) override; llama_pos seq_pos_max(llama_seq_id seq_id) override;
void defrag() override;
bool get_can_shift() const override; bool get_can_shift() const override;
// find an empty slot of size "n_tokens" in the cache // find an empty slot of size "n_tokens" in the cache

1295
src/llama-memory.cpp Normal file

File diff suppressed because it is too large Load Diff

21
src/llama-memory.h Normal file
View File

@@ -0,0 +1,21 @@
#pragma once
#include "llama.h"
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual bool get_can_edit() const = 0;
};