cont : support non-continuous slots

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-21 16:23:31 +03:00
parent 39d0b1e8df
commit 332f073589
2 changed files with 46 additions and 29 deletions

View File

@@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
bool success = true; bool success = true;
for (const auto & ubatch : ubatches) { for (const auto & ubatch : ubatches) {
// non-continuous slots require support for ggml_set_rows()
const bool cont = supports_set_rows ? false : true;
// only find a suitable slot for the ubatch. don't modify the cells yet // only find a suitable slot for the ubatch. don't modify the cells yet
const auto sinfo_new = find_slot(ubatch); const auto sinfo_new = find_slot(ubatch, cont);
if (sinfo_new.empty()) { if (sinfo_new.empty()) {
success = false; success = false;
break; break;
@@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
return updated; return updated;
} }
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head; uint32_t head_cur = this->head;
@@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
} }
} }
uint32_t n_found = 0;
uint32_t n_tested = 0; uint32_t n_tested = 0;
const uint32_t n_test = cont ? n_tokens : 1;
slot_info res;
res.idxs.resize(n_tokens);
while (true) { while (true) {
if (head_cur + n_tokens > cells.size()) { if (head_cur + n_test > cells.size()) {
n_tested += cells.size() - head_cur; n_tested += cells.size() - head_cur;
head_cur = 0; head_cur = 0;
continue; continue;
} }
bool found = true; for (uint32_t i = 0; i < n_test; i++) {
for (uint32_t i = 0; i < n_tokens; i++) { const auto idx = head_cur;
//const llama_pos pos = ubatch.pos[i]; //const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0]; //const llama_seq_id seq_id = ubatch.seq_id[i][0];
@@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
// - (disabled) mask causally, if the sequence is the same as the one we are inserting // - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache // - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos // always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i); bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(head_cur + i) == 1) { if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i); const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask // (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand // note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(head_cur + i, seq_id)) { //if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos; // can_use = pos_cell >= pos;
//} //}
if (!can_use) { if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask // SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
} }
} }
if (!can_use) { head_cur++;
found = false; n_tested++;
head_cur += i + 1;
n_tested += i + 1; if (can_use) {
res.idxs[n_found] = idx;
n_found++;
} else {
break; break;
} }
} }
if (found) { if (n_found == n_tokens) {
break; break;
} }
if (cont) {
n_found = 0;
}
if (n_tested >= cells.size()) { if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return { }; return { };
} }
} }
slot_info res; // we didn't find a suitable slot - return empty result
if (n_found < n_tokens) {
res.idxs.resize(n_tokens); res.clear();
for (uint32_t i = 0; i < n_tokens; ++i) {
res.idxs[i] = head_cur + i;
} }
return res; return res;
@@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
ubatch.seq_id[i] = &dest_seq_id; ubatch.seq_id[i] = &dest_seq_id;
} }
const auto sinfo = find_slot(ubatch); const auto sinfo = find_slot(ubatch, true);
if (sinfo.empty()) { if (sinfo.empty()) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;

View File

@@ -49,6 +49,10 @@ public:
return idxs.empty(); return idxs.empty();
} }
void clear() {
idxs.clear();
}
// TODO: implement // TODO: implement
//std::vector<idx_vec_t> seq_idxs; //std::vector<idx_vec_t> seq_idxs;
}; };
@@ -133,14 +137,10 @@ public:
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// find a continuous slot of kv cells that can hold the ubatch // find a slot of kv cells that can hold the ubatch
// return the cell position where we can insert the ubatch // if cont == true, then the slot must be continuous
// return -1 on failure to find a slot // return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch) const; slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// find a set of kv cells that can hold the ubatch
// TODO: implement
//slot_info find_slot_ext(const llama_ubatch & ubatch) const;
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);