mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
llama : introduce concept of llama_memory
ggml-ci
This commit is contained in:
@@ -49,7 +49,7 @@ llama_context_base::llama_context_base(
|
|||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
llama_context_params params,
|
llama_context_params params,
|
||||||
llama_graph_type gtype) :
|
llama_graph_type gtype) :
|
||||||
llama_context_i(),
|
llama_context(),
|
||||||
llama_graph_i(gtype),
|
llama_graph_i(gtype),
|
||||||
model(model) {
|
model(model) {
|
||||||
LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype);
|
LLAMA_LOG_INFO("%s: constructing llama_context_base, gtype = %d\n", __func__, gtype);
|
||||||
|
|||||||
@@ -21,10 +21,10 @@ class llama_io_write_i;
|
|||||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||||
|
|
||||||
// abstract interface corresponding to the public C API
|
// abstract interface corresponding to the public C API
|
||||||
struct llama_context {
|
class llama_context_i {
|
||||||
public:
|
public:
|
||||||
llama_context() = default;
|
llama_context_i() = default;
|
||||||
virtual ~llama_context() = default;
|
virtual ~llama_context_i() = default;
|
||||||
|
|
||||||
virtual void init() = 0;
|
virtual void init() = 0;
|
||||||
|
|
||||||
@@ -157,14 +157,13 @@ public:
|
|||||||
size_t n_token_count) = 0;
|
size_t n_token_count) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// C++ alias
|
// C alias
|
||||||
class llama_context_i : public llama_context {
|
struct llama_context : public llama_context_i {
|
||||||
public:
|
using llama_context_i::llama_context_i;
|
||||||
using llama_context::llama_context;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// basic transformer without KV cache
|
// basic transformer without KV cache
|
||||||
class llama_context_base : public llama_context_i, public llama_graph_i {
|
class llama_context_base : public llama_context, public llama_graph_i {
|
||||||
public:
|
public:
|
||||||
llama_context_base(
|
llama_context_base(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
@@ -821,7 +820,7 @@ public:
|
|||||||
llama_cross * cross = nullptr;
|
llama_cross * cross = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
class llama_context_enc_dec : public llama_context_i {
|
class llama_context_enc_dec : public llama_context {
|
||||||
public:
|
public:
|
||||||
llama_context_enc_dec(
|
llama_context_enc_dec(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ bool llama_kv_cache_unified::init(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_unified::n_tokens() const {
|
int32_t llama_kv_cache_unified::get_n_tokens() const {
|
||||||
int32_t result = 0;
|
int32_t result = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; i++) {
|
for (uint32_t i = 0; i < size; i++) {
|
||||||
@@ -132,7 +132,7 @@ int32_t llama_kv_cache_unified::n_tokens() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::used_cells() const {
|
uint32_t llama_kv_cache_unified::get_used_cells() const {
|
||||||
return used;
|
return used;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1091,7 +1091,7 @@ int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->n_tokens();
|
return kv->get_n_tokens();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
||||||
@@ -1099,7 +1099,7 @@ int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv->used_cells();
|
return kv->get_used_cells();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
void llama_kv_cache_clear(llama_kv_cache * kv) {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
#include "llama-graph.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
@@ -13,6 +13,17 @@ struct llama_cparams;
|
|||||||
struct llama_hparams;
|
struct llama_hparams;
|
||||||
struct llama_ubatch;
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
|
using llama_memory_i::llama_memory_i;
|
||||||
|
|
||||||
|
virtual int32_t get_n_tokens() const = 0;
|
||||||
|
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||||
|
|
||||||
|
virtual bool get_can_shift() const = 0;
|
||||||
|
|
||||||
|
bool get_can_edit() const override { return get_can_shift(); }
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
@@ -45,36 +56,10 @@ struct llama_kv_cache_slot_info {
|
|||||||
operator bool() const { return found; }
|
operator bool() const { return found; }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_kv_cache {
|
|
||||||
public:
|
|
||||||
virtual int32_t n_tokens() const = 0;
|
|
||||||
virtual uint32_t used_cells() const = 0; // TODO: remove
|
|
||||||
|
|
||||||
virtual void clear() = 0;
|
|
||||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
|
||||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
|
||||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
|
||||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
|
||||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
|
||||||
|
|
||||||
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
|
||||||
|
|
||||||
virtual void defrag() = 0;
|
|
||||||
virtual bool get_can_shift() const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// C++ alias
|
|
||||||
class llama_kv_cache_i : public llama_kv_cache {
|
|
||||||
public:
|
|
||||||
using llama_kv_cache::llama_kv_cache;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
// TODO: pimpl
|
// TODO: pimpl
|
||||||
// TODO: add notion of max sequences
|
// TODO: add notion of max sequences
|
||||||
class llama_kv_cache_unified : public llama_kv_cache_i {
|
class llama_kv_cache_unified : public llama_kv_cache {
|
||||||
public:
|
public:
|
||||||
llama_kv_cache_unified(const llama_hparams & hparams);
|
llama_kv_cache_unified(const llama_hparams & hparams);
|
||||||
virtual ~llama_kv_cache_unified() = default;
|
virtual ~llama_kv_cache_unified() = default;
|
||||||
@@ -88,8 +73,8 @@ public:
|
|||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
bool offload);
|
bool offload);
|
||||||
|
|
||||||
int32_t n_tokens() const override;
|
int32_t get_n_tokens() const override;
|
||||||
uint32_t used_cells() const override;
|
uint32_t get_used_cells() const override;
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
|
|
||||||
@@ -97,6 +82,7 @@ public:
|
|||||||
llama_pos pos_max() const;
|
llama_pos pos_max() const;
|
||||||
|
|
||||||
void clear() override;
|
void clear() override;
|
||||||
|
void defrag() override;
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
@@ -106,7 +92,6 @@ public:
|
|||||||
|
|
||||||
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
llama_pos seq_pos_max(llama_seq_id seq_id) override;
|
||||||
|
|
||||||
void defrag() override;
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
// find an empty slot of size "n_tokens" in the cache
|
||||||
|
|||||||
1295
src/llama-memory.cpp
Normal file
1295
src/llama-memory.cpp
Normal file
File diff suppressed because it is too large
Load Diff
21
src/llama-memory.h
Normal file
21
src/llama-memory.h
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
// general concept of LLM memory
|
||||||
|
// the KV cache is a type of LLM memory, but there can be other types
|
||||||
|
class llama_memory_i {
|
||||||
|
public:
|
||||||
|
virtual void clear() = 0;
|
||||||
|
virtual void defrag() = 0;
|
||||||
|
|
||||||
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||||
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
|
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
|
||||||
|
|
||||||
|
virtual bool get_can_edit() const = 0;
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user