mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
llama : introduce concept of llama_memory
ggml-ci
This commit is contained in:
@@ -49,7 +49,7 @@ llama_context_base::llama_context_base(
|
||||
const llama_model & model,
|
||||
llama_context_params params,
|
||||
llama_graph_type gtype) :
|
||||
llama_context_i(),
|
||||
llama_context(),
|
||||
llama_graph_i(gtype),
|
||||
model(model) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype);
|
||||
|
||||
@@ -21,10 +21,10 @@ class llama_io_write_i;
|
||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||
|
||||
// abstract interface corresponding to the public C API
|
||||
struct llama_context {
|
||||
class llama_context_i {
|
||||
public:
|
||||
llama_context() = default;
|
||||
virtual ~llama_context() = default;
|
||||
llama_context_i() = default;
|
||||
virtual ~llama_context_i() = default;
|
||||
|
||||
virtual void init() = 0;
|
||||
|
||||
@@ -157,14 +157,13 @@ public:
|
||||
size_t n_token_count) = 0;
|
||||
};
|
||||
|
||||
// C++ alias
|
||||
class llama_context_i : public llama_context {
|
||||
public:
|
||||
using llama_context::llama_context;
|
||||
// C alias
|
||||
struct llama_context : public llama_context_i {
|
||||
using llama_context_i::llama_context_i;
|
||||
};
|
||||
|
||||
// basic transformer without KV cache
|
||||
class llama_context_base : public llama_context_i, public llama_graph_i {
|
||||
class llama_context_base : public llama_context, public llama_graph_i {
|
||||
public:
|
||||
llama_context_base(
|
||||
const llama_model & model,
|
||||
@@ -821,7 +820,7 @@ public:
|
||||
llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llama_context_enc_dec : public llama_context_i {
|
||||
class llama_context_enc_dec : public llama_context {
|
||||
public:
|
||||
llama_context_enc_dec(
|
||||
const llama_model & model,
|
||||
|
||||
@@ -122,7 +122,7 @@ bool llama_kv_cache_unified::init(
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_unified::n_tokens() const {
|
||||
int32_t llama_kv_cache_unified::get_n_tokens() const {
|
||||
int32_t result = 0;
|
||||
|
||||
for (uint32_t i = 0; i < size; i++) {
|
||||
@@ -132,7 +132,7 @@ int32_t llama_kv_cache_unified::n_tokens() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::used_cells() const {
|
||||
uint32_t llama_kv_cache_unified::get_used_cells() const {
|
||||
return used;
|
||||
}
|
||||
|
||||
@@ -1091,7 +1091,7 @@ int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return kv->n_tokens();
|
||||
return kv->get_n_tokens();
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
||||
@@ -1099,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return kv->used_cells();
|
||||
return kv->get_used_cells();
|
||||
}
|
||||
|
||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
@@ -13,6 +13,17 @@ struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_ubatch;
|
||||
|
||||
struct llama_kv_cache : public llama_memory_i {
|
||||
using llama_memory_i::llama_memory_i;
|
||||
|
||||
virtual int32_t get_n_tokens() const = 0;
|
||||
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||
|
||||
virtual bool get_can_shift() const = 0;
|
||||
|
||||
bool get_can_edit() const override { return get_can_shift(); }
|
||||
};
|
||||
|
||||
struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
|
||||
operator bool() const { return found; }
|
||||
};
|
||||
|
||||
struct llama_kv_cache {
|
||||
public:
|
||||
virtual int32_t n_tokens() const = 0;
|
||||
virtual uint32_t used_cells() const = 0; // TODO: remove
|
||||
|
||||
virtual void clear() = 0;
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
||||
|
||||
virtual void defrag() = 0;
|
||||
virtual bool get_can_shift() const = 0;
|
||||
};
|
||||
|
||||
|
||||
// C++ alias
|
||||
class llama_kv_cache_i : public llama_kv_cache {
|
||||
public:
|
||||
using llama_kv_cache::llama_kv_cache;
|
||||
};
|
||||
|
||||
|
||||
// ring-buffer of cached KV data
|
||||
// TODO: pimpl
|
||||
// TODO: add notion of max sequences
|
||||
class llama_kv_cache_unified : public llama_kv_cache_i {
|
||||
class llama_kv_cache_unified : public llama_kv_cache {
|
||||
public:
|
||||
llama_kv_cache_unified(const llama_hparams & hparams);
|
||||
virtual ~llama_kv_cache_unified() = default;
|
||||
@@ -88,8 +73,8 @@ public:
|
||||
uint32_t kv_size,
|
||||
bool offload);
|
||||
|
||||
int32_t n_tokens() const override;
|
||||
uint32_t used_cells() const override;
|
||||
int32_t get_n_tokens() const override;
|
||||
uint32_t get_used_cells() const override;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
@@ -97,6 +82,7 @@ public:
|
||||
llama_pos pos_max() const;
|
||||
|
||||
void clear() override;
|
||||
void defrag() override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
@@ -106,7 +92,6 @@ public:
|
||||
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
||||
|
||||
void defrag() override;
|
||||
bool get_can_shift() const override;
|
||||
|
||||
// find an empty slot of size "n_tokens" in the cache
|
||||
|
||||
1295
src/llama-memory.cpp
Normal file
1295
src/llama-memory.cpp
Normal file
File diff suppressed because it is too large
Load Diff
21
src/llama-memory.h
Normal file
21
src/llama-memory.h
Normal file
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
// general concept of LLM memory
|
||||
// the KV cache is a type of LLM memory, but there can be other types
|
||||
class llama_memory_i {
|
||||
public:
|
||||
virtual void clear() = 0;
|
||||
virtual void defrag() = 0;
|
||||
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
||||
|
||||
virtual bool get_can_edit() const = 0;
|
||||
};
|
||||
Reference in New Issue
Block a user