diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 19bc16cd54..8fac081d11 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -400,8 +400,11 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st bool success = true; for (const auto & ubatch : ubatches) { + // non-continuous slots require support for ggml_set_rows() + const bool cont = supports_set_rows ? false : true; + // only find a suitable slot for the ubatch. don't modify the cells yet - const auto sinfo_new = find_slot(ubatch); + const auto sinfo_new = find_slot(ubatch, cont); if (sinfo_new.empty()) { success = false; break; @@ -521,7 +524,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d return updated; } -llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { +llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { const uint32_t n_tokens = ubatch.n_tokens; uint32_t head_cur = this->head; @@ -595,17 +598,25 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } + uint32_t n_found = 0; uint32_t n_tested = 0; + const uint32_t n_test = cont ? n_tokens : 1; + + slot_info res; + + res.idxs.resize(n_tokens); + while (true) { - if (head_cur + n_tokens > cells.size()) { + if (head_cur + n_test > cells.size()) { n_tested += cells.size() - head_cur; head_cur = 0; continue; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + //const llama_pos pos = ubatch.pos[i]; //const llama_seq_id seq_id = ubatch.seq_id[i][0]; @@ -615,19 +626,19 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ // - (disabled) mask causally, if the sequence is the same as the one we are inserting // - mask SWA, using current max pos for that sequence in the cache // always insert in the cell with minimum pos - bool can_use = cells.is_empty(head_cur + i); + bool can_use = cells.is_empty(idx); - if (!can_use && cells.seq_count(head_cur + i) == 1) { - const llama_pos pos_cell = cells.pos_get(head_cur + i); + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); // (disabled) causal mask // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(head_cur + i, seq_id)) { + //if (cells.seq_has(idx, seq_id)) { // can_use = pos_cell >= pos; //} if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { @@ -636,29 +647,35 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - if (!can_use) { - found = false; - head_cur += i + 1; - n_tested += i + 1; + head_cur++; + n_tested++; + + if (can_use) { + res.idxs[n_found] = idx; + + n_found++; + } else { break; } } - if (found) { + if (n_found == n_tokens) { break; } + if (cont) { + n_found = 0; + } + if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return { }; } } - slot_info res; - - res.idxs.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; ++i) { - res.idxs[i] = head_cur + i; + // we didn't find a suitable slot - return empty result + if (n_found < n_tokens) { + res.clear(); } return res; @@ -1592,7 +1609,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell ubatch.seq_id[i] = &dest_seq_id; } - const auto sinfo = find_slot(ubatch); + const auto sinfo = find_slot(ubatch, true); if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 6edf7b588c..1f72afd21c 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -49,6 +49,10 @@ public: return idxs.empty(); } + void clear() { + idxs.clear(); + } + // TODO: implement //std::vector seq_idxs; }; @@ -133,14 +137,10 @@ public: bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); - // find a continuous slot of kv cells that can hold the ubatch - // return the cell position where we can insert the ubatch - // return -1 on failure to find a slot - slot_info find_slot(const llama_ubatch & ubatch) const; - - // find a set of kv cells that can hold the ubatch - // TODO: implement - //slot_info find_slot_ext(const llama_ubatch & ubatch) const; + // find a slot of kv cells that can hold the ubatch + // if cont == true, then the slot must be continuous + // return empty slot_info on failure + slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);