mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	kv-cells : track min/max used cells and per-sequence positions (#13808)
* kv-cells : track min/max used cells and per-sequence positions ggml-ci * kv-cells : fix pos-modification updates for seq_pos ggml-ci * kv-cells : add comments ggml-ci
This commit is contained in:
		| @@ -286,31 +286,11 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po | ||||
| } | ||||
|  | ||||
| llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { | ||||
|     llama_pos result = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             result = std::min(result, cells.pos_get(i)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if (result == std::numeric_limits<llama_pos>::max()) { | ||||
|         result = -1; | ||||
|     } | ||||
|  | ||||
|     return result; | ||||
|     return cells.seq_pos_min(seq_id); | ||||
| } | ||||
|  | ||||
| llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { | ||||
|     llama_pos result = -1; | ||||
|  | ||||
|     for (uint32_t i = 0; i < cells.size(); ++i) { | ||||
|         if (cells.seq_has(i, seq_id)) { | ||||
|             result = std::max(result, cells.pos_get(i)); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return result; | ||||
|     return cells.seq_pos_max(seq_id); | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_unified::restore() { | ||||
| @@ -504,7 +484,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { | ||||
|     // a heuristic, to avoid attending the full cache if it is not yet utilized | ||||
|     // after enough generations, the benefit from this heuristic disappears | ||||
|     // if we start defragmenting the cache, the benefit from this will be more important | ||||
|     n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad))); | ||||
|     n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); | ||||
|  | ||||
| #ifdef FIND_SLOT_DEBUG | ||||
|     LLAMA_LOG_WARN("end:   n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); | ||||
| @@ -1018,7 +998,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( | ||||
| bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|     const uint32_t n_layer = layers.size(); | ||||
|  | ||||
|     const uint32_t n_kv   = cell_max(); | ||||
|     const uint32_t n_kv   = cells.used_max_p1(); | ||||
|     const uint32_t n_used = cells.get_used(); | ||||
|  | ||||
|     assert(n_used <= n_kv); | ||||
| @@ -1144,16 +1124,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| uint32_t llama_kv_cache_unified::cell_max() const { | ||||
|     for (uint32_t i = cells.size(); i > 0; --i) { | ||||
|         if (!cells.is_empty(i - 1)) { | ||||
|             return i; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     return 0; | ||||
| } | ||||
|  | ||||
| bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { | ||||
|     assert(p0 >= 0 && p1 >= 0); | ||||
|  | ||||
|   | ||||
| @@ -246,10 +246,6 @@ private: | ||||
|     // return true if cells have been moved | ||||
|     bool defrag_prepare(int32_t n_max_nodes); | ||||
|  | ||||
|     // find how many cells are currently in use | ||||
|     // TODO: optimize | ||||
|     uint32_t cell_max() const; | ||||
|  | ||||
|     size_t total_size() const; | ||||
|  | ||||
|     size_t size_k_bytes() const; | ||||
|   | ||||
| @@ -6,6 +6,7 @@ | ||||
| #include <bitset> | ||||
| #include <cassert> | ||||
| #include <vector> | ||||
| #include <set> | ||||
|  | ||||
| // meta information about KV cells that can be part of multiple sequences at the same time | ||||
| // TODO: add unit tests | ||||
| @@ -18,8 +19,13 @@ public: | ||||
|             seq[i].reset(); | ||||
|         } | ||||
|  | ||||
|         used      = 0; | ||||
|         has_shift = false; | ||||
|  | ||||
|         used.clear(); | ||||
|  | ||||
|         for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|             seq_pos[s].clear(); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void reset_shift() { | ||||
| @@ -50,7 +56,25 @@ public: | ||||
|     } | ||||
|  | ||||
|     uint32_t get_used() const { | ||||
|         return used; | ||||
|         return used.size(); | ||||
|     } | ||||
|  | ||||
|     // the index of the first cell that is used | ||||
|     // return 0 if no cells are used | ||||
|     uint32_t used_min() const { | ||||
|         return used.empty() ? 0 : *used.begin(); | ||||
|     } | ||||
|  | ||||
|     // the index of the last cell that is used + 1 | ||||
|     // return 0 if no cells are used | ||||
|     uint32_t used_max_p1() const { | ||||
| #if 0 | ||||
|         if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); | ||||
|         if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); | ||||
|         if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); | ||||
| #endif | ||||
|  | ||||
|         return used.empty() ? 0 : *used.rbegin() + 1; | ||||
|     } | ||||
|  | ||||
|     bool get_has_shift() const { | ||||
| @@ -69,6 +93,9 @@ public: | ||||
|         pos  [isrc] = -1; | ||||
|         shift[isrc] =  0; | ||||
|         seq  [isrc].reset(); | ||||
|  | ||||
|         used.erase (isrc); | ||||
|         used.insert(idst); | ||||
|     } | ||||
|  | ||||
|     // copy the state of cells [i, i + n) (used for save/restore the state of the cells) | ||||
| @@ -95,16 +122,24 @@ public: | ||||
|  | ||||
|         for (uint32_t j = 0; j < other.pos.size(); ++j) { | ||||
|             if (pos[i + j] == -1 && other.pos[j] != -1) { | ||||
|                 used++; | ||||
|                 used.insert(i + j); | ||||
|             } | ||||
|  | ||||
|             if (pos[i + j] != -1 && other.pos[j] == -1) { | ||||
|                 used--; | ||||
|                 used.erase(i + j); | ||||
|             } | ||||
|  | ||||
|             if (pos[i + j] != -1) { | ||||
|                 seq_pos_rm(i + j); | ||||
|             } | ||||
|  | ||||
|             pos[i + j] = other.pos[j]; | ||||
|             seq[i + j] = other.seq[j]; | ||||
|  | ||||
|             if (pos[i + j] != -1) { | ||||
|                 seq_pos_add(i + j); | ||||
|             } | ||||
|  | ||||
|             assert(shift[i + j] == 0); | ||||
|         } | ||||
|     } | ||||
| @@ -118,11 +153,12 @@ public: | ||||
|         assert(seq_id >= 0); | ||||
|  | ||||
|         seq[i].reset(seq_id); | ||||
|         seq_pos[seq_id].erase(pos[i]); | ||||
|  | ||||
|         if (seq[i].none()) { | ||||
|             pos[i] = -1; | ||||
|  | ||||
|             used--; | ||||
|             used.erase(i); | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
| @@ -135,17 +171,22 @@ public: | ||||
|         assert(i < pos.size()); | ||||
|  | ||||
|         if (seq[i].test(seq_id)) { | ||||
|             seq_pos_rm(i); | ||||
|             seq[i].reset(); | ||||
|  | ||||
|             seq[i].set(seq_id); | ||||
|             seq_pos[seq_id].insert(pos[i]); | ||||
|  | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         if (seq[i].any()) { | ||||
|             seq_pos_rm(i); | ||||
|             seq[i].reset(); | ||||
|  | ||||
|             pos[i] = -1; | ||||
|  | ||||
|             used--; | ||||
|             used.erase(i); | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
| @@ -169,6 +210,33 @@ public: | ||||
|         assert(!seq[i].test(seq_id)); | ||||
|  | ||||
|         seq[i].set(seq_id); | ||||
|         seq_pos[seq_id].insert(pos[i]); | ||||
|     } | ||||
|  | ||||
|     // the minimum position of sequence seq_id currently present in any of the cells | ||||
|     // return -1 if the sequence is not present | ||||
|     llama_pos seq_pos_min(llama_seq_id seq_id) const { | ||||
|         assert(seq_id >= 0); | ||||
|         assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); | ||||
|  | ||||
|         if (seq_pos[seq_id].empty()) { | ||||
|             return -1; | ||||
|         } | ||||
|  | ||||
|         return *seq_pos[seq_id].begin(); | ||||
|     } | ||||
|  | ||||
|     // the maximum position of sequence seq_id currently present in any of the cells | ||||
|     // return -1 if the sequence is not present | ||||
|     llama_pos seq_pos_max(llama_seq_id seq_id) const { | ||||
|         assert(seq_id >= 0); | ||||
|         assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); | ||||
|  | ||||
|         if (seq_pos[seq_id].empty()) { | ||||
|             return -1; | ||||
|         } | ||||
|  | ||||
|         return *seq_pos[seq_id].rbegin(); | ||||
|     } | ||||
|  | ||||
|     // note: call only if the cell is not empty | ||||
| @@ -202,7 +270,8 @@ public: | ||||
|         assert(pos[i] == -1); | ||||
|  | ||||
|         pos[i] = p; | ||||
|         used++; | ||||
|  | ||||
|         used.insert(i); | ||||
|     } | ||||
|  | ||||
|     // pos[i] = pos[i] + d | ||||
| @@ -212,16 +281,22 @@ public: | ||||
|         assert(i < pos.size()); | ||||
|         assert(pos[i] != -1); | ||||
|  | ||||
|         seq_pos_rm(i); | ||||
|  | ||||
|         pos[i]   += d; | ||||
|         shift[i] += d; | ||||
|  | ||||
|         seq_pos_add(i); | ||||
|  | ||||
|         has_shift = true; | ||||
|  | ||||
|         if (pos[i] < 0) { | ||||
|             pos[i] = -1; | ||||
|             seq[i].reset(); | ||||
|             seq_pos_rm(i); | ||||
|  | ||||
|             used--; | ||||
|             seq[i].reset(); | ||||
|             pos[i] = -1; | ||||
|  | ||||
|             used.erase(i); | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
| @@ -238,17 +313,22 @@ public: | ||||
|  | ||||
|         const llama_pos p_old = pos[i]; | ||||
|  | ||||
|         seq_pos_rm(i); | ||||
|  | ||||
|         pos[i]   /= d; | ||||
|         shift[i] += p_old - pos[i]; | ||||
|  | ||||
|         seq_pos_add(i); | ||||
|  | ||||
|         has_shift = true; | ||||
|     } | ||||
|  | ||||
| private: | ||||
|     uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id) | ||||
|  | ||||
|     bool has_shift = false; | ||||
|  | ||||
|     // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) | ||||
|     std::set<uint32_t> used; | ||||
|  | ||||
|     std::vector<llama_pos> pos; | ||||
|  | ||||
|     // this array accumulates any applied shifts to the pos array since the last reset_shift() call | ||||
| @@ -268,6 +348,32 @@ private: | ||||
|     // | ||||
|     std::vector<llama_pos> shift; | ||||
|  | ||||
|     std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq; | ||||
| }; | ||||
|     using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>; | ||||
|  | ||||
|     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell | ||||
|     std::vector<bits_t> seq; | ||||
|  | ||||
|     // the set seq_pos[s] tells us which positions are currently present for sequence s | ||||
|     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache | ||||
|     std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; | ||||
|  | ||||
|     // helper functions for updating `seq_pos`, once cell at a time: | ||||
|  | ||||
|     // remove cell i | ||||
|     void seq_pos_rm(uint32_t i) { | ||||
|         for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|             if (seq[i].test(s)) { | ||||
|                 seq_pos[s].erase(pos[i]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // add cell i | ||||
|     void seq_pos_add(uint32_t i) { | ||||
|         for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { | ||||
|             if (seq[i].test(s)) { | ||||
|                 seq_pos[s].insert(pos[i]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov