mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : keep track of used KV cells + better KV cache management
This commit is contained in:
		
							
								
								
									
										38
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1280,6 +1280,7 @@ struct llama_kv_cache { | ||||
|     // cannot be freely changed after a slot has been allocated. | ||||
|     uint32_t head = 0; | ||||
|     uint32_t size = 0; | ||||
|     uint32_t used = 0; // used cells (i.e. at least one seq_id) | ||||
|  | ||||
|     // computed before each graph build | ||||
|     uint32_t n = 0; | ||||
| @@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init( | ||||
|  | ||||
|     cache.head = 0; | ||||
|     cache.size = n_ctx; | ||||
|     cache.used = 0; | ||||
|  | ||||
|     cache.cells.clear(); | ||||
|     cache.cells.resize(n_ctx); | ||||
| @@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot( | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     cache.used += n_tokens; | ||||
|  | ||||
|     return true; | ||||
| } | ||||
|  | ||||
| @@ -1647,6 +1651,9 @@ static void llama_kv_cache_seq_rm( | ||||
|                 continue; | ||||
|             } | ||||
|             if (cache.cells[i].seq_id.empty()) { | ||||
|                 // keep count of the number of used cells | ||||
|                 if (cache.cells[i].pos >= 0) cache.used--; | ||||
|  | ||||
|                 cache.cells[i].pos = -1; | ||||
|                 if (new_head == cache.size) new_head = i; | ||||
|             } | ||||
| @@ -1654,7 +1661,7 @@ static void llama_kv_cache_seq_rm( | ||||
|     } | ||||
|  | ||||
|     // If we freed up a slot, set head to it so searching can start there. | ||||
|     if (new_head != cache.size) cache.head = new_head; | ||||
|     if (new_head != cache.size && new_head < cache.head) cache.head = new_head; | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_cp( | ||||
| @@ -1680,6 +1687,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (!cache.cells[i].has_seq_id(seq_id)) { | ||||
|             if (cache.cells[i].pos >= 0) cache.used--; | ||||
|             cache.cells[i].pos = -1; | ||||
|             cache.cells[i].seq_id.clear(); | ||||
|             if (new_head == cache.size) new_head = i; | ||||
| @@ -1690,7 +1698,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | ||||
|     } | ||||
|  | ||||
|     // If we freed up a slot, set head to it so searching can start there. | ||||
|     if (new_head != cache.size) cache.head = new_head; | ||||
|     if (new_head != cache.size && new_head < cache.head) cache.head = new_head; | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_shift( | ||||
| @@ -1711,6 +1719,7 @@ static void llama_kv_cache_seq_shift( | ||||
|             cache.cells[i].delta += delta; | ||||
|  | ||||
|             if (cache.cells[i].pos < 0) { | ||||
|                 if (!cache.cells[i].seq_id.empty()) cache.used--; | ||||
|                 cache.cells[i].pos = -1; | ||||
|                 cache.cells[i].seq_id.clear(); | ||||
|                 if (new_head == cache.size) new_head = i; | ||||
| @@ -5469,6 +5478,12 @@ static int llama_decode_internal( | ||||
|         batch.seq_id = seq_id_arr.data(); | ||||
|     } | ||||
|  | ||||
|     // if we have enough unused cells before the current head -> | ||||
|     //   better to start searching from the beginning of the cache, hoping to fill it | ||||
|     if (kv_self.head > kv_self.used + 2*n_tokens) { | ||||
|         kv_self.head = 0; | ||||
|     } | ||||
|  | ||||
|     if (!llama_kv_cache_find_slot(kv_self, batch)) { | ||||
|         return 1; | ||||
|     } | ||||
| @@ -5479,7 +5494,7 @@ static int llama_decode_internal( | ||||
|     //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32));   // TODO: this might be better for CUDA? | ||||
|     kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); | ||||
|  | ||||
|     //printf("kv_self.n = %d\n", kv_self.n); | ||||
|     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); | ||||
|  | ||||
|     ggml_allocr_reset(lctx.alloc); | ||||
|  | ||||
| @@ -8790,7 +8805,17 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha | ||||
| } | ||||
|  | ||||
| int llama_get_kv_cache_token_count(const struct llama_context * ctx) { | ||||
|     return ctx->kv_self.head; | ||||
|     int result = 0; | ||||
|  | ||||
|     for (uint32_t i = 0; i < ctx->kv_self.size; i++) { | ||||
|         result += ctx->kv_self.cells[i].seq_id.size(); | ||||
|     } | ||||
|  | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| int llama_get_kv_cache_used_cells(const struct llama_context * ctx) { | ||||
|     return ctx->kv_self.used; | ||||
| } | ||||
|  | ||||
| void llama_kv_cache_clear(struct llama_context * ctx) { | ||||
| @@ -8960,10 +8985,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat | ||||
|         const size_t   kv_buf_size = kv_self.buf.size; | ||||
|         const uint32_t kv_head     = kv_self.head; | ||||
|         const uint32_t kv_size     = kv_self.size; | ||||
|         const uint32_t kv_used     = kv_self.used; | ||||
|  | ||||
|         data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); | ||||
|         data_ctx->write(&kv_head,     sizeof(kv_head)); | ||||
|         data_ctx->write(&kv_size,     sizeof(kv_size)); | ||||
|         data_ctx->write(&kv_used,     sizeof(kv_used)); | ||||
|  | ||||
|         if (kv_buf_size) { | ||||
|             const size_t elt_size = ggml_element_size(kv_self.k); | ||||
| @@ -9086,10 +9113,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|         size_t   kv_buf_size; | ||||
|         uint32_t kv_head; | ||||
|         uint32_t kv_size; | ||||
|         uint32_t kv_used; | ||||
|  | ||||
|         memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); | ||||
|         memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head); | ||||
|         memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size); | ||||
|         memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used); | ||||
|  | ||||
|         if (kv_buf_size) { | ||||
|             GGML_ASSERT(kv_self.buf.size == kv_buf_size); | ||||
| @@ -9124,6 +9153,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { | ||||
|  | ||||
|         ctx->kv_self.head = kv_head; | ||||
|         ctx->kv_self.size = kv_size; | ||||
|         ctx->kv_self.used = kv_used; | ||||
|  | ||||
|         ctx->kv_self.cells.resize(kv_size); | ||||
|  | ||||
|   | ||||
							
								
								
									
										9
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								llama.h
									
									
									
									
									
								
							| @@ -361,9 +361,12 @@ extern "C" { | ||||
|     // KV cache | ||||
|     // | ||||
|  | ||||
|     // Returns the number of tokens in the KV cache | ||||
|     LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), | ||||
|             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); | ||||
|     // Returns the number of tokens in the KV cache (slow, use only for debug) | ||||
|     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times | ||||
|     LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); | ||||
|  | ||||
|     // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) | ||||
|     LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx); | ||||
|  | ||||
|     // Clear the KV cache | ||||
|     LLAMA_API void llama_kv_cache_clear( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov