cont : migrate to using set of indices instead of slot head

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-21 11:57:07 +03:00
parent db2bb378b1
commit f875d6cb72
6 changed files with 144 additions and 86 deletions

View File

@@ -334,13 +334,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads = prepare(ubatches);
if (heads.empty()) {
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break;
}
return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches));
this, std::move(sinfos), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -383,8 +383,8 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
}
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::ubatch_heads res;
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::slot_info_vec_t res;
struct state {
uint32_t head_old; // old position of the head, before placing the ubatch
@@ -400,20 +400,25 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
for (const auto & ubatch : ubatches) {
// only find a suitable slot for the ubatch. don't modify the cells yet
const int32_t head_new = find_slot(ubatch);
if (head_new < 0) {
const auto sinfo_new = find_slot(ubatch);
if (sinfo_new.empty()) {
success = false;
break;
}
// remeber the position that we found
res.push_back(head_new);
res.push_back(sinfo_new);
// TODO: temporary
if (supports_set_rows) {
GGML_ASSERT(sinfo_new.is_cont());
}
// store the old state of the cells in the recovery stack
states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)});
// now emplace the ubatch
apply_ubatch(head_new, ubatch);
apply_ubatch(sinfo_new, ubatch);
}
// iterate backwards and restore the cells to their original state
@@ -520,7 +525,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated;
}
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head;
@@ -533,7 +538,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return -1;
return { };
}
if (debug > 0) {
@@ -649,14 +654,21 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return -1;
return { };
}
}
return head_cur;
slot_info res;
res.idxs.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; ++i) {
res.idxs[i] = head_cur + i;
}
return res;
}
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -664,22 +676,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1;
}
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
assert(ubatch.n_tokens == sinfo.idxs.size());
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
const llama_pos pos = cells.pos_get(head_cur + i);
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto idx = sinfo.idxs[i];
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i);
cells.rm(idx);
}
cells.pos_set(head_cur + i, ubatch.pos[i]);
cells.pos_set(idx, ubatch.pos[i]);
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
}
@@ -700,7 +716,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
}
// move the head at the end of the slot
head = head_cur + ubatch.n_tokens;
head = sinfo.idxs.back() + 1;
}
bool llama_kv_cache_unified::get_can_shift() const {
@@ -753,7 +769,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
0);
}
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * k = layers[ikv].k;
@@ -772,12 +788,12 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
ggml_tensor * k_view = ggml_view_1d(ctx, k,
n_tokens*n_embd_k_gqa,
ggml_row_size(k->type, n_embd_k_gqa)*head_cur);
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
return ggml_cpy(ctx, k_cur, k_view);
}
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * v = layers[ikv].v;
@@ -814,19 +830,19 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
if (!v_trans) {
v_view = ggml_view_1d(ctx, v,
n_tokens*n_embd_v_gqa,
ggml_row_size(v->type, n_embd_v_gqa)*head_cur);
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
} else {
v_cur = ggml_transpose(ctx, v_cur);
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
(v->ne[1])*ggml_element_size(v),
(head_cur)*ggml_element_size(v));
(v->ne[1] )*ggml_element_size(v),
(sinfo.head())*ggml_element_size(v));
}
return ggml_cpy(ctx, v_cur, v_view);
}
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
if (!supports_set_rows) {
return;
}
@@ -837,7 +853,7 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) {
data[i] = head_cur + i;
data[i] = sinfo.idxs[i];
}
}
@@ -1580,13 +1596,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
ubatch.seq_id[i] = &dest_seq_id;
}
const auto head_cur = find_slot(ubatch);
if (head_cur < 0) {
const auto sinfo = find_slot(ubatch);
if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
apply_ubatch(head_cur, ubatch);
apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();
// keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur;
@@ -1772,7 +1790,10 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size();
head = 0;
sinfos.resize(1);
sinfos[0].idxs.resize(1);
sinfos[0].idxs[0] = 0;
}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@@ -1787,8 +1808,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
llama_kv_cache_unified::slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
}
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@@ -1796,7 +1817,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) {
if (++i_cur >= ubatches.size()) {
return false;
}
@@ -1813,10 +1834,9 @@ bool llama_kv_cache_unified_context::apply() {
return true;
}
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
n_kv = kv->get_n_kv();
head = heads[i_next];
return true;
}
@@ -1828,7 +1848,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
return ubatches[i_cur];
}
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@@ -1844,11 +1864,11 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
}
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
@@ -1856,7 +1876,7 @@ void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const
}
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_kv_idxs(dst, ubatch, head);
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
}
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {