mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
wip 2
This commit is contained in:
@@ -59,9 +59,25 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
return it->second;
|
||||
};
|
||||
|
||||
head = 0;
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
v_heads[s] = 0;
|
||||
}
|
||||
|
||||
cells.resize(kv_size);
|
||||
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_cells[s].resize(kv_size);
|
||||
}
|
||||
|
||||
// by default, all sequence ids are mapped to the 0th virtual sequence
|
||||
seq_virt_idx.resize(LLAMA_MAX_SEQ, 0);
|
||||
|
||||
if (n_seq_virt > 1) {
|
||||
seq_virt_idx.resize(n_seq_virt, 0);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
seq_virt_idx[s] = s;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < hparams.n_layer; il++) {
|
||||
if (filter && !filter(il)) {
|
||||
@@ -141,9 +157,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_cells[s].reset();
|
||||
v_heads[s] = 0;
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
@@ -153,6 +170,9 @@ void llama_kv_cache_unified::clear(bool data) {
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
@@ -199,6 +219,10 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "TODO: implement seq_cp() for n_seq_virt > 1");
|
||||
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id_src]];
|
||||
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
}
|
||||
@@ -223,6 +247,9 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
@@ -240,6 +267,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -279,6 +309,9 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
@@ -308,10 +341,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
@@ -357,7 +394,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
{
|
||||
if (n_seq_virt == 1) {
|
||||
// note : for now do not consider defrag for n_seq_virt > 1
|
||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
@@ -473,12 +513,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||
updated = true;
|
||||
}
|
||||
|
||||
cells.reset_shift();
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
auto & cells = v_cells[seq_virt_idx[s]];
|
||||
|
||||
cells.reset_shift();
|
||||
}
|
||||
}
|
||||
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
// note: for now do not consider defrag for n_seq_virt > 1
|
||||
auto & cells = v_cells[seq_virt_idx[0]];
|
||||
auto & head = v_heads[seq_virt_idx[0]];
|
||||
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
@@ -532,7 +580,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
uint32_t head_cur = this->head;
|
||||
// TODO: implement
|
||||
auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
uint32_t head_cur = v_heads[0];
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
@@ -546,7 +597,8 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
}
|
||||
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
||||
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
||||
|
||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||
std::string ss;
|
||||
@@ -748,15 +800,31 @@ bool llama_kv_cache_unified::get_can_shift() const {
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return cells.get_has_shift();
|
||||
bool result = false;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
result |= v_cells[seq_virt_idx[s]].get_has_shift();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
uint32_t result = 0;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
const auto & cells = v_cells[seq_virt_idx[s]];
|
||||
|
||||
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
||||
|
||||
@@ -167,10 +167,6 @@ private:
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
@@ -192,7 +188,13 @@ private:
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t v_heads[LLAMA_MAX_SEQ];
|
||||
|
||||
llama_kv_cells_unified v_cells[LLAMA_MAX_SEQ];
|
||||
|
||||
std::vector<uint32_t> seq_virt_idx;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user