mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	kv-cells : fix tracking of seq_pos (#14339)
* kv-cells : fix tracking of seq_pos during cache reuse ggml-ci * cont : improve error message ggml-ci * cont : add more comments
This commit is contained in:
		| @@ -944,12 +944,14 @@ extern "C" { | |||||||
|     // Requires the context to have a memory. |     // Requires the context to have a memory. | ||||||
|     // For encode-decoder contexts, processes the batch using the decoder. |     // For encode-decoder contexts, processes the batch using the decoder. | ||||||
|     // Positive return values does not mean a fatal error, but rather a warning. |     // Positive return values does not mean a fatal error, but rather a warning. | ||||||
|     // Upon non-zero return values, the memory state is restored to the state before this call |     // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context | ||||||
|  |     //   To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max() | ||||||
|  |     // Upon other return values, the memory state is restored to the state before this call | ||||||
|     //    0 - success |     //    0 - success | ||||||
|     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) |     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) | ||||||
|     //    2 - aborted |     //    2 - aborted     (processed ubatches will remain in the context's memory) | ||||||
|     //   -1 - invalid input batch |     //   -1 - invalid input batch | ||||||
|     // < -1 - error |     // < -1 - fatal error (processed ubatches will remain in the context's memory) | ||||||
|     LLAMA_API int32_t llama_decode( |     LLAMA_API int32_t llama_decode( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|               struct llama_batch   batch); |               struct llama_batch   batch); | ||||||
|   | |||||||
| @@ -245,10 +245,11 @@ bool llama_batch_allocr::init( | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (memory) { |         if (memory) { | ||||||
|  |             bool ok = true; | ||||||
|  |  | ||||||
|             if (batch.token) { |             if (batch.token) { | ||||||
|                 if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) { |                 if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) { | ||||||
|                     LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); |                     ok = false; | ||||||
|                     return false; |  | ||||||
|                 } |                 } | ||||||
|             } else { |             } else { | ||||||
|                 assert(batch.embd); |                 assert(batch.embd); | ||||||
| @@ -256,10 +257,20 @@ bool llama_batch_allocr::init( | |||||||
|                 // for embeddings (typically used as vision input), we allow them to have repeating positions |                 // for embeddings (typically used as vision input), we allow them to have repeating positions | ||||||
|                 // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 |                 // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 | ||||||
|                 if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { |                 if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { | ||||||
|                     LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); |                     ok = false; | ||||||
|                     return false; |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             if (!ok) { | ||||||
|  |                 LLAMA_LOG_ERROR( | ||||||
|  |                         "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" | ||||||
|  |                         " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" | ||||||
|  |                         " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" | ||||||
|  |                         " it is required that the sequence positions remain consecutive: Y = X + 1\n", | ||||||
|  |                         __func__, s, s, memory->seq_pos_max(s), s, seq_pos_min(s)); | ||||||
|  |  | ||||||
|  |                 return false; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { |         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { | ||||||
|   | |||||||
| @@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) { | |||||||
|                 pos_min[s] = std::numeric_limits<llama_pos>::max(); |                 pos_min[s] = std::numeric_limits<llama_pos>::max(); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // TODO: fix sequence indexing |  | ||||||
|             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { | ||||||
|                 const auto & seq_id = ubatch.seq_id[i][0]; |                 const auto & seq_id = ubatch.seq_id[i][0]; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ | |||||||
| #include <cassert> | #include <cassert> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <set> | #include <set> | ||||||
|  | #include <map> | ||||||
|  |  | ||||||
| // meta information about KV cells that can be part of multiple sequences at the same time | // meta information about KV cells that can be part of multiple sequences at the same time | ||||||
| // TODO: add unit tests | // TODO: add unit tests | ||||||
| @@ -164,7 +165,7 @@ public: | |||||||
|         assert(seq_id >= 0); |         assert(seq_id >= 0); | ||||||
|  |  | ||||||
|         seq[i].reset(seq_id); |         seq[i].reset(seq_id); | ||||||
|         seq_pos[seq_id].erase(pos[i]); |         seq_pos_dec(seq_id, pos[i]); | ||||||
|  |  | ||||||
|         if (seq[i].none()) { |         if (seq[i].none()) { | ||||||
|             pos[i] = -1; |             pos[i] = -1; | ||||||
| @@ -187,7 +188,7 @@ public: | |||||||
|             seq[i].reset(); |             seq[i].reset(); | ||||||
|  |  | ||||||
|             seq[i].set(seq_id); |             seq[i].set(seq_id); | ||||||
|             seq_pos[seq_id].insert(pos[i]); |             seq_pos_inc(seq_id, pos[i]); | ||||||
|  |  | ||||||
|             return false; |             return false; | ||||||
|         } |         } | ||||||
| @@ -232,7 +233,7 @@ public: | |||||||
|         assert(!seq[i].test(seq_id)); |         assert(!seq[i].test(seq_id)); | ||||||
|  |  | ||||||
|         seq[i].set(seq_id); |         seq[i].set(seq_id); | ||||||
|         seq_pos[seq_id].insert(pos[i]); |         seq_pos_inc(seq_id, pos[i]); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // return the sequence id of this cell |     // return the sequence id of this cell | ||||||
| @@ -259,7 +260,9 @@ public: | |||||||
|             return -1; |             return -1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return *seq_pos[seq_id].begin(); |         assert(seq_pos[seq_id].begin()->second > 0); | ||||||
|  |  | ||||||
|  |         return seq_pos[seq_id].begin()->first; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // the maximum position of sequence seq_id currently present in any of the cells |     // the maximum position of sequence seq_id currently present in any of the cells | ||||||
| @@ -272,7 +275,9 @@ public: | |||||||
|             return -1; |             return -1; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return *seq_pos[seq_id].rbegin(); |         assert(seq_pos[seq_id].rbegin()->second > 0); | ||||||
|  |  | ||||||
|  |         return seq_pos[seq_id].rbegin()->first; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // note: call only if the cell is not empty |     // note: call only if the cell is not empty | ||||||
| @@ -389,17 +394,36 @@ private: | |||||||
|     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell |     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell | ||||||
|     std::vector<seq_set_t> seq; |     std::vector<seq_set_t> seq; | ||||||
|  |  | ||||||
|     // the set seq_pos[s] tells us which positions are currently present for sequence s |     // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s | ||||||
|  |     // if the position p is not present, seq_pos[s][p] is not set | ||||||
|     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache |     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache | ||||||
|     std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ]; |     // | ||||||
|  |     // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq: | ||||||
|  |     //  - during performing a cache reuse via (rm + add) | ||||||
|  |     //  - some vision models have input embeddings with repeating positions | ||||||
|  |     // | ||||||
|  |     std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ]; | ||||||
|  |  | ||||||
|     // helper functions for updating `seq_pos`, once cell at a time: |     // helper functions for updating `seq_pos`, once cell at a time: | ||||||
|  |  | ||||||
|  |     void seq_pos_dec(llama_seq_id s, llama_pos p) { | ||||||
|  |         auto it = seq_pos[s].find(p); | ||||||
|  |         assert(it != seq_pos[s].end()); | ||||||
|  |  | ||||||
|  |         if (--it->second == 0) { | ||||||
|  |             seq_pos[s].erase(it); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void seq_pos_inc(llama_seq_id s, llama_pos p) { | ||||||
|  |         seq_pos[s][p]++; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // remove cell i |     // remove cell i | ||||||
|     void seq_pos_rm(uint32_t i) { |     void seq_pos_rm(uint32_t i) { | ||||||
|         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { | ||||||
|             if (seq[i].test(s)) { |             if (seq[i].test(s)) { | ||||||
|                 seq_pos[s].erase(pos[i]); |                 seq_pos_dec(s, pos[i]); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -408,7 +432,7 @@ private: | |||||||
|     void seq_pos_add(uint32_t i) { |     void seq_pos_add(uint32_t i) { | ||||||
|         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { | ||||||
|             if (seq[i].test(s)) { |             if (seq[i].test(s)) { | ||||||
|                 seq_pos[s].insert(pos[i]); |                 seq_pos_inc(s, pos[i]); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -3418,9 +3418,12 @@ struct server_context { | |||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     if (ret < -1) { |                     if (ret < -1) { | ||||||
|  |                         // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() | ||||||
|                         err = "Compute error."; |                         err = "Compute error."; | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|  |                     // TODO: handle ret == 2 (abort) when we start aborting | ||||||
|  |  | ||||||
|                     if (!err.empty()) { |                     if (!err.empty()) { | ||||||
|                         SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); |                         SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); | ||||||
|                         for (auto & slot : slots) { |                         for (auto & slot : slots) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov