mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cache : rework kv_cell (#13706)
* kv-cache : rework kv_cell ggml-ci * kv-cells : use "shift" instead of "delta" consistently ggml-ci * llama : add llama_max_parallel_sequences() ggml-ci * kv-cells : update comments [no ci] * context : fail upon construction if sequences exceed max value ggml-ci * kv-cells : get_pos() -> pos_get() + comments ggml-ci * kv-cells : fix tracking of "used" cells ggml-ci
This commit is contained in:
		| @@ -471,6 +471,7 @@ extern "C" { | ||||
|     LLAMA_API int64_t llama_time_us(void); | ||||
|  | ||||
|     LLAMA_API size_t llama_max_devices(void); | ||||
|     LLAMA_API size_t llama_max_parallel_sequences(void); | ||||
|  | ||||
|     LLAMA_API bool llama_supports_mmap       (void); | ||||
|     LLAMA_API bool llama_supports_mlock      (void); | ||||
|   | ||||
| @@ -26,6 +26,10 @@ llama_context::llama_context( | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
|     cparams.n_seq_max = std::max(1u, params.n_seq_max); | ||||
|     if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { | ||||
|         throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); | ||||
|     } | ||||
|  | ||||
|     cparams.n_threads        = params.n_threads; | ||||
|     cparams.n_threads_batch  = params.n_threads_batch; | ||||
|     cparams.yarn_ext_factor  = params.yarn_ext_factor; | ||||
|   | ||||
| @@ -1 +1,5 @@ | ||||
| #include "llama-cparams.h" | ||||
|  | ||||
| size_t llama_max_parallel_sequences(void) { | ||||
|     return LLAMA_MAX_PARALLEL_SEQUENCES; | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,8 @@ | ||||
|  | ||||
| #include <cstdint> | ||||
|  | ||||
| #define LLAMA_MAX_PARALLEL_SEQUENCES 64 | ||||
|  | ||||
| struct llama_cparams { | ||||
|     uint32_t n_ctx;           // context size used during inference | ||||
|     uint32_t n_batch; | ||||
|   | ||||
| @@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified( | ||||
|     }; | ||||
|  | ||||
|     head = 0; | ||||
|     size = kv_size; | ||||
|     used = 0; | ||||
|  | ||||
|     cells.resize(kv_size); | ||||
|  | ||||
| @@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::clear() { | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         cells[i].pos = -1; | ||||
|         cells[i].seq_id.clear(); | ||||
|     } | ||||
|     cells.reset(); | ||||
|  | ||||
|     head = 0; | ||||
|     used = 0; | ||||
|  | ||||
|     for (auto & buf : bufs) { | ||||
|         ggml_backend_buffer_clear(buf.get(), 0); | ||||
| @@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() { | ||||
| } | ||||
|  | ||||
| bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | ||||
|     uint32_t new_head = size; | ||||
|     uint32_t new_head = cells.size(); | ||||
|  | ||||
|     if (p0 < 0) { | ||||
|         p0 = 0; | ||||
| @@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos | ||||
|         p1 = std::numeric_limits<llama_pos>::max(); | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].pos >= p0 && cells[i].pos < p1) { | ||||
|             if (seq_id < 0) { | ||||
|                 cells[i].seq_id.clear(); | ||||
|             } else if (cells[i].has_seq_id(seq_id)) { | ||||
|                 cells[i].seq_id.erase(seq_id); | ||||
|             } else { | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.pos_in(i, p0, p1)) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|             if (cells[i].is_empty()) { | ||||
|                 // keep count of the number of used cells | ||||
|                 if (cells[i].pos >= 0) { | ||||
|                     used--; | ||||
|                 } | ||||
|  | ||||
|                 cells[i].pos = -1; | ||||
|  | ||||
|                 if (new_head == size) { | ||||
|         if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { | ||||
|             if (new_head == cells.size()) { | ||||
|                 new_head = i; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     } | ||||
|  | ||||
|     // If we freed up a slot, set head to it so searching can start there. | ||||
|     if (new_head != size && new_head < head) { | ||||
|     if (new_head != cells.size() && new_head < head) { | ||||
|         head = new_head; | ||||
|     } | ||||
|  | ||||
| @@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id | ||||
|         p1 = std::numeric_limits<llama_pos>::max(); | ||||
|     } | ||||
|  | ||||
|     // otherwise, this is the KV of a Transformer-like model | ||||
|     head = 0; | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.pos_in(i, p0, p1)) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { | ||||
|             cells[i].seq_id.insert(seq_id_dst); | ||||
|         if (cells.seq_has(i, seq_id_src)) { | ||||
|             cells.seq_add(i, seq_id_dst); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { | ||||
|     uint32_t new_head = size; | ||||
|     uint32_t new_head = cells.size(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (!cells[i].has_seq_id(seq_id)) { | ||||
|             if (cells[i].pos >= 0) { | ||||
|                 used--; | ||||
|             } | ||||
|  | ||||
|             cells[i].pos = -1; | ||||
|             cells[i].seq_id.clear(); | ||||
|  | ||||
|             if (new_head == size){ | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (cells.seq_keep(i, seq_id)) { | ||||
|             if (new_head == cells.size()) { | ||||
|                 new_head = i; | ||||
|             } | ||||
|         } else { | ||||
|             cells[i].seq_id.clear(); | ||||
|             cells[i].seq_id.insert(seq_id); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // If we freed up a slot, set head to it so searching can start there. | ||||
|     if (new_head != size && new_head < head) { | ||||
|     if (new_head != cells.size() && new_head < head) { | ||||
|         head = new_head; | ||||
|     } | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | ||||
|     if (delta == 0) { | ||||
| void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||
|     if (shift == 0) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     uint32_t new_head = size; | ||||
|     uint32_t new_head = cells.size(); | ||||
|  | ||||
|     if (p0 < 0) { | ||||
|         p0 = 0; | ||||
| @@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po | ||||
|         p1 = std::numeric_limits<llama_pos>::max(); | ||||
|     } | ||||
|  | ||||
|     // If there is no range then return early to avoid looping over the | ||||
|     // If there is no range then return early to avoid looping over all cells. | ||||
|     if (p0 == p1) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { | ||||
|             has_shift = true; | ||||
|  | ||||
|             cells[i].pos   += delta; | ||||
|             cells[i].delta += delta; | ||||
|  | ||||
|             if (cells[i].pos < 0) { | ||||
|                 if (!cells[i].is_empty()) { | ||||
|                     used--; | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.pos_in(i, p0, p1)) { | ||||
|             continue; | ||||
|         } | ||||
|                 cells[i].pos = -1; | ||||
|                 cells[i].seq_id.clear(); | ||||
|                 if (new_head == size) { | ||||
|  | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             if (cells.pos_add(i, shift)) { | ||||
|                 if (new_head == cells.size()) { | ||||
|                     new_head = i; | ||||
|                 } | ||||
|             } | ||||
| @@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po | ||||
|  | ||||
|     // If we freed up a slot, set head to it so searching can start there. | ||||
|     // Otherwise we just start the next search from the beginning. | ||||
|     head = new_head != size ? new_head : 0; | ||||
|     head = new_head != cells.size() ? new_head : 0; | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | ||||
| @@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { | ||||
|             has_shift = true; | ||||
|  | ||||
|             { | ||||
|                 llama_pos p_old = cells[i].pos; | ||||
|                 cells[i].pos   /= d; | ||||
|                 cells[i].delta += cells[i].pos - p_old; | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.pos_in(i, p0, p1)) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             cells.pos_div(i, d); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po | ||||
| llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | ||||
|     llama_pos result = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].has_seq_id(seq_id)) { | ||||
|             result = std::min(result, cells[i].pos); | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             result = std::min(result, cells.pos_get(i)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | ||||
| llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | ||||
|     llama_pos result = -1; | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         if (cells[i].has_seq_id(seq_id)) { | ||||
|             result = std::max(result, cells[i].pos); | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             result = std::max(result, cells.pos_get(i)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::restore() { | ||||
|     for (const auto & [id, cell] : recovery.cells) { | ||||
|         // TODO: move to new `struct kv_cells` | ||||
|         const bool is_empty0 = cells[id].is_empty(); | ||||
|         const bool is_empty1 = cell.is_empty(); | ||||
|  | ||||
|         if (!is_empty0 && is_empty1) { | ||||
|             used--; | ||||
|         } else if (is_empty0 && !is_empty1) { | ||||
|             used++; | ||||
|         } | ||||
|  | ||||
|         cells[id] = cell; | ||||
|     for (auto & state : recovery.states) { | ||||
|         cells.set(state.i, state.cells); | ||||
|     } | ||||
|  | ||||
|     recovery.clear(); | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::commit() { | ||||
|     if (recovery.cells.empty()) { | ||||
|     if (recovery.states.empty()) { | ||||
|         LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", | ||||
|                 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); | ||||
|         return; | ||||
| @@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | ||||
|  | ||||
|     auto * sched = lctx.get_sched(); | ||||
|  | ||||
|     if (has_shift) { | ||||
|     if (cells.get_has_shift()) { | ||||
|         if (!get_can_shift()) { | ||||
|             GGML_ABORT("The current KV cache / model configuration does not support K-shift"); | ||||
|         } | ||||
| @@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | ||||
|             need_reserve = true; | ||||
|         } | ||||
|  | ||||
|         { | ||||
|             has_shift = false; | ||||
|  | ||||
|             for (uint32_t i = 0; i < size; ++i) { | ||||
|                 cells[i].delta = 0; | ||||
|             } | ||||
|         } | ||||
|         cells.reset_shift(); | ||||
|     } | ||||
|  | ||||
|     if (do_defrag) { | ||||
| @@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | ||||
| void llama_kv_cache_unified::defrag_sched(float thold) { | ||||
|     // - do not defrag small contexts (i.e. < 2048 tokens) | ||||
|     // - count the padding towards the number of used tokens | ||||
|     const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; | ||||
|     const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; | ||||
|  | ||||
|     // queue defragmentation for next llama_kv_cache_update | ||||
|     if (fragmentation > thold) { | ||||
| @@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) { | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::set_full() { | ||||
|     n = size; | ||||
|     n = cells.size(); | ||||
|  | ||||
|     // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not | ||||
|     //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. | ||||
| @@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | ||||
|  | ||||
|     // if we have enough unused cells before the current head -> | ||||
|     //   better to start searching from the beginning of the cache, hoping to fill it | ||||
|     if (head > used + 2*ubatch.n_tokens) { | ||||
|     if (head > cells.get_used() + 2*ubatch.n_tokens) { | ||||
|         head = 0; | ||||
|     } | ||||
|  | ||||
|     // otherwise, one cell per token. | ||||
|  | ||||
|     if (n_tokens > size) { | ||||
|         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); | ||||
|     if (n_tokens > cells.size()) { | ||||
|         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
| @@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | ||||
|         std::string ss; | ||||
|         if (n_swa > 0) { | ||||
|             for (uint32_t i = 0; i < size; ++i) { | ||||
|                 if (cells[i].pos == -1) { | ||||
|                 if (cells.is_empty(i)) { | ||||
|                     ss += '.'; | ||||
|                 } else { | ||||
|                     ss += std::to_string(*cells[i].seq_id.begin()); | ||||
|                     ss += 'x'; | ||||
|                 } | ||||
|                 if (i%256 == 255) { | ||||
|                     ss += '\n'; | ||||
| @@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | ||||
|     uint32_t n_tested = 0; | ||||
|  | ||||
|     while (true) { | ||||
|         if (head + n_tokens > size) { | ||||
|             n_tested += size - head; | ||||
|         if (head + n_tokens > cells.size()) { | ||||
|             n_tested += cells.size() - head; | ||||
|             head = 0; | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         bool found = true; | ||||
|         for (uint32_t i = 0; i < n_tokens; i++) { | ||||
|             if (cells[head + i].pos >= 0) { | ||||
|             // TODO: improve to accept cells that are masked by the SWA | ||||
|             if (!cells.is_empty(head + i)) { | ||||
|                 found = false; | ||||
|                 head     += i + 1; | ||||
|                 n_tested += i + 1; | ||||
| @@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         if (n_tested >= size) { | ||||
|         if (n_tested >= cells.size()) { | ||||
|             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for (uint32_t i = 0; i < n_tokens; ++i) { | ||||
|         // remember the original state | ||||
|         if (recovery.cells.find(head + i) == recovery.cells.end()) { | ||||
|             recovery.cells[head + i] = cells[head + i]; | ||||
|         } | ||||
|     // store the old state of the cells in the recovery stack | ||||
|     recovery.states.push_back({head, cells.cp(head, n_tokens)}); | ||||
|  | ||||
|         cells[head + i].pos = ubatch.pos[i]; | ||||
|     for (uint32_t i = 0; i < n_tokens; ++i) { | ||||
|         cells.pos_set(head + i, ubatch.pos[i]); | ||||
|  | ||||
|         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { | ||||
|             cells[head + i].seq_id.insert(ubatch.seq_id[i][j]); | ||||
|             cells.seq_add(head + i, ubatch.seq_id[i][j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     used += n_tokens; | ||||
|  | ||||
|     // a heuristic, to avoid attending the full cache if it is not yet utilized | ||||
|     // after enough generations, the benefit from this heuristic disappears | ||||
|     // if we start defragmenting the cache, the benefit from this will be more important | ||||
|     n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); | ||||
|     n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); | ||||
|  | ||||
| #ifdef FIND_SLOT_DEBUG | ||||
|     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); | ||||
| @@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const { | ||||
| } | ||||
|  | ||||
| uint32_t llama_kv_cache_unified::get_size() const { | ||||
|     return size; | ||||
|     return cells.size(); | ||||
| } | ||||
|  | ||||
| ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { | ||||
| @@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam | ||||
|  | ||||
|     int n_attended = 0; | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         const llama_pos p0 = cells[i].pos; | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.seq_has(i, seq_id)) { | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         const llama_pos p0 = cells.pos_get(i); | ||||
|  | ||||
|         if (p0 <= pmin && !is_masked_swa(p0, pmin)) { | ||||
|             n_attended++; | ||||
|         } | ||||
|  | ||||
|         if (is_masked_swa(p0, pmax)) { | ||||
|             if (seq_id < 0) { | ||||
|                 cells[i].seq_id.clear(); | ||||
|             } else if (cells[i].has_seq_id(seq_id)) { | ||||
|                 cells[i].seq_id.erase(seq_id); | ||||
|             } else { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             if (cells[i].is_empty()) { | ||||
|                 // keep count of the number of used cells | ||||
|                 if (cells[i].pos >= 0) { | ||||
|                     used--; | ||||
|                 } | ||||
|  | ||||
|                 cells[i].pos = -1; | ||||
|             } | ||||
|             cells.seq_rm(i, seq_id); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -723,12 +657,17 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | ||||
|                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; | ||||
|  | ||||
|                 for (int i = 0; i < n_kv; ++i) { | ||||
|                     const llama_pos p0 = cells[i].pos; | ||||
|                     float f = 0.0f; | ||||
|  | ||||
|                     bool masked = false; | ||||
|  | ||||
|                     if (cells.is_empty(i)) { | ||||
|                         masked = true; | ||||
|                     } else { | ||||
|                         const llama_pos p0 = cells.pos_get(i); | ||||
|  | ||||
|                         // mask the token if not the same sequence | ||||
|                     masked = masked || (!cells[i].has_seq_id(seq_id)); | ||||
|                         masked = masked || (!cells.seq_has(i, seq_id)); | ||||
|  | ||||
|                         // mask future tokens | ||||
|                         masked = masked || (causal_attn && p0 > p1); | ||||
| @@ -736,12 +675,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | ||||
|                         // apply SWA if any | ||||
|                         masked = masked || (is_masked_swa(p0, p1)); | ||||
|  | ||||
|                     float f = 0.0f; | ||||
|                         if (!masked && hparams.use_alibi) { | ||||
|                             f = -std::abs(p0 - p1); | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     if (masked) { | ||||
|                         f = -INFINITY; | ||||
|                     } else if (hparams.use_alibi) { | ||||
|                         f = -std::abs(p0 - p1); | ||||
|                     } | ||||
|  | ||||
|                     data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; | ||||
| @@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { | ||||
|  | ||||
|     int32_t * data = (int32_t *) dst->data; | ||||
|  | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         data[i] = cells[i].delta; | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama | ||||
|     for (int h = 0; h < 1; ++h) { | ||||
|         for (int j = 0; j < n_tokens; ++j) { | ||||
|             for (int i = 0; i < n_kv; ++i) { | ||||
|                 data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); | ||||
|                 // the position when the cells is empty is irrelevant - it will be masked out later in the attention | ||||
|                 const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); | ||||
|  | ||||
|                 data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( | ||||
|  | ||||
|         ggml_tensor * k = | ||||
|             ggml_view_3d(ctx, layer.k, | ||||
|                 n_embd_head_k, n_head_kv, size, | ||||
|                 n_embd_head_k, n_head_kv, cells.size(), | ||||
|                 ggml_row_size(layer.k->type, n_embd_head_k), | ||||
|                 ggml_row_size(layer.k->type, n_embd_k_gqa), | ||||
|                 0); | ||||
| @@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | ||||
|             } else { | ||||
|                 view_v_src = ggml_view_2d(ctx, layer.v, | ||||
|                         nm, n_embd_v_gqa, | ||||
|                         ggml_row_size(layer.v->type, size), | ||||
|                         ggml_row_size(layer.v->type, cells.size()), | ||||
|                         ggml_row_size(layer.v->type, i)); | ||||
|  | ||||
|                 view_v_dst = ggml_view_2d(ctx, layer.v, | ||||
|                         nm, n_embd_v_gqa, | ||||
|                         ggml_row_size(layer.v->type, size), | ||||
|                         ggml_row_size(layer.v->type, cells.size()), | ||||
|                         ggml_row_size(layer.v->type, id)); | ||||
|             } | ||||
|  | ||||
| @@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|     const uint32_t n_layer = layers.size(); | ||||
|  | ||||
|     const uint32_t n_kv   = cell_max(); | ||||
|     const uint32_t n_used = used; | ||||
|     const uint32_t n_used = cells.get_used(); | ||||
|  | ||||
|     assert(n_used <= n_kv); | ||||
|  | ||||
| @@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|     ids.resize(n_kv, n_kv); | ||||
|  | ||||
|     for (uint32_t i0 = 0; i0 < n_used; ++i0) { | ||||
|         const auto & cell0 = cells[i0]; | ||||
|  | ||||
|         if (!cell0.is_empty()) { | ||||
|         if (!cells.is_empty(i0)) { | ||||
|             ids[i0] = i0; | ||||
|  | ||||
|             continue; | ||||
| @@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|         uint32_t nh = 1; | ||||
|  | ||||
|         // determine the size of the hole | ||||
|         while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { | ||||
|         while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { | ||||
|             nh++; | ||||
|         } | ||||
|  | ||||
| @@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|  | ||||
|         // starting from the end, find nh non-empty cells | ||||
|         for (; is > i0; --is) { | ||||
|             const auto & cell1 = cells[is]; | ||||
|  | ||||
|             if (cell1.is_empty() || ids[is] != n_kv) { | ||||
|             if (cells.is_empty(is) || ids[is] != n_kv) { | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
| @@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|  | ||||
|         // go back and move the nf cells to the hole | ||||
|         for (; i1 < n_kv; ++i1) { | ||||
|             auto & cell1 = cells[i1]; | ||||
|  | ||||
|             if (cell1.is_empty() || ids[i1] != n_kv) { | ||||
|             if (cells.is_empty(i1) || ids[i1] != n_kv) { | ||||
|                 if (n_moves == max_moves) { | ||||
|                     stop = true; | ||||
|                     break; | ||||
| @@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|             ids[i1] = i0 + nf; | ||||
|  | ||||
|             // move the cell meta data | ||||
|             cells[i0 + nf] = cell1; | ||||
|             cells.mv(i1, i0 + nf); | ||||
|  | ||||
|             // clear the old cell and move the head there | ||||
|             cell1 = kv_cell(); | ||||
|             head = n_used; | ||||
|  | ||||
|             if (!cont) { | ||||
| @@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
| } | ||||
|  | ||||
| uint32_t llama_kv_cache_unified::cell_max() const { | ||||
|     for (uint32_t i = size; i > 0; --i) { | ||||
|         const kv_cell & cell = cells[i - 1]; | ||||
|  | ||||
|         if (cell.pos >= 0 && !cell.is_empty()) { | ||||
|     for (uint32_t i = cells.size(); i > 0; --i) { | ||||
|         if (!cells.is_empty(i - 1)) { | ||||
|             return i; | ||||
|         } | ||||
|     } | ||||
| @@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const { | ||||
| } | ||||
|  | ||||
| bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | ||||
|     if (p0 < 0) { | ||||
|         return true; | ||||
|     } | ||||
|     assert(p0 >= 0 && p1 >= 0); | ||||
|  | ||||
|     switch (swa_type) { | ||||
|         case LLAMA_SWA_TYPE_NONE: | ||||
| @@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq | ||||
|  | ||||
|     // Count the number of cells with the specified seq_id | ||||
|     // Find all the ranges of cells with this seq id (or all, when -1) | ||||
|     uint32_t cell_range_begin = size; | ||||
|     for (uint32_t i = 0; i < size; ++i) { | ||||
|         const auto & cell = cells[i]; | ||||
|         if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { | ||||
|     uint32_t cell_range_begin = cells.size(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { | ||||
|             ++cell_count; | ||||
|             if (cell_range_begin == size) { | ||||
|             if (cell_range_begin == cells.size()) { | ||||
|                 cell_range_begin = i; | ||||
|             } | ||||
|         } else { | ||||
|             if (cell_range_begin != size) { | ||||
|             if (cell_range_begin != cells.size()) { | ||||
|                 cell_ranges.emplace_back(cell_range_begin, i); | ||||
|                 cell_range_begin = size; | ||||
|                 cell_range_begin = cells.size(); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     if (cell_range_begin != size) { | ||||
|         cell_ranges.emplace_back(cell_range_begin, size); | ||||
|  | ||||
|     if (cell_range_begin != cells.size()) { | ||||
|         cell_ranges.emplace_back(cell_range_begin, cells.size()); | ||||
|     } | ||||
|  | ||||
|     // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count | ||||
| @@ -1308,20 +1240,27 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i | ||||
| void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { | ||||
|     for (const auto & range : cell_ranges) { | ||||
|         for (uint32_t i = range.first; i < range.second; ++i) { | ||||
|             const auto & cell = cells[i]; | ||||
|             const llama_pos pos      = cell.pos; | ||||
|             const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; | ||||
|             std::vector<llama_seq_id> seq_ids; | ||||
|  | ||||
|             for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { | ||||
|                 if (cur == seq_id || seq_id == -1) { | ||||
|                     if (cells.seq_has(i, cur)) { | ||||
|                         seq_ids.push_back(cur); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             const llama_pos pos     = cells.pos_get(i); | ||||
|             const uint32_t n_seq_id = seq_ids.size(); | ||||
|  | ||||
|             io.write(&pos,      sizeof(pos)); | ||||
|             io.write(&n_seq_id, sizeof(n_seq_id)); | ||||
|  | ||||
|             if (n_seq_id) { | ||||
|                 for (auto seq_id : cell.seq_id) { | ||||
|             for (const auto & seq_id : seq_ids) { | ||||
|                 io.write(&seq_id, sizeof(seq_id)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     } | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { | ||||
| @@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: | ||||
|         } | ||||
|     } else { | ||||
|         // When v is transposed, we also need the element size and get the element ranges from each row | ||||
|         const uint32_t kv_size = size; | ||||
|         const uint32_t kv_size = cells.size(); | ||||
|  | ||||
|         for (const auto & layer : layers) { | ||||
|             const uint32_t il = layer.il; | ||||
| @@ -1429,13 +1368,19 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|             io.read_to(&pos,      sizeof(pos)); | ||||
|             io.read_to(&n_seq_id, sizeof(n_seq_id)); | ||||
|  | ||||
|             if (n_seq_id != 0) { | ||||
|             if (n_seq_id != 1) { | ||||
|                 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); | ||||
|                 return false; | ||||
|             } | ||||
|  | ||||
|             // read the sequence id, but directly discard it - we will use dest_seq_id instead | ||||
|             { | ||||
|                 llama_seq_id seq_id; | ||||
|                 io.read_to(&seq_id, sizeof(seq_id)); | ||||
|             } | ||||
|  | ||||
|             batch.pos[i]      = pos; | ||||
|             batch.n_seq_id[i] = 1; | ||||
|             batch.n_seq_id[i] = n_seq_id; | ||||
|             batch.seq_id[i]   = &dest_seq_id; | ||||
|         } | ||||
|  | ||||
| @@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|  | ||||
|         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) | ||||
|         // Assume that this is one contiguous block of cells | ||||
|         GGML_ASSERT(head + cell_count <= size); | ||||
|         GGML_ASSERT(cells[head].pos == batch.pos[0]); | ||||
|         GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); | ||||
|         GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); | ||||
|         GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); | ||||
|         GGML_ASSERT(head + cell_count <= cells.size()); | ||||
|         GGML_ASSERT(cells.pos_get(head)                  == batch.pos[0]); | ||||
|         GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); | ||||
|         GGML_ASSERT(cells.seq_has(head,                  dest_seq_id)); | ||||
|         GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); | ||||
|     } else { | ||||
|         // whole KV cache restore | ||||
|  | ||||
|         if (cell_count > size) { | ||||
|         if (cell_count > cells.size()) { | ||||
|             LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); | ||||
|             return false; | ||||
|         } | ||||
| @@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|         clear(); | ||||
|  | ||||
|         for (uint32_t i = 0; i < cell_count; ++i) { | ||||
|             kv_cell & cell = cells[i]; | ||||
|  | ||||
|             llama_pos pos; | ||||
|             uint32_t  n_seq_id; | ||||
|  | ||||
|             io.read_to(&pos,      sizeof(pos)); | ||||
|             io.read_to(&n_seq_id, sizeof(n_seq_id)); | ||||
|  | ||||
|             cell.pos = pos; | ||||
|             cells.pos_set(i, pos); | ||||
|  | ||||
|             for (uint32_t j = 0; j < n_seq_id; ++j) { | ||||
|                 llama_seq_id seq_id; | ||||
| @@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | ||||
|                     return false; | ||||
|                 } | ||||
|  | ||||
|                 cell.seq_id.insert(seq_id); | ||||
|                 cells.seq_add(i, seq_id); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         head = 0; | ||||
|         used = cell_count; | ||||
|     } | ||||
|  | ||||
|     return true; | ||||
| @@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell | ||||
|         LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); | ||||
|         return false; | ||||
|     } | ||||
|     if (cell_count > size) { | ||||
|         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); | ||||
|     if (cell_count > cells.size()) { | ||||
|         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); | ||||
|         return false; | ||||
|     } | ||||
|     if (this->v_trans != (bool) v_trans) { | ||||
| @@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell | ||||
|             if (cell_count) { | ||||
|                 // For each row in the transposed matrix, read the values for the whole cell range | ||||
|                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||
|                     const size_t dst_offset = (head + j * size) * v_size_el; | ||||
|                     const size_t dst_offset = (head + j * cells.size()) * v_size_el; | ||||
|                     ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); | ||||
|                 } | ||||
|             } | ||||
| @@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { | ||||
|     kv_swa ->seq_keep(seq_id); | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | ||||
|     kv_base->seq_add(seq_id, p0, p1, delta); | ||||
|     kv_swa ->seq_add(seq_id, p0, p1, delta); | ||||
| void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||
|     kv_base->seq_add(seq_id, p0, p1, shift); | ||||
|     kv_swa ->seq_add(seq_id, p0, p1, shift); | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | ||||
| @@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | ||||
|     if (delta == 0) { | ||||
| void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||
|     if (shift == 0) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
| @@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_ | ||||
|         if (tail_id >= 0) { | ||||
|             kv_cell & cell = cells[tail_id]; | ||||
|             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { | ||||
|                 cell.pos += delta; | ||||
|                 cell.pos += shift; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| #include "llama-io.h" | ||||
| #include "llama-graph.h" | ||||
| #include "llama-memory.h" | ||||
| #include "llama-kv-cells.h" | ||||
|  | ||||
| #include "ggml-cpp.h" | ||||
|  | ||||
| @@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i { | ||||
|     virtual void defrag_sched(float thold) = 0; | ||||
|  | ||||
|     // simulate full cache, used for allocating worst-case compute buffers | ||||
|     // TODO: remove | ||||
|     virtual void set_full() = 0; | ||||
|  | ||||
|     // | ||||
| @@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i { | ||||
|     // | ||||
|  | ||||
|     // ============================================================================================================= | ||||
|     // TODO: refactor  and simplify this | ||||
|     // TODO: refactor and simplify this [TAG: KV_API] | ||||
|  | ||||
|     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; | ||||
|  | ||||
| @@ -121,7 +123,7 @@ public: | ||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||
|  | ||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||
| @@ -180,26 +182,6 @@ private: | ||||
|     const llama_model & model; | ||||
|     const llama_hparams & hparams; | ||||
|  | ||||
|     struct kv_cell { | ||||
|         llama_pos pos   = -1; | ||||
|         llama_pos delta =  0; | ||||
|  | ||||
|         // TODO: replace with bitset uint64_t | ||||
|         std::set<llama_seq_id> seq_id; | ||||
|  | ||||
|         bool has_seq_id(const llama_seq_id & id) const { | ||||
|             return seq_id.find(id) != seq_id.end(); | ||||
|         } | ||||
|  | ||||
|         bool is_empty() const { | ||||
|             return seq_id.empty(); | ||||
|         } | ||||
|  | ||||
|         bool is_same_seq(const kv_cell & other) const { | ||||
|             return seq_id == other.seq_id; | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     struct kv_layer { | ||||
|         // layer index in the model | ||||
|         // note: can be different from the layer index in the KV cache | ||||
| @@ -209,15 +191,13 @@ private: | ||||
|         ggml_tensor * v; | ||||
|     }; | ||||
|  | ||||
|     bool has_shift = false; | ||||
|     bool do_defrag = false; | ||||
|     bool v_trans   = true;  // the value tensor is transposed | ||||
|  | ||||
|     uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) | ||||
|     uint32_t size = 0; // total number of cells, shared across all sequences | ||||
|     uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt) | ||||
|  | ||||
|     // computed before each graph build | ||||
|     // TODO: cells should start to maintain this value dynamically based on the edits | ||||
|     uint32_t n = 0; | ||||
|  | ||||
|     const uint32_t n_seq_max = 1; | ||||
| @@ -233,19 +213,29 @@ private: | ||||
|     std::vector<ggml_context_ptr>        ctxs; | ||||
|     std::vector<ggml_backend_buffer_ptr> bufs; | ||||
|  | ||||
|     std::vector<kv_cell>  cells;  // TODO: replace with `struct kv_cells` | ||||
|     llama_kv_cells_unified cells; | ||||
|  | ||||
|     std::vector<kv_layer> layers; | ||||
|  | ||||
|     // model layer id -> KV cache layer id | ||||
|     std::unordered_map<int32_t, int32_t> map_layer_ids; | ||||
|  | ||||
|     // recovery information used to restore the KV cells to their original state in case of a failure | ||||
|     // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation | ||||
|     //       to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] | ||||
|     struct { | ||||
|         void clear() { | ||||
|             cells.clear(); | ||||
|             states.clear(); | ||||
|         } | ||||
|  | ||||
|         std::unordered_map<uint32_t, kv_cell> cells; | ||||
|         struct state { | ||||
|             uint32_t i; | ||||
|  | ||||
|             llama_kv_cells_unified cells; | ||||
|         }; | ||||
|  | ||||
|         // stack with the partial states before each ubatch | ||||
|         std::vector<state> states; | ||||
|     } recovery; | ||||
|  | ||||
|     // defrag | ||||
| @@ -257,6 +247,7 @@ private: | ||||
|     bool defrag_prepare(int32_t n_max_nodes); | ||||
|  | ||||
|     // find how many cells are currently in use | ||||
|     // TODO: optimize | ||||
|     uint32_t cell_max() const; | ||||
|  | ||||
|     size_t total_size() const; | ||||
| @@ -325,7 +316,7 @@ public: | ||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||
|  | ||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||
| @@ -431,7 +422,7 @@ public: | ||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; | ||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||
|  | ||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||
|   | ||||
							
								
								
									
										273
									
								
								src/llama-kv-cells.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								src/llama-kv-cells.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,273 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "llama.h" | ||||
| #include "llama-cparams.h" | ||||
|  | ||||
| #include <bitset> | ||||
| #include <cassert> | ||||
| #include <vector> | ||||
|  | ||||
| // meta information about KV cells that can be part of multiple sequences at the same time | ||||
| // TODO: add unit tests | ||||
| class llama_kv_cells_unified { | ||||
| public: | ||||
|     void reset() { | ||||
|         for (uint32_t i = 0; i < pos.size(); ++i) { | ||||
|             pos[i]   = -1; | ||||
|             shift[i] =  0; | ||||
|             seq[i].reset(); | ||||
|         } | ||||
|  | ||||
|         used      = 0; | ||||
|         has_shift = false; | ||||
|     } | ||||
|  | ||||
|     void reset_shift() { | ||||
|         has_shift = false; | ||||
|  | ||||
|         for (uint32_t i = 0; i < shift.size(); ++i) { | ||||
|             shift[i] = 0; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     uint32_t size() const { | ||||
|         return pos.size(); | ||||
|     } | ||||
|  | ||||
|     void resize(uint32_t n) { | ||||
|         pos.resize(n); | ||||
|         shift.resize(n); | ||||
|         seq.resize(n); | ||||
|  | ||||
|         reset(); | ||||
|     } | ||||
|  | ||||
|     bool is_empty(uint32_t i) const { | ||||
|         assert(i < pos.size()); | ||||
|         assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); | ||||
|  | ||||
|         return pos[i] == -1; | ||||
|     } | ||||
|  | ||||
|     uint32_t get_used() const { | ||||
|         return used; | ||||
|     } | ||||
|  | ||||
|     bool get_has_shift() const { | ||||
|         return has_shift; | ||||
|     } | ||||
|  | ||||
|     // move cell isrc to idst (used during defrag) | ||||
|     void mv(uint32_t isrc, uint32_t idst) { | ||||
|         assert(isrc < pos.size()); | ||||
|         assert(idst < pos.size()); | ||||
|  | ||||
|         pos  [idst] = pos  [isrc]; | ||||
|         shift[idst] = shift[isrc]; | ||||
|         seq  [idst] = seq  [isrc]; | ||||
|  | ||||
|         pos  [isrc] = -1; | ||||
|         shift[isrc] =  0; | ||||
|         seq  [isrc].reset(); | ||||
|     } | ||||
|  | ||||
|     // copy the state of cells [i, i + n) (used for save/restore the state of the cells) | ||||
|     llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { | ||||
|         assert(i + n <= pos.size()); | ||||
|  | ||||
|         llama_kv_cells_unified res; | ||||
|  | ||||
|         res.resize(n); | ||||
|  | ||||
|         for (uint32_t j = 0; j < n; ++j) { | ||||
|             res.pos[j] = pos[i + j]; | ||||
|             res.seq[j] = seq[i + j]; | ||||
|  | ||||
|             assert(shift[i + j] == 0); | ||||
|         } | ||||
|  | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) | ||||
|     void set(uint32_t i, const llama_kv_cells_unified & other) { | ||||
|         assert(i + other.pos.size() <= pos.size()); | ||||
|  | ||||
|         for (uint32_t j = 0; j < other.pos.size(); ++j) { | ||||
|             if (pos[i + j] == -1 && other.pos[j] != -1) { | ||||
|                 used++; | ||||
|             } | ||||
|  | ||||
|             if (pos[i + j] != -1 && other.pos[j] == -1) { | ||||
|                 used--; | ||||
|             } | ||||
|  | ||||
|             pos[i + j] = other.pos[j]; | ||||
|             seq[i + j] = other.seq[j]; | ||||
|  | ||||
|             assert(shift[i + j] == 0); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // note: call only if the cell has seq_id | ||||
|     // return true if the cell becomes empty | ||||
|     bool seq_rm(uint32_t i, llama_seq_id seq_id) { | ||||
|         assert(i < pos.size()); | ||||
|         assert(seq[i].test(seq_id)); | ||||
|         assert(pos[i] != -1); | ||||
|         assert(seq_id >= 0); | ||||
|  | ||||
|         seq[i].reset(seq_id); | ||||
|  | ||||
|         if (seq[i].none()) { | ||||
|             pos[i] = -1; | ||||
|  | ||||
|             used--; | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
|     // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) | ||||
|     bool seq_keep(uint32_t i, llama_seq_id seq_id) { | ||||
|         assert(i < pos.size()); | ||||
|  | ||||
|         if (seq[i].test(seq_id)) { | ||||
|             seq[i].reset(); | ||||
|             seq[i].set(seq_id); | ||||
|  | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         if (seq[i].any()) { | ||||
|             seq[i].reset(); | ||||
|             pos[i] = -1; | ||||
|  | ||||
|             used--; | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         assert(pos[i] == -1); | ||||
|  | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
|     bool seq_has(uint32_t i, llama_seq_id seq_id) const { | ||||
|         assert(i < pos.size()); | ||||
|         assert(seq_id >= 0); | ||||
|  | ||||
|         return seq[i].test(seq_id); | ||||
|     } | ||||
|  | ||||
|     // note: call only if the cell is not empty and the seq_id is not in the cell | ||||
|     void seq_add(uint32_t i, llama_seq_id seq_id) { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|         assert(!seq[i].test(seq_id)); | ||||
|  | ||||
|         seq[i].set(seq_id); | ||||
|     } | ||||
|  | ||||
|     // note: call only if the cell is not empty | ||||
|     llama_pos pos_get(uint32_t i) const { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|  | ||||
|         return pos[i]; | ||||
|     } | ||||
|  | ||||
|     // note: call only if the cell is not empty | ||||
|     llama_pos get_shift(uint32_t i) const { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|  | ||||
|         return shift[i]; | ||||
|     } | ||||
|  | ||||
|     // check if a cell is not empty and its position is within [p0, p1) | ||||
|     bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { | ||||
|         assert(i < pos.size()); | ||||
|  | ||||
|         return pos[i] >= p0 && pos[i] < p1; | ||||
|     } | ||||
|  | ||||
|     // set the position of an empty cell | ||||
|     // does not modify "has_shift" | ||||
|     // note: call only if the cell is empty | ||||
|     void pos_set(uint32_t i, llama_pos p) { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] == -1); | ||||
|  | ||||
|         pos[i] = p; | ||||
|         used++; | ||||
|     } | ||||
|  | ||||
|     // pos[i] = pos[i] + d | ||||
|     // sets "has_shift" to true | ||||
|     // note: call only if the cell is not empty | ||||
|     bool pos_add(uint32_t i, llama_pos d) { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|  | ||||
|         pos[i]   += d; | ||||
|         shift[i] += d; | ||||
|  | ||||
|         has_shift = true; | ||||
|  | ||||
|         if (pos[i] < 0) { | ||||
|             pos[i] = -1; | ||||
|             seq[i].reset(); | ||||
|  | ||||
|             used--; | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         return false; | ||||
|     } | ||||
|  | ||||
|     // pos[i] = pos[i] / d | ||||
|     // sets "has_shift" to true | ||||
|     // note: call only if the cell is not empty | ||||
|     void pos_div(uint32_t i, int d) { | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|  | ||||
|         const llama_pos p_old = pos[i]; | ||||
|  | ||||
|         pos[i]   /= d; | ||||
|         shift[i] += p_old - pos[i]; | ||||
|  | ||||
|         has_shift = true; | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) | ||||
|  | ||||
|     bool has_shift = false; | ||||
|  | ||||
|     std::vector<llama_pos> pos; | ||||
|  | ||||
|     // this array accumulates any applied shifts to the pos array since the last reset_shift() call | ||||
|     // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: | ||||
|     // | ||||
|     //   cells.pos_add(x, shift_x); | ||||
|     //   cells.pos_div(y, shift_y); | ||||
|     //   ... | ||||
|     // | ||||
|     //   if (cells.has_shift()) { | ||||
|     //      for (int i = 0; i < n; ++i) { | ||||
|     //          auto shift_i = cells.get_shift(i); | ||||
|     //          ... | ||||
|     //      } | ||||
|     //      cells.reset_shift(); | ||||
|     //   } | ||||
|     // | ||||
|     std::vector<llama_pos> shift; | ||||
|  | ||||
|     std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq; | ||||
| }; | ||||
|  | ||||
| @@ -22,7 +22,7 @@ public: | ||||
|     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0; | ||||
|     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; | ||||
|     virtual void seq_keep(llama_seq_id seq_id) = 0; | ||||
|     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0; | ||||
|     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0; | ||||
|     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0; | ||||
|  | ||||
|     virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov