mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cache : rework kv_cell (#13706)
* kv-cache : rework kv_cell ggml-ci * kv-cells : use "shift" instead of "delta" consistently ggml-ci * llama : add llama_max_parallel_sequences() ggml-ci * kv-cells : update comments [no ci] * context : fail upon construction if sequences exceed max value ggml-ci * kv-cells : get_pos() -> pos_get() + comments ggml-ci * kv-cells : fix tracking of "used" cells ggml-ci
This commit is contained in:
		| @@ -471,6 +471,7 @@ extern "C" { | |||||||
|     LLAMA_API int64_t llama_time_us(void); |     LLAMA_API int64_t llama_time_us(void); | ||||||
|  |  | ||||||
|     LLAMA_API size_t llama_max_devices(void); |     LLAMA_API size_t llama_max_devices(void); | ||||||
|  |     LLAMA_API size_t llama_max_parallel_sequences(void); | ||||||
|  |  | ||||||
|     LLAMA_API bool llama_supports_mmap       (void); |     LLAMA_API bool llama_supports_mmap       (void); | ||||||
|     LLAMA_API bool llama_supports_mlock      (void); |     LLAMA_API bool llama_supports_mlock      (void); | ||||||
|   | |||||||
| @@ -25,7 +25,11 @@ llama_context::llama_context( | |||||||
|  |  | ||||||
|     const auto & hparams = model.hparams; |     const auto & hparams = model.hparams; | ||||||
|  |  | ||||||
|     cparams.n_seq_max        = std::max(1u, params.n_seq_max); |     cparams.n_seq_max = std::max(1u, params.n_seq_max); | ||||||
|  |     if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { | ||||||
|  |         throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     cparams.n_threads        = params.n_threads; |     cparams.n_threads        = params.n_threads; | ||||||
|     cparams.n_threads_batch  = params.n_threads_batch; |     cparams.n_threads_batch  = params.n_threads_batch; | ||||||
|     cparams.yarn_ext_factor  = params.yarn_ext_factor; |     cparams.yarn_ext_factor  = params.yarn_ext_factor; | ||||||
|   | |||||||
| @@ -1 +1,5 @@ | |||||||
| #include "llama-cparams.h" | #include "llama-cparams.h" | ||||||
|  |  | ||||||
|  | size_t llama_max_parallel_sequences(void) { | ||||||
|  |     return LLAMA_MAX_PARALLEL_SEQUENCES; | ||||||
|  | } | ||||||
|   | |||||||
| @@ -4,6 +4,8 @@ | |||||||
|  |  | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
|  |  | ||||||
|  | #define LLAMA_MAX_PARALLEL_SEQUENCES 64 | ||||||
|  |  | ||||||
| struct llama_cparams { | struct llama_cparams { | ||||||
|     uint32_t n_ctx;           // context size used during inference |     uint32_t n_ctx;           // context size used during inference | ||||||
|     uint32_t n_batch; |     uint32_t n_batch; | ||||||
|   | |||||||
| @@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified( | |||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     head = 0; |     head = 0; | ||||||
|     size = kv_size; |  | ||||||
|     used = 0; |  | ||||||
|  |  | ||||||
|     cells.resize(kv_size); |     cells.resize(kv_size); | ||||||
|  |  | ||||||
| @@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( | |||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::clear() { | void llama_kv_cache_unified::clear() { | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     cells.reset(); | ||||||
|         cells[i].pos = -1; |  | ||||||
|         cells[i].seq_id.clear(); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     head = 0; |     head = 0; | ||||||
|     used = 0; |  | ||||||
|  |  | ||||||
|     for (auto & buf : bufs) { |     for (auto & buf : bufs) { | ||||||
|         ggml_backend_buffer_clear(buf.get(), 0); |         ggml_backend_buffer_clear(buf.get(), 0); | ||||||
| @@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | ||||||
|     uint32_t new_head = size; |     uint32_t new_head = cells.size(); | ||||||
|  |  | ||||||
|     if (p0 < 0) { |     if (p0 < 0) { | ||||||
|         p0 = 0; |         p0 = 0; | ||||||
| @@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos | |||||||
|         p1 = std::numeric_limits<llama_pos>::max(); |         p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (cells[i].pos >= p0 && cells[i].pos < p1) { |         if (!cells.pos_in(i, p0, p1)) { | ||||||
|             if (seq_id < 0) { |             continue; | ||||||
|                 cells[i].seq_id.clear(); |         } | ||||||
|             } else if (cells[i].has_seq_id(seq_id)) { |  | ||||||
|                 cells[i].seq_id.erase(seq_id); |  | ||||||
|             } else { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (cells[i].is_empty()) { |         if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { | ||||||
|                 // keep count of the number of used cells |             if (new_head == cells.size()) { | ||||||
|                 if (cells[i].pos >= 0) { |                 new_head = i; | ||||||
|                     used--; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 cells[i].pos = -1; |  | ||||||
|  |  | ||||||
|                 if (new_head == size) { |  | ||||||
|                     new_head = i; |  | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // If we freed up a slot, set head to it so searching can start there. |     // If we freed up a slot, set head to it so searching can start there. | ||||||
|     if (new_head != size && new_head < head) { |     if (new_head != cells.size() && new_head < head) { | ||||||
|         head = new_head; |         head = new_head; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id | |||||||
|         p1 = std::numeric_limits<llama_pos>::max(); |         p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // otherwise, this is the KV of a Transformer-like model |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|     head = 0; |         if (!cells.pos_in(i, p0, p1)) { | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |         if (cells.seq_has(i, seq_id_src)) { | ||||||
|         if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { |             cells.seq_add(i, seq_id_dst); | ||||||
|             cells[i].seq_id.insert(seq_id_dst); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { | void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { | ||||||
|     uint32_t new_head = size; |     uint32_t new_head = cells.size(); | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (!cells[i].has_seq_id(seq_id)) { |         if (cells.seq_keep(i, seq_id)) { | ||||||
|             if (cells[i].pos >= 0) { |             if (new_head == cells.size()) { | ||||||
|                 used--; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             cells[i].pos = -1; |  | ||||||
|             cells[i].seq_id.clear(); |  | ||||||
|  |  | ||||||
|             if (new_head == size){ |  | ||||||
|                 new_head = i; |                 new_head = i; | ||||||
|             } |             } | ||||||
|         } else { |  | ||||||
|             cells[i].seq_id.clear(); |  | ||||||
|             cells[i].seq_id.insert(seq_id); |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // If we freed up a slot, set head to it so searching can start there. |     // If we freed up a slot, set head to it so searching can start there. | ||||||
|     if (new_head != size && new_head < head) { |     if (new_head != cells.size() && new_head < head) { | ||||||
|         head = new_head; |         head = new_head; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||||
|     if (delta == 0) { |     if (shift == 0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     uint32_t new_head = size; |     uint32_t new_head = cells.size(); | ||||||
|  |  | ||||||
|     if (p0 < 0) { |     if (p0 < 0) { | ||||||
|         p0 = 0; |         p0 = 0; | ||||||
| @@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po | |||||||
|         p1 = std::numeric_limits<llama_pos>::max(); |         p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // If there is no range then return early to avoid looping over the |     // If there is no range then return early to avoid looping over all cells. | ||||||
|     if (p0 == p1) { |     if (p0 == p1) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { |         if (!cells.pos_in(i, p0, p1)) { | ||||||
|             has_shift = true; |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|             cells[i].pos   += delta; |         if (cells.seq_has(i, seq_id)) { | ||||||
|             cells[i].delta += delta; |             if (cells.pos_add(i, shift)) { | ||||||
|  |                 if (new_head == cells.size()) { | ||||||
|             if (cells[i].pos < 0) { |  | ||||||
|                 if (!cells[i].is_empty()) { |  | ||||||
|                     used--; |  | ||||||
|                 } |  | ||||||
|                 cells[i].pos = -1; |  | ||||||
|                 cells[i].seq_id.clear(); |  | ||||||
|                 if (new_head == size) { |  | ||||||
|                     new_head = i; |                     new_head = i; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po | |||||||
|  |  | ||||||
|     // If we freed up a slot, set head to it so searching can start there. |     // If we freed up a slot, set head to it so searching can start there. | ||||||
|     // Otherwise we just start the next search from the beginning. |     // Otherwise we just start the next search from the beginning. | ||||||
|     head = new_head != size ? new_head : 0; |     head = new_head != cells.size() ? new_head : 0; | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | ||||||
| @@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { |         if (!cells.pos_in(i, p0, p1)) { | ||||||
|             has_shift = true; |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|             { |         if (cells.seq_has(i, seq_id)) { | ||||||
|                 llama_pos p_old = cells[i].pos; |             cells.pos_div(i, d); | ||||||
|                 cells[i].pos   /= d; |  | ||||||
|                 cells[i].delta += cells[i].pos - p_old; |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po | |||||||
| llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | ||||||
|     llama_pos result = std::numeric_limits<llama_pos>::max(); |     llama_pos result = std::numeric_limits<llama_pos>::max(); | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (cells[i].has_seq_id(seq_id)) { |         if (cells.seq_has(i, seq_id)) { | ||||||
|             result = std::min(result, cells[i].pos); |             result = std::min(result, cells.pos_get(i)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | |||||||
| llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | ||||||
|     llama_pos result = -1; |     llama_pos result = -1; | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if (cells[i].has_seq_id(seq_id)) { |         if (cells.seq_has(i, seq_id)) { | ||||||
|             result = std::max(result, cells[i].pos); |             result = std::max(result, cells.pos_get(i)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | |||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::restore() { | void llama_kv_cache_unified::restore() { | ||||||
|     for (const auto & [id, cell] : recovery.cells) { |     for (auto & state : recovery.states) { | ||||||
|         // TODO: move to new `struct kv_cells` |         cells.set(state.i, state.cells); | ||||||
|         const bool is_empty0 = cells[id].is_empty(); |  | ||||||
|         const bool is_empty1 = cell.is_empty(); |  | ||||||
|  |  | ||||||
|         if (!is_empty0 && is_empty1) { |  | ||||||
|             used--; |  | ||||||
|         } else if (is_empty0 && !is_empty1) { |  | ||||||
|             used++; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         cells[id] = cell; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     recovery.clear(); |     recovery.clear(); | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::commit() { | void llama_kv_cache_unified::commit() { | ||||||
|     if (recovery.cells.empty()) { |     if (recovery.states.empty()) { | ||||||
|         LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", |         LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", | ||||||
|                 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); |                 __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); | ||||||
|         return; |         return; | ||||||
| @@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
|  |  | ||||||
|     auto * sched = lctx.get_sched(); |     auto * sched = lctx.get_sched(); | ||||||
|  |  | ||||||
|     if (has_shift) { |     if (cells.get_has_shift()) { | ||||||
|         if (!get_can_shift()) { |         if (!get_can_shift()) { | ||||||
|             GGML_ABORT("The current KV cache / model configuration does not support K-shift"); |             GGML_ABORT("The current KV cache / model configuration does not support K-shift"); | ||||||
|         } |         } | ||||||
| @@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
|             need_reserve = true; |             need_reserve = true; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         { |         cells.reset_shift(); | ||||||
|             has_shift = false; |  | ||||||
|  |  | ||||||
|             for (uint32_t i = 0; i < size; ++i) { |  | ||||||
|                 cells[i].delta = 0; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (do_defrag) { |     if (do_defrag) { | ||||||
| @@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { | |||||||
| void llama_kv_cache_unified::defrag_sched(float thold) { | void llama_kv_cache_unified::defrag_sched(float thold) { | ||||||
|     // - do not defrag small contexts (i.e. < 2048 tokens) |     // - do not defrag small contexts (i.e. < 2048 tokens) | ||||||
|     // - count the padding towards the number of used tokens |     // - count the padding towards the number of used tokens | ||||||
|     const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; |     const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; | ||||||
|  |  | ||||||
|     // queue defragmentation for next llama_kv_cache_update |     // queue defragmentation for next llama_kv_cache_update | ||||||
|     if (fragmentation > thold) { |     if (fragmentation > thold) { | ||||||
| @@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) { | |||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified::set_full() { | void llama_kv_cache_unified::set_full() { | ||||||
|     n = size; |     n = cells.size(); | ||||||
|  |  | ||||||
|     // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not |     // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not | ||||||
|     //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. |     //   affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. | ||||||
| @@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | |||||||
|  |  | ||||||
|     // if we have enough unused cells before the current head -> |     // if we have enough unused cells before the current head -> | ||||||
|     //   better to start searching from the beginning of the cache, hoping to fill it |     //   better to start searching from the beginning of the cache, hoping to fill it | ||||||
|     if (head > used + 2*ubatch.n_tokens) { |     if (head > cells.get_used() + 2*ubatch.n_tokens) { | ||||||
|         head = 0; |         head = 0; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // otherwise, one cell per token. |     // otherwise, one cell per token. | ||||||
|  |  | ||||||
|     if (n_tokens > size) { |     if (n_tokens > cells.size()) { | ||||||
|         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); |         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | |||||||
|         std::string ss; |         std::string ss; | ||||||
|         if (n_swa > 0) { |         if (n_swa > 0) { | ||||||
|             for (uint32_t i = 0; i < size; ++i) { |             for (uint32_t i = 0; i < size; ++i) { | ||||||
|                 if (cells[i].pos == -1) { |                 if (cells.is_empty(i)) { | ||||||
|                     ss += '.'; |                     ss += '.'; | ||||||
|                 } else { |                 } else { | ||||||
|                     ss += std::to_string(*cells[i].seq_id.begin()); |                     ss += 'x'; | ||||||
|                 } |                 } | ||||||
|                 if (i%256 == 255) { |                 if (i%256 == 255) { | ||||||
|                     ss += '\n'; |                     ss += '\n'; | ||||||
| @@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | |||||||
|     uint32_t n_tested = 0; |     uint32_t n_tested = 0; | ||||||
|  |  | ||||||
|     while (true) { |     while (true) { | ||||||
|         if (head + n_tokens > size) { |         if (head + n_tokens > cells.size()) { | ||||||
|             n_tested += size - head; |             n_tested += cells.size() - head; | ||||||
|             head = 0; |             head = 0; | ||||||
|             continue; |             continue; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         bool found = true; |         bool found = true; | ||||||
|         for (uint32_t i = 0; i < n_tokens; i++) { |         for (uint32_t i = 0; i < n_tokens; i++) { | ||||||
|             if (cells[head + i].pos >= 0) { |             // TODO: improve to accept cells that are masked by the SWA | ||||||
|  |             if (!cells.is_empty(head + i)) { | ||||||
|                 found = false; |                 found = false; | ||||||
|                 head     += i + 1; |                 head     += i + 1; | ||||||
|                 n_tested += i + 1; |                 n_tested += i + 1; | ||||||
| @@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | |||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (n_tested >= size) { |         if (n_tested >= cells.size()) { | ||||||
|             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); |             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < n_tokens; ++i) { |     // store the old state of the cells in the recovery stack | ||||||
|         // remember the original state |     recovery.states.push_back({head, cells.cp(head, n_tokens)}); | ||||||
|         if (recovery.cells.find(head + i) == recovery.cells.end()) { |  | ||||||
|             recovery.cells[head + i] = cells[head + i]; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         cells[head + i].pos = ubatch.pos[i]; |     for (uint32_t i = 0; i < n_tokens; ++i) { | ||||||
|  |         cells.pos_set(head + i, ubatch.pos[i]); | ||||||
|  |  | ||||||
|         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { |         for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { | ||||||
|             cells[head + i].seq_id.insert(ubatch.seq_id[i][j]); |             cells.seq_add(head + i, ubatch.seq_id[i][j]); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     used += n_tokens; |  | ||||||
|  |  | ||||||
|     // a heuristic, to avoid attending the full cache if it is not yet utilized |     // a heuristic, to avoid attending the full cache if it is not yet utilized | ||||||
|     // after enough generations, the benefit from this heuristic disappears |     // after enough generations, the benefit from this heuristic disappears | ||||||
|     // if we start defragmenting the cache, the benefit from this will be more important |     // if we start defragmenting the cache, the benefit from this will be more important | ||||||
|     n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); |     n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); | ||||||
|  |  | ||||||
| #ifdef FIND_SLOT_DEBUG | #ifdef FIND_SLOT_DEBUG | ||||||
|     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); |     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); | ||||||
| @@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| uint32_t llama_kv_cache_unified::get_size() const { | uint32_t llama_kv_cache_unified::get_size() const { | ||||||
|     return size; |     return cells.size(); | ||||||
| } | } | ||||||
|  |  | ||||||
| ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { | ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { | ||||||
| @@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam | |||||||
|  |  | ||||||
|     int n_attended = 0; |     int n_attended = 0; | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         const llama_pos p0 = cells[i].pos; |         if (!cells.seq_has(i, seq_id)) { | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const llama_pos p0 = cells.pos_get(i); | ||||||
|  |  | ||||||
|         if (p0 <= pmin && !is_masked_swa(p0, pmin)) { |         if (p0 <= pmin && !is_masked_swa(p0, pmin)) { | ||||||
|             n_attended++; |             n_attended++; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (is_masked_swa(p0, pmax)) { |         if (is_masked_swa(p0, pmax)) { | ||||||
|             if (seq_id < 0) { |             cells.seq_rm(i, seq_id); | ||||||
|                 cells[i].seq_id.clear(); |  | ||||||
|             } else if (cells[i].has_seq_id(seq_id)) { |  | ||||||
|                 cells[i].seq_id.erase(seq_id); |  | ||||||
|             } else { |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (cells[i].is_empty()) { |  | ||||||
|                 // keep count of the number of used cells |  | ||||||
|                 if (cells[i].pos >= 0) { |  | ||||||
|                     used--; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 cells[i].pos = -1; |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub | |||||||
|                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; |                 const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; | ||||||
|  |  | ||||||
|                 for (int i = 0; i < n_kv; ++i) { |                 for (int i = 0; i < n_kv; ++i) { | ||||||
|                     const llama_pos p0 = cells[i].pos; |                     float f = 0.0f; | ||||||
|  |  | ||||||
|                     bool masked = false; |                     bool masked = false; | ||||||
|  |  | ||||||
|                     // mask the token if not the same sequence |                     if (cells.is_empty(i)) { | ||||||
|                     masked = masked || (!cells[i].has_seq_id(seq_id)); |                         masked = true; | ||||||
|  |                     } else { | ||||||
|  |                         const llama_pos p0 = cells.pos_get(i); | ||||||
|  |  | ||||||
|                     // mask future tokens |                         // mask the token if not the same sequence | ||||||
|                     masked = masked || (causal_attn && p0 > p1); |                         masked = masked || (!cells.seq_has(i, seq_id)); | ||||||
|  |  | ||||||
|                     // apply SWA if any |                         // mask future tokens | ||||||
|                     masked = masked || (is_masked_swa(p0, p1)); |                         masked = masked || (causal_attn && p0 > p1); | ||||||
|  |  | ||||||
|                     float f = 0.0f; |                         // apply SWA if any | ||||||
|  |                         masked = masked || (is_masked_swa(p0, p1)); | ||||||
|  |  | ||||||
|  |                         if (!masked && hparams.use_alibi) { | ||||||
|  |                             f = -std::abs(p0 - p1); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |  | ||||||
|                     if (masked) { |                     if (masked) { | ||||||
|                         f = -INFINITY; |                         f = -INFINITY; | ||||||
|                     } else if (hparams.use_alibi) { |  | ||||||
|                         f = -std::abs(p0 - p1); |  | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; |                     data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; | ||||||
| @@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { | |||||||
|  |  | ||||||
|     int32_t * data = (int32_t *) dst->data; |     int32_t * data = (int32_t *) dst->data; | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         data[i] = cells[i].delta; |         data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama | |||||||
|     for (int h = 0; h < 1; ++h) { |     for (int h = 0; h < 1; ++h) { | ||||||
|         for (int j = 0; j < n_tokens; ++j) { |         for (int j = 0; j < n_tokens; ++j) { | ||||||
|             for (int i = 0; i < n_kv; ++i) { |             for (int i = 0; i < n_kv; ++i) { | ||||||
|                 data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); |                 // the position when the cells is empty is irrelevant - it will be masked out later in the attention | ||||||
|  |                 const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); | ||||||
|  |  | ||||||
|  |                 data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( | |||||||
|  |  | ||||||
|         ggml_tensor * k = |         ggml_tensor * k = | ||||||
|             ggml_view_3d(ctx, layer.k, |             ggml_view_3d(ctx, layer.k, | ||||||
|                 n_embd_head_k, n_head_kv, size, |                 n_embd_head_k, n_head_kv, cells.size(), | ||||||
|                 ggml_row_size(layer.k->type, n_embd_head_k), |                 ggml_row_size(layer.k->type, n_embd_head_k), | ||||||
|                 ggml_row_size(layer.k->type, n_embd_k_gqa), |                 ggml_row_size(layer.k->type, n_embd_k_gqa), | ||||||
|                 0); |                 0); | ||||||
| @@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | |||||||
|             } else { |             } else { | ||||||
|                 view_v_src = ggml_view_2d(ctx, layer.v, |                 view_v_src = ggml_view_2d(ctx, layer.v, | ||||||
|                         nm, n_embd_v_gqa, |                         nm, n_embd_v_gqa, | ||||||
|                         ggml_row_size(layer.v->type, size), |                         ggml_row_size(layer.v->type, cells.size()), | ||||||
|                         ggml_row_size(layer.v->type, i)); |                         ggml_row_size(layer.v->type, i)); | ||||||
|  |  | ||||||
|                 view_v_dst = ggml_view_2d(ctx, layer.v, |                 view_v_dst = ggml_view_2d(ctx, layer.v, | ||||||
|                         nm, n_embd_v_gqa, |                         nm, n_embd_v_gqa, | ||||||
|                         ggml_row_size(layer.v->type, size), |                         ggml_row_size(layer.v->type, cells.size()), | ||||||
|                         ggml_row_size(layer.v->type, id)); |                         ggml_row_size(layer.v->type, id)); | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|     const uint32_t n_layer = layers.size(); |     const uint32_t n_layer = layers.size(); | ||||||
|  |  | ||||||
|     const uint32_t n_kv   = cell_max(); |     const uint32_t n_kv   = cell_max(); | ||||||
|     const uint32_t n_used = used; |     const uint32_t n_used = cells.get_used(); | ||||||
|  |  | ||||||
|     assert(n_used <= n_kv); |     assert(n_used <= n_kv); | ||||||
|  |  | ||||||
| @@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|     ids.resize(n_kv, n_kv); |     ids.resize(n_kv, n_kv); | ||||||
|  |  | ||||||
|     for (uint32_t i0 = 0; i0 < n_used; ++i0) { |     for (uint32_t i0 = 0; i0 < n_used; ++i0) { | ||||||
|         const auto & cell0 = cells[i0]; |         if (!cells.is_empty(i0)) { | ||||||
|  |  | ||||||
|         if (!cell0.is_empty()) { |  | ||||||
|             ids[i0] = i0; |             ids[i0] = i0; | ||||||
|  |  | ||||||
|             continue; |             continue; | ||||||
| @@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|         uint32_t nh = 1; |         uint32_t nh = 1; | ||||||
|  |  | ||||||
|         // determine the size of the hole |         // determine the size of the hole | ||||||
|         while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { |         while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { | ||||||
|             nh++; |             nh++; | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|  |  | ||||||
|         // starting from the end, find nh non-empty cells |         // starting from the end, find nh non-empty cells | ||||||
|         for (; is > i0; --is) { |         for (; is > i0; --is) { | ||||||
|             const auto & cell1 = cells[is]; |             if (cells.is_empty(is) || ids[is] != n_kv) { | ||||||
|  |  | ||||||
|             if (cell1.is_empty() || ids[is] != n_kv) { |  | ||||||
|                 continue; |                 continue; | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|  |  | ||||||
|         // go back and move the nf cells to the hole |         // go back and move the nf cells to the hole | ||||||
|         for (; i1 < n_kv; ++i1) { |         for (; i1 < n_kv; ++i1) { | ||||||
|             auto & cell1 = cells[i1]; |             if (cells.is_empty(i1) || ids[i1] != n_kv) { | ||||||
|  |  | ||||||
|             if (cell1.is_empty() || ids[i1] != n_kv) { |  | ||||||
|                 if (n_moves == max_moves) { |                 if (n_moves == max_moves) { | ||||||
|                     stop = true; |                     stop = true; | ||||||
|                     break; |                     break; | ||||||
| @@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
|             ids[i1] = i0 + nf; |             ids[i1] = i0 + nf; | ||||||
|  |  | ||||||
|             // move the cell meta data |             // move the cell meta data | ||||||
|             cells[i0 + nf] = cell1; |             cells.mv(i1, i0 + nf); | ||||||
|  |  | ||||||
|             // clear the old cell and move the head there |  | ||||||
|             cell1 = kv_cell(); |  | ||||||
|             head = n_used; |             head = n_used; | ||||||
|  |  | ||||||
|             if (!cont) { |             if (!cont) { | ||||||
| @@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | |||||||
| } | } | ||||||
|  |  | ||||||
| uint32_t llama_kv_cache_unified::cell_max() const { | uint32_t llama_kv_cache_unified::cell_max() const { | ||||||
|     for (uint32_t i = size; i > 0; --i) { |     for (uint32_t i = cells.size(); i > 0; --i) { | ||||||
|         const kv_cell & cell = cells[i - 1]; |         if (!cells.is_empty(i - 1)) { | ||||||
|  |  | ||||||
|         if (cell.pos >= 0 && !cell.is_empty()) { |  | ||||||
|             return i; |             return i; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | ||||||
|     if (p0 < 0) { |     assert(p0 >= 0 && p1 >= 0); | ||||||
|         return true; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     switch (swa_type) { |     switch (swa_type) { | ||||||
|         case LLAMA_SWA_TYPE_NONE: |         case LLAMA_SWA_TYPE_NONE: | ||||||
| @@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq | |||||||
|  |  | ||||||
|     // Count the number of cells with the specified seq_id |     // Count the number of cells with the specified seq_id | ||||||
|     // Find all the ranges of cells with this seq id (or all, when -1) |     // Find all the ranges of cells with this seq id (or all, when -1) | ||||||
|     uint32_t cell_range_begin = size; |     uint32_t cell_range_begin = cells.size(); | ||||||
|     for (uint32_t i = 0; i < size; ++i) { |  | ||||||
|         const auto & cell = cells[i]; |     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||||
|         if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { |         if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { | ||||||
|             ++cell_count; |             ++cell_count; | ||||||
|             if (cell_range_begin == size) { |             if (cell_range_begin == cells.size()) { | ||||||
|                 cell_range_begin = i; |                 cell_range_begin = i; | ||||||
|             } |             } | ||||||
|         } else { |         } else { | ||||||
|             if (cell_range_begin != size) { |             if (cell_range_begin != cells.size()) { | ||||||
|                 cell_ranges.emplace_back(cell_range_begin, i); |                 cell_ranges.emplace_back(cell_range_begin, i); | ||||||
|                 cell_range_begin = size; |                 cell_range_begin = cells.size(); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     if (cell_range_begin != size) { |  | ||||||
|         cell_ranges.emplace_back(cell_range_begin, size); |     if (cell_range_begin != cells.size()) { | ||||||
|  |         cell_ranges.emplace_back(cell_range_begin, cells.size()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count |     // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count | ||||||
| @@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i | |||||||
| void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { | void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { | ||||||
|     for (const auto & range : cell_ranges) { |     for (const auto & range : cell_ranges) { | ||||||
|         for (uint32_t i = range.first; i < range.second; ++i) { |         for (uint32_t i = range.first; i < range.second; ++i) { | ||||||
|             const auto & cell = cells[i]; |             std::vector<llama_seq_id> seq_ids; | ||||||
|             const llama_pos pos      = cell.pos; |  | ||||||
|             const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; |             for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { | ||||||
|  |                 if (cur == seq_id || seq_id == -1) { | ||||||
|  |                     if (cells.seq_has(i, cur)) { | ||||||
|  |                         seq_ids.push_back(cur); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             const llama_pos pos     = cells.pos_get(i); | ||||||
|  |             const uint32_t n_seq_id = seq_ids.size(); | ||||||
|  |  | ||||||
|             io.write(&pos,      sizeof(pos)); |             io.write(&pos,      sizeof(pos)); | ||||||
|             io.write(&n_seq_id, sizeof(n_seq_id)); |             io.write(&n_seq_id, sizeof(n_seq_id)); | ||||||
|  |  | ||||||
|             if (n_seq_id) { |             for (const auto & seq_id : seq_ids) { | ||||||
|                 for (auto seq_id : cell.seq_id) { |                 io.write(&seq_id, sizeof(seq_id)); | ||||||
|                     io.write(&seq_id, sizeof(seq_id)); |  | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: | |||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         // When v is transposed, we also need the element size and get the element ranges from each row |         // When v is transposed, we also need the element size and get the element ranges from each row | ||||||
|         const uint32_t kv_size = size; |         const uint32_t kv_size = cells.size(); | ||||||
|  |  | ||||||
|         for (const auto & layer : layers) { |         for (const auto & layer : layers) { | ||||||
|             const uint32_t il = layer.il; |             const uint32_t il = layer.il; | ||||||
| @@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | |||||||
|             io.read_to(&pos,      sizeof(pos)); |             io.read_to(&pos,      sizeof(pos)); | ||||||
|             io.read_to(&n_seq_id, sizeof(n_seq_id)); |             io.read_to(&n_seq_id, sizeof(n_seq_id)); | ||||||
|  |  | ||||||
|             if (n_seq_id != 0) { |             if (n_seq_id != 1) { | ||||||
|                 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); |                 LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             batch.pos[i] = pos; |             // read the sequence id, but directly discard it - we will use dest_seq_id instead | ||||||
|             batch.n_seq_id[i] = 1; |             { | ||||||
|             batch.seq_id[i] = &dest_seq_id; |                 llama_seq_id seq_id; | ||||||
|  |                 io.read_to(&seq_id, sizeof(seq_id)); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             batch.pos[i]      = pos; | ||||||
|  |             batch.n_seq_id[i] = n_seq_id; | ||||||
|  |             batch.seq_id[i]   = &dest_seq_id; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (!find_slot(batch)) { |         if (!find_slot(batch)) { | ||||||
| @@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | |||||||
|  |  | ||||||
|         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) |         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) | ||||||
|         // Assume that this is one contiguous block of cells |         // Assume that this is one contiguous block of cells | ||||||
|         GGML_ASSERT(head + cell_count <= size); |         GGML_ASSERT(head + cell_count <= cells.size()); | ||||||
|         GGML_ASSERT(cells[head].pos == batch.pos[0]); |         GGML_ASSERT(cells.pos_get(head)                  == batch.pos[0]); | ||||||
|         GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); |         GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); | ||||||
|         GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); |         GGML_ASSERT(cells.seq_has(head,                  dest_seq_id)); | ||||||
|         GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); |         GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); | ||||||
|     } else { |     } else { | ||||||
|         // whole KV cache restore |         // whole KV cache restore | ||||||
|  |  | ||||||
|         if (cell_count > size) { |         if (cell_count > cells.size()) { | ||||||
|             LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); |             LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
| @@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | |||||||
|         clear(); |         clear(); | ||||||
|  |  | ||||||
|         for (uint32_t i = 0; i < cell_count; ++i) { |         for (uint32_t i = 0; i < cell_count; ++i) { | ||||||
|             kv_cell & cell = cells[i]; |  | ||||||
|  |  | ||||||
|             llama_pos pos; |             llama_pos pos; | ||||||
|             uint32_t  n_seq_id; |             uint32_t  n_seq_id; | ||||||
|  |  | ||||||
|             io.read_to(&pos,      sizeof(pos)); |             io.read_to(&pos,      sizeof(pos)); | ||||||
|             io.read_to(&n_seq_id, sizeof(n_seq_id)); |             io.read_to(&n_seq_id, sizeof(n_seq_id)); | ||||||
|  |  | ||||||
|             cell.pos = pos; |             cells.pos_set(i, pos); | ||||||
|  |  | ||||||
|             for (uint32_t j = 0; j < n_seq_id; ++j) { |             for (uint32_t j = 0; j < n_seq_id; ++j) { | ||||||
|                 llama_seq_id seq_id; |                 llama_seq_id seq_id; | ||||||
| @@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell | |||||||
|                     return false; |                     return false; | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 cell.seq_id.insert(seq_id); |                 cells.seq_add(i, seq_id); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         head = 0; |         head = 0; | ||||||
|         used = cell_count; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return true; |     return true; | ||||||
| @@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell | |||||||
|         LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); |         LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|     if (cell_count > size) { |     if (cell_count > cells.size()) { | ||||||
|         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); |         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|     if (this->v_trans != (bool) v_trans) { |     if (this->v_trans != (bool) v_trans) { | ||||||
| @@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell | |||||||
|             if (cell_count) { |             if (cell_count) { | ||||||
|                 // For each row in the transposed matrix, read the values for the whole cell range |                 // For each row in the transposed matrix, read the values for the whole cell range | ||||||
|                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { |                 for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { | ||||||
|                     const size_t dst_offset = (head + j * size) * v_size_el; |                     const size_t dst_offset = (head + j * cells.size()) * v_size_el; | ||||||
|                     ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); |                     ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { | |||||||
|     kv_swa ->seq_keep(seq_id); |     kv_swa ->seq_keep(seq_id); | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||||
|     kv_base->seq_add(seq_id, p0, p1, delta); |     kv_base->seq_add(seq_id, p0, p1, shift); | ||||||
|     kv_swa ->seq_add(seq_id, p0, p1, delta); |     kv_swa ->seq_add(seq_id, p0, p1, shift); | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | ||||||
| @@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { | ||||||
|     if (delta == 0) { |     if (shift == 0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_ | |||||||
|         if (tail_id >= 0) { |         if (tail_id >= 0) { | ||||||
|             kv_cell & cell = cells[tail_id]; |             kv_cell & cell = cells[tail_id]; | ||||||
|             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { |             if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { | ||||||
|                 cell.pos += delta; |                 cell.pos += shift; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ | |||||||
| #include "llama-io.h" | #include "llama-io.h" | ||||||
| #include "llama-graph.h" | #include "llama-graph.h" | ||||||
| #include "llama-memory.h" | #include "llama-memory.h" | ||||||
|  | #include "llama-kv-cells.h" | ||||||
|  |  | ||||||
| #include "ggml-cpp.h" | #include "ggml-cpp.h" | ||||||
|  |  | ||||||
| @@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i { | |||||||
|     virtual void defrag_sched(float thold) = 0; |     virtual void defrag_sched(float thold) = 0; | ||||||
|  |  | ||||||
|     // simulate full cache, used for allocating worst-case compute buffers |     // simulate full cache, used for allocating worst-case compute buffers | ||||||
|  |     // TODO: remove | ||||||
|     virtual void set_full() = 0; |     virtual void set_full() = 0; | ||||||
|  |  | ||||||
|     // |     // | ||||||
| @@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i { | |||||||
|     // |     // | ||||||
|  |  | ||||||
|     // ============================================================================================================= |     // ============================================================================================================= | ||||||
|     // TODO: refactor  and simplify this |     // TODO: refactor and simplify this [TAG: KV_API] | ||||||
|  |  | ||||||
|     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; |     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; | ||||||
|  |  | ||||||
| @@ -121,7 +123,7 @@ public: | |||||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; |     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; |     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; |     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; |     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||||
|  |  | ||||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; |     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||||
| @@ -159,7 +161,7 @@ public: | |||||||
|     // llama_kv_cache_unified specific API |     // llama_kv_cache_unified specific API | ||||||
|     // |     // | ||||||
|  |  | ||||||
|     uint32_t get_n() const; |     uint32_t get_n()    const; | ||||||
|     uint32_t get_size() const; |     uint32_t get_size() const; | ||||||
|  |  | ||||||
|     // get views of the current state of the cache |     // get views of the current state of the cache | ||||||
| @@ -180,26 +182,6 @@ private: | |||||||
|     const llama_model & model; |     const llama_model & model; | ||||||
|     const llama_hparams & hparams; |     const llama_hparams & hparams; | ||||||
|  |  | ||||||
|     struct kv_cell { |  | ||||||
|         llama_pos pos   = -1; |  | ||||||
|         llama_pos delta =  0; |  | ||||||
|  |  | ||||||
|         // TODO: replace with bitset uint64_t |  | ||||||
|         std::set<llama_seq_id> seq_id; |  | ||||||
|  |  | ||||||
|         bool has_seq_id(const llama_seq_id & id) const { |  | ||||||
|             return seq_id.find(id) != seq_id.end(); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         bool is_empty() const { |  | ||||||
|             return seq_id.empty(); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         bool is_same_seq(const kv_cell & other) const { |  | ||||||
|             return seq_id == other.seq_id; |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     struct kv_layer { |     struct kv_layer { | ||||||
|         // layer index in the model |         // layer index in the model | ||||||
|         // note: can be different from the layer index in the KV cache |         // note: can be different from the layer index in the KV cache | ||||||
| @@ -209,15 +191,13 @@ private: | |||||||
|         ggml_tensor * v; |         ggml_tensor * v; | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     bool has_shift = false; |  | ||||||
|     bool do_defrag = false; |     bool do_defrag = false; | ||||||
|     bool v_trans   = true;  // the value tensor is transposed |     bool v_trans   = true;  // the value tensor is transposed | ||||||
|  |  | ||||||
|     uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) |     uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) | ||||||
|     uint32_t size = 0; // total number of cells, shared across all sequences |  | ||||||
|     uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt) |  | ||||||
|  |  | ||||||
|     // computed before each graph build |     // computed before each graph build | ||||||
|  |     // TODO: cells should start to maintain this value dynamically based on the edits | ||||||
|     uint32_t n = 0; |     uint32_t n = 0; | ||||||
|  |  | ||||||
|     const uint32_t n_seq_max = 1; |     const uint32_t n_seq_max = 1; | ||||||
| @@ -233,19 +213,29 @@ private: | |||||||
|     std::vector<ggml_context_ptr>        ctxs; |     std::vector<ggml_context_ptr>        ctxs; | ||||||
|     std::vector<ggml_backend_buffer_ptr> bufs; |     std::vector<ggml_backend_buffer_ptr> bufs; | ||||||
|  |  | ||||||
|     std::vector<kv_cell>  cells;  // TODO: replace with `struct kv_cells` |     llama_kv_cells_unified cells; | ||||||
|  |  | ||||||
|     std::vector<kv_layer> layers; |     std::vector<kv_layer> layers; | ||||||
|  |  | ||||||
|     // model layer id -> KV cache layer id |     // model layer id -> KV cache layer id | ||||||
|     std::unordered_map<int32_t, int32_t> map_layer_ids; |     std::unordered_map<int32_t, int32_t> map_layer_ids; | ||||||
|  |  | ||||||
|     // recovery information used to restore the KV cells to their original state in case of a failure |     // recovery information used to restore the KV cells to their original state in case of a failure | ||||||
|  |     // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation | ||||||
|  |     //       to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] | ||||||
|     struct { |     struct { | ||||||
|         void clear() { |         void clear() { | ||||||
|             cells.clear(); |             states.clear(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         std::unordered_map<uint32_t, kv_cell> cells; |         struct state { | ||||||
|  |             uint32_t i; | ||||||
|  |  | ||||||
|  |             llama_kv_cells_unified cells; | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         // stack with the partial states before each ubatch | ||||||
|  |         std::vector<state> states; | ||||||
|     } recovery; |     } recovery; | ||||||
|  |  | ||||||
|     // defrag |     // defrag | ||||||
| @@ -257,6 +247,7 @@ private: | |||||||
|     bool defrag_prepare(int32_t n_max_nodes); |     bool defrag_prepare(int32_t n_max_nodes); | ||||||
|  |  | ||||||
|     // find how many cells are currently in use |     // find how many cells are currently in use | ||||||
|  |     // TODO: optimize | ||||||
|     uint32_t cell_max() const; |     uint32_t cell_max() const; | ||||||
|  |  | ||||||
|     size_t total_size() const; |     size_t total_size() const; | ||||||
| @@ -325,7 +316,7 @@ public: | |||||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; |     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; |     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; |     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; |     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||||
|  |  | ||||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; |     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||||
| @@ -431,7 +422,7 @@ public: | |||||||
|     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; |     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; |     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; | ||||||
|     void seq_keep(llama_seq_id seq_id)                                                          override; |     void seq_keep(llama_seq_id seq_id)                                                          override; | ||||||
|     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override; |     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override; | ||||||
|     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; |     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override; | ||||||
|  |  | ||||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const override; |     llama_pos seq_pos_min(llama_seq_id seq_id) const override; | ||||||
|   | |||||||
							
								
								
									
										273
									
								
								src/llama-kv-cells.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								src/llama-kv-cells.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,273 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "llama.h" | ||||||
|  | #include "llama-cparams.h" | ||||||
|  |  | ||||||
|  | #include <bitset> | ||||||
|  | #include <cassert> | ||||||
|  | #include <vector> | ||||||
|  |  | ||||||
|  | // meta information about KV cells that can be part of multiple sequences at the same time | ||||||
|  | // TODO: add unit tests | ||||||
|  | class llama_kv_cells_unified { | ||||||
|  | public: | ||||||
|  |     void reset() { | ||||||
|  |         for (uint32_t i = 0; i < pos.size(); ++i) { | ||||||
|  |             pos[i]   = -1; | ||||||
|  |             shift[i] =  0; | ||||||
|  |             seq[i].reset(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         used      = 0; | ||||||
|  |         has_shift = false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void reset_shift() { | ||||||
|  |         has_shift = false; | ||||||
|  |  | ||||||
|  |         for (uint32_t i = 0; i < shift.size(); ++i) { | ||||||
|  |             shift[i] = 0; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     uint32_t size() const { | ||||||
|  |         return pos.size(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void resize(uint32_t n) { | ||||||
|  |         pos.resize(n); | ||||||
|  |         shift.resize(n); | ||||||
|  |         seq.resize(n); | ||||||
|  |  | ||||||
|  |         reset(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     bool is_empty(uint32_t i) const { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); | ||||||
|  |  | ||||||
|  |         return pos[i] == -1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     uint32_t get_used() const { | ||||||
|  |         return used; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     bool get_has_shift() const { | ||||||
|  |         return has_shift; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // move cell isrc to idst (used during defrag) | ||||||
|  |     void mv(uint32_t isrc, uint32_t idst) { | ||||||
|  |         assert(isrc < pos.size()); | ||||||
|  |         assert(idst < pos.size()); | ||||||
|  |  | ||||||
|  |         pos  [idst] = pos  [isrc]; | ||||||
|  |         shift[idst] = shift[isrc]; | ||||||
|  |         seq  [idst] = seq  [isrc]; | ||||||
|  |  | ||||||
|  |         pos  [isrc] = -1; | ||||||
|  |         shift[isrc] =  0; | ||||||
|  |         seq  [isrc].reset(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // copy the state of cells [i, i + n) (used for save/restore the state of the cells) | ||||||
|  |     llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { | ||||||
|  |         assert(i + n <= pos.size()); | ||||||
|  |  | ||||||
|  |         llama_kv_cells_unified res; | ||||||
|  |  | ||||||
|  |         res.resize(n); | ||||||
|  |  | ||||||
|  |         for (uint32_t j = 0; j < n; ++j) { | ||||||
|  |             res.pos[j] = pos[i + j]; | ||||||
|  |             res.seq[j] = seq[i + j]; | ||||||
|  |  | ||||||
|  |             assert(shift[i + j] == 0); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return res; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) | ||||||
|  |     void set(uint32_t i, const llama_kv_cells_unified & other) { | ||||||
|  |         assert(i + other.pos.size() <= pos.size()); | ||||||
|  |  | ||||||
|  |         for (uint32_t j = 0; j < other.pos.size(); ++j) { | ||||||
|  |             if (pos[i + j] == -1 && other.pos[j] != -1) { | ||||||
|  |                 used++; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (pos[i + j] != -1 && other.pos[j] == -1) { | ||||||
|  |                 used--; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             pos[i + j] = other.pos[j]; | ||||||
|  |             seq[i + j] = other.seq[j]; | ||||||
|  |  | ||||||
|  |             assert(shift[i + j] == 0); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // note: call only if the cell has seq_id | ||||||
|  |     // return true if the cell becomes empty | ||||||
|  |     bool seq_rm(uint32_t i, llama_seq_id seq_id) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(seq[i].test(seq_id)); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |         assert(seq_id >= 0); | ||||||
|  |  | ||||||
|  |         seq[i].reset(seq_id); | ||||||
|  |  | ||||||
|  |         if (seq[i].none()) { | ||||||
|  |             pos[i] = -1; | ||||||
|  |  | ||||||
|  |             used--; | ||||||
|  |  | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) | ||||||
|  |     bool seq_keep(uint32_t i, llama_seq_id seq_id) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |  | ||||||
|  |         if (seq[i].test(seq_id)) { | ||||||
|  |             seq[i].reset(); | ||||||
|  |             seq[i].set(seq_id); | ||||||
|  |  | ||||||
|  |             return false; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (seq[i].any()) { | ||||||
|  |             seq[i].reset(); | ||||||
|  |             pos[i] = -1; | ||||||
|  |  | ||||||
|  |             used--; | ||||||
|  |  | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         assert(pos[i] == -1); | ||||||
|  |  | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     bool seq_has(uint32_t i, llama_seq_id seq_id) const { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(seq_id >= 0); | ||||||
|  |  | ||||||
|  |         return seq[i].test(seq_id); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // note: call only if the cell is not empty and the seq_id is not in the cell | ||||||
|  |     void seq_add(uint32_t i, llama_seq_id seq_id) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |         assert(!seq[i].test(seq_id)); | ||||||
|  |  | ||||||
|  |         seq[i].set(seq_id); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // note: call only if the cell is not empty | ||||||
|  |     llama_pos pos_get(uint32_t i) const { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |  | ||||||
|  |         return pos[i]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // note: call only if the cell is not empty | ||||||
|  |     llama_pos get_shift(uint32_t i) const { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |  | ||||||
|  |         return shift[i]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // check if a cell is not empty and its position is within [p0, p1) | ||||||
|  |     bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |  | ||||||
|  |         return pos[i] >= p0 && pos[i] < p1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // set the position of an empty cell | ||||||
|  |     // does not modify "has_shift" | ||||||
|  |     // note: call only if the cell is empty | ||||||
|  |     void pos_set(uint32_t i, llama_pos p) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] == -1); | ||||||
|  |  | ||||||
|  |         pos[i] = p; | ||||||
|  |         used++; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // pos[i] = pos[i] + d | ||||||
|  |     // sets "has_shift" to true | ||||||
|  |     // note: call only if the cell is not empty | ||||||
|  |     bool pos_add(uint32_t i, llama_pos d) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |  | ||||||
|  |         pos[i]   += d; | ||||||
|  |         shift[i] += d; | ||||||
|  |  | ||||||
|  |         has_shift = true; | ||||||
|  |  | ||||||
|  |         if (pos[i] < 0) { | ||||||
|  |             pos[i] = -1; | ||||||
|  |             seq[i].reset(); | ||||||
|  |  | ||||||
|  |             used--; | ||||||
|  |  | ||||||
|  |             return true; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // pos[i] = pos[i] / d | ||||||
|  |     // sets "has_shift" to true | ||||||
|  |     // note: call only if the cell is not empty | ||||||
|  |     void pos_div(uint32_t i, int d) { | ||||||
|  |         assert(i < pos.size()); | ||||||
|  |         assert(pos[i] != -1); | ||||||
|  |  | ||||||
|  |         const llama_pos p_old = pos[i]; | ||||||
|  |  | ||||||
|  |         pos[i]   /= d; | ||||||
|  |         shift[i] += p_old - pos[i]; | ||||||
|  |  | ||||||
|  |         has_shift = true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |     uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) | ||||||
|  |  | ||||||
|  |     bool has_shift = false; | ||||||
|  |  | ||||||
|  |     std::vector<llama_pos> pos; | ||||||
|  |  | ||||||
|  |     // this array accumulates any applied shifts to the pos array since the last reset_shift() call | ||||||
|  |     // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: | ||||||
|  |     // | ||||||
|  |     //   cells.pos_add(x, shift_x); | ||||||
|  |     //   cells.pos_div(y, shift_y); | ||||||
|  |     //   ... | ||||||
|  |     // | ||||||
|  |     //   if (cells.has_shift()) { | ||||||
|  |     //      for (int i = 0; i < n; ++i) { | ||||||
|  |     //          auto shift_i = cells.get_shift(i); | ||||||
|  |     //          ... | ||||||
|  |     //      } | ||||||
|  |     //      cells.reset_shift(); | ||||||
|  |     //   } | ||||||
|  |     // | ||||||
|  |     std::vector<llama_pos> shift; | ||||||
|  |  | ||||||
|  |     std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq; | ||||||
|  | }; | ||||||
|  |  | ||||||
| @@ -22,7 +22,7 @@ public: | |||||||
|     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0; |     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0; | ||||||
|     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; |     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; | ||||||
|     virtual void seq_keep(llama_seq_id seq_id) = 0; |     virtual void seq_keep(llama_seq_id seq_id) = 0; | ||||||
|     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0; |     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0; | ||||||
|     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0; |     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0; | ||||||
|  |  | ||||||
|     virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; |     virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov