mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
wip 3
This commit is contained in:
@@ -59,12 +59,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
return it->second;
|
||||
};
|
||||
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
|
||||
|
||||
v_heads.resize(n_seq_virt);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_heads[s] = 0;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
|
||||
|
||||
v_cells.resize(n_seq_virt);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_cells[s].resize(kv_size);
|
||||
}
|
||||
@@ -310,7 +312,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
if (d == 1) {
|
||||
return;
|
||||
@@ -427,16 +428,16 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
||||
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::slot_info_vec_t res;
|
||||
|
||||
struct state {
|
||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||
|
||||
struct state_t {
|
||||
slot_info sinfo; // slot info for the ubatch
|
||||
|
||||
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
||||
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
|
||||
};
|
||||
|
||||
// remember the old state of the cells so we can restore it in the end
|
||||
std::vector<state> states;
|
||||
std::vector<state_t> states;
|
||||
|
||||
bool success = true;
|
||||
|
||||
@@ -455,16 +456,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
||||
res.push_back(sinfo_new);
|
||||
|
||||
// store the old state of the cells in the recovery stack
|
||||
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
||||
{
|
||||
state_t state = { sinfo_new, v_heads, {} };
|
||||
|
||||
for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) {
|
||||
auto & cells = v_cells[sinfo_new.seq_id_virt[s]];
|
||||
|
||||
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
|
||||
}
|
||||
|
||||
states.push_back(std::move(state));
|
||||
}
|
||||
|
||||
// now emplace the ubatch
|
||||
apply_ubatch(sinfo_new, ubatch);
|
||||
}
|
||||
|
||||
GGML_ASSERT(!states.empty());
|
||||
|
||||
// iterate backwards and restore the cells to their original state
|
||||
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
||||
cells.set(it->sinfo.idxs, it->cells);
|
||||
head = it->head_old;
|
||||
const auto & sinfo = it->sinfo;
|
||||
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
||||
|
||||
cells.set(sinfo.idxs[s], it->v_cells[s]);
|
||||
head = it->v_heads_old[s];
|
||||
}
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
@@ -514,7 +534,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
auto & cells = v_cells[seq_virt_idx[s]];
|
||||
auto & cells = v_cells[s];
|
||||
|
||||
cells.reset_shift();
|
||||
}
|
||||
@@ -574,29 +594,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(!cont && "n_seq_virt > 1 does not support continuous slots");
|
||||
}
|
||||
if (debug > 0 && n_seq_virt == 1) {
|
||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t head_cur = v_heads[0];
|
||||
|
||||
// TODO: implement
|
||||
auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
uint32_t head_cur = v_heads[0];
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return { };
|
||||
}
|
||||
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
||||
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
||||
|
||||
@@ -655,29 +657,64 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t n_found = 0;
|
||||
uint32_t n_tested = 0;
|
||||
uint32_t n_tokens = ubatch.n_tokens;
|
||||
uint32_t n_seqs = 1;
|
||||
|
||||
const uint32_t n_test = cont ? n_tokens : 1;
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
|
||||
|
||||
n_seqs = ubatch.n_seqs_unq;
|
||||
n_tokens = n_tokens / n_seqs;
|
||||
}
|
||||
|
||||
slot_info res;
|
||||
|
||||
res.idxs.resize(n_tokens);
|
||||
res.resize(n_seqs);
|
||||
|
||||
while (true) {
|
||||
if (head_cur + n_test > cells.size()) {
|
||||
n_tested += cells.size() - head_cur;
|
||||
head_cur = 0;
|
||||
continue;
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const auto seq_id = ubatch.seq_id_unq[s];
|
||||
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
|
||||
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_test; i++) {
|
||||
const auto idx = head_cur;
|
||||
res.seq_id_virt[s] = seq_virt_idx[seq_id];
|
||||
res.idxs[s].resize(n_tokens);
|
||||
|
||||
head_cur++;
|
||||
n_tested++;
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t head_cur = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head_cur > cells.get_used() + 2*n_tokens) {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return { };
|
||||
}
|
||||
|
||||
uint32_t n_found = 0;
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
const uint32_t n_test = cont ? n_tokens : 1;
|
||||
|
||||
while (true) {
|
||||
if (head_cur + n_test > cells.size()) {
|
||||
n_tested += cells.size() - head_cur;
|
||||
head_cur = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_test; i++) {
|
||||
const auto idx = head_cur;
|
||||
|
||||
head_cur++;
|
||||
n_tested++;
|
||||
|
||||
if (n_seq_virt == 1) {
|
||||
//const llama_pos pos = ubatch.pos[i];
|
||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
@@ -709,7 +746,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
}
|
||||
|
||||
if (can_use) {
|
||||
res.idxs[n_found] = idx;
|
||||
res.idxs[s][n_found] = idx;
|
||||
|
||||
n_found++;
|
||||
} else {
|
||||
@@ -717,30 +754,28 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("WIP");
|
||||
}
|
||||
|
||||
if (n_found == n_tokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (cont) {
|
||||
n_found = 0;
|
||||
}
|
||||
|
||||
if (n_tested >= cells.size()) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return { };
|
||||
}
|
||||
}
|
||||
|
||||
if (n_found == n_tokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (cont) {
|
||||
n_found = 0;
|
||||
}
|
||||
|
||||
if (n_tested >= cells.size()) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
// we didn't find a suitable slot - return empty result
|
||||
if (n_found < n_tokens) {
|
||||
return { };
|
||||
}
|
||||
}
|
||||
|
||||
// we didn't find a suitable slot - return empty result
|
||||
if (n_found < n_tokens) {
|
||||
res.clear();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -748,41 +783,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
|
||||
assert(ubatch.n_tokens == sinfo.idxs.size());
|
||||
assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size());
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto idx = sinfo.idxs[i];
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
||||
const uint32_t i = s*sinfo.size() + ii;
|
||||
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
||||
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
const auto idx = sinfo.idxs[s][ii];
|
||||
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
|
||||
cells.rm(idx);
|
||||
}
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
cells.rm(idx);
|
||||
}
|
||||
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq_pos_max_rm[s] == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(s < seq_virt_idx.size());
|
||||
|
||||
auto & cells = v_cells[seq_virt_idx[s]];
|
||||
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||
@@ -792,7 +837,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
||||
}
|
||||
|
||||
// move the head at the end of the slot
|
||||
head = sinfo.idxs.back() + 1;
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
||||
|
||||
head = sinfo.idxs[s].back() + 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_can_shift() const {
|
||||
@@ -878,6 +927,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
// will be removed when ggml_set_rows() is adopted by all backends
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
||||
|
||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||
n_tokens*n_embd_k_gqa,
|
||||
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
||||
@@ -921,6 +972,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
// will be removed when ggml_set_rows() is adopted by all backends
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
||||
|
||||
ggml_tensor * v_view = nullptr;
|
||||
|
||||
if (!v_trans) {
|
||||
@@ -944,12 +997,20 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(n_tokens == sinfo.size()*sinfo.n_seq_virt());
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
int64_t * data = (int64_t *) dst->data;
|
||||
|
||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
data[i] = sinfo.idxs[i];
|
||||
//for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
// data[i] = sinfo.idxs[i];
|
||||
//}
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
||||
|
||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -959,7 +1020,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
float * data = (float *) dst->data;
|
||||
|
||||
const int64_t n_kv = dst->ne[0];
|
||||
const int64_t n_kv = dst->ne[0];
|
||||
const int64_t n_seq_virt = dst->ne[2]; // num virtual sequences in the current ubatch
|
||||
|
||||
GGML_ASSERT(n_tokens%n_seq_virt == 0);
|
||||
|
||||
const int64_t n_tokens_per_seq = n_tokens/n_seq_virt;
|
||||
const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD);
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
@@ -974,48 +1041,54 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) {
|
||||
const uint32_t i = s*n_tokens_per_seq + ii;
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
float f = 0.0f;
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
bool masked = false;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(j, seq_id));
|
||||
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
float f = 0.0f;
|
||||
|
||||
bool masked = false;
|
||||
|
||||
if (cells.is_empty(j)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(j, seq_id));
|
||||
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f;
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1027,14 +1100,21 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
const auto & cells = v_cells[s];
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences");
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
@@ -1141,7 +1221,7 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * k_shift; // I32 [kv_size]
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
@@ -1165,7 +1245,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
@@ -1181,7 +1261,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, cells.size(),
|
||||
n_embd_head_k, n_head_kv, get_size()*n_seq_virt,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
@@ -1203,6 +1283,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
const defrag_info & dinfo) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
#if 0
|
||||
@@ -1345,6 +1429,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
@@ -1493,6 +1581,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = cells.size();
|
||||
@@ -1547,6 +1638,9 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
std::vector<llama_seq_id> seq_ids;
|
||||
@@ -1573,6 +1667,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
@@ -1660,6 +1757,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
auto & cells = v_cells[0];
|
||||
auto & head = v_heads[0];
|
||||
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
@@ -1751,6 +1852,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
auto & cells = v_cells[0];
|
||||
auto & head = v_heads[0];
|
||||
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
||||
@@ -1888,8 +1993,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
n_kv = kv->get_size();
|
||||
|
||||
sinfos.resize(1);
|
||||
sinfos[0].seq_id_virt.resize(1, 0);
|
||||
sinfos[0].idxs.resize(1);
|
||||
sinfos[0].idxs[0] = 0;
|
||||
sinfos[0].idxs[0].resize(1, 0);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
|
||||
@@ -39,10 +39,28 @@ public:
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
idx_vec_t idxs;
|
||||
std::vector<llama_seq_id> seq_id_virt;
|
||||
std::vector<idx_vec_t> idxs;
|
||||
|
||||
uint32_t head() const {
|
||||
return idxs[0];
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
|
||||
return idxs[0][0];
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
seq_id_virt.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == seq_id_virt.size());
|
||||
|
||||
return idxs[0].size();
|
||||
}
|
||||
|
||||
size_t n_seq_virt() const {
|
||||
return seq_id_virt.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
@@ -190,9 +208,9 @@ private:
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t v_heads[LLAMA_MAX_SEQ];
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
llama_kv_cells_unified v_cells[LLAMA_MAX_SEQ];
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
std::vector<uint32_t> seq_virt_idx;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user