mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
cont : support non-continuous slots
ggml-ci
This commit is contained in:
@@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|||||||
bool success = true;
|
bool success = true;
|
||||||
|
|
||||||
for (const auto & ubatch : ubatches) {
|
for (const auto & ubatch : ubatches) {
|
||||||
|
// non-continuous slots require support for ggml_set_rows()
|
||||||
|
const bool cont = supports_set_rows ? false : true;
|
||||||
|
|
||||||
// only find a suitable slot for the ubatch. don't modify the cells yet
|
// only find a suitable slot for the ubatch. don't modify the cells yet
|
||||||
const auto sinfo_new = find_slot(ubatch);
|
const auto sinfo_new = find_slot(ubatch, cont);
|
||||||
if (sinfo_new.empty()) {
|
if (sinfo_new.empty()) {
|
||||||
success = false;
|
success = false;
|
||||||
break;
|
break;
|
||||||
@@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
uint32_t head_cur = this->head;
|
uint32_t head_cur = this->head;
|
||||||
@@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t n_found = 0;
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
|
||||||
|
const uint32_t n_test = cont ? n_tokens : 1;
|
||||||
|
|
||||||
|
slot_info res;
|
||||||
|
|
||||||
|
res.idxs.resize(n_tokens);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (head_cur + n_tokens > cells.size()) {
|
if (head_cur + n_test > cells.size()) {
|
||||||
n_tested += cells.size() - head_cur;
|
n_tested += cells.size() - head_cur;
|
||||||
head_cur = 0;
|
head_cur = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool found = true;
|
for (uint32_t i = 0; i < n_test; i++) {
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
const auto idx = head_cur;
|
||||||
|
|
||||||
//const llama_pos pos = ubatch.pos[i];
|
//const llama_pos pos = ubatch.pos[i];
|
||||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
|
|
||||||
@@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||||
// - mask SWA, using current max pos for that sequence in the cache
|
// - mask SWA, using current max pos for that sequence in the cache
|
||||||
// always insert in the cell with minimum pos
|
// always insert in the cell with minimum pos
|
||||||
bool can_use = cells.is_empty(head_cur + i);
|
bool can_use = cells.is_empty(idx);
|
||||||
|
|
||||||
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
if (!can_use && cells.seq_count(idx) == 1) {
|
||||||
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
const llama_pos pos_cell = cells.pos_get(idx);
|
||||||
|
|
||||||
// (disabled) causal mask
|
// (disabled) causal mask
|
||||||
// note: it's better to purge any "future" tokens beforehand
|
// note: it's better to purge any "future" tokens beforehand
|
||||||
//if (cells.seq_has(head_cur + i, seq_id)) {
|
//if (cells.seq_has(idx, seq_id)) {
|
||||||
// can_use = pos_cell >= pos;
|
// can_use = pos_cell >= pos;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (!can_use) {
|
if (!can_use) {
|
||||||
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||||
|
|
||||||
// SWA mask
|
// SWA mask
|
||||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||||
@@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!can_use) {
|
head_cur++;
|
||||||
found = false;
|
n_tested++;
|
||||||
head_cur += i + 1;
|
|
||||||
n_tested += i + 1;
|
if (can_use) {
|
||||||
|
res.idxs[n_found] = idx;
|
||||||
|
|
||||||
|
n_found++;
|
||||||
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (found) {
|
if (n_found == n_tokens) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cont) {
|
||||||
|
n_found = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (n_tested >= cells.size()) {
|
if (n_tested >= cells.size()) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return { };
|
return { };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot_info res;
|
// we didn't find a suitable slot - return empty result
|
||||||
|
if (n_found < n_tokens) {
|
||||||
res.idxs.resize(n_tokens);
|
res.clear();
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
|
||||||
res.idxs[i] = head_cur + i;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
@@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
ubatch.seq_id[i] = &dest_seq_id;
|
ubatch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto sinfo = find_slot(ubatch);
|
const auto sinfo = find_slot(ubatch, true);
|
||||||
if (sinfo.empty()) {
|
if (sinfo.empty()) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ public:
|
|||||||
return idxs.empty();
|
return idxs.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
idxs.clear();
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: implement
|
// TODO: implement
|
||||||
//std::vector<idx_vec_t> seq_idxs;
|
//std::vector<idx_vec_t> seq_idxs;
|
||||||
};
|
};
|
||||||
@@ -133,14 +137,10 @@ public:
|
|||||||
|
|
||||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
// find a continuous slot of kv cells that can hold the ubatch
|
// find a slot of kv cells that can hold the ubatch
|
||||||
// return the cell position where we can insert the ubatch
|
// if cont == true, then the slot must be continuous
|
||||||
// return -1 on failure to find a slot
|
// return empty slot_info on failure
|
||||||
slot_info find_slot(const llama_ubatch & ubatch) const;
|
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||||
|
|
||||||
// find a set of kv cells that can hold the ubatch
|
|
||||||
// TODO: implement
|
|
||||||
//slot_info find_slot_ext(const llama_ubatch & ubatch) const;
|
|
||||||
|
|
||||||
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||||
|
|||||||
Reference in New Issue
Block a user