mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : allow exporting a view of the KV cache (#4180)
* Allow exporting a view of the KV cache * Allow dumping the sequences per cell in common * Track max contiguous cells value and position as well * Fix max contiguous empty cells index calculation Make dump functions deal with lengths or sequences counts > 10 better * Fix off by one error in dump_kv_cache_view * Add doc comments for KV cache view functions Eliminate cell sequence struct; use llama_seq_id directly Minor cleanups
This commit is contained in:
		| @@ -12,6 +12,7 @@ | |||||||
| #include <regex> | #include <regex> | ||||||
| #include <sstream> | #include <sstream> | ||||||
| #include <string> | #include <string> | ||||||
|  | #include <unordered_map> | ||||||
| #include <unordered_set> | #include <unordered_set> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <cinttypes> | #include <cinttypes> | ||||||
| @@ -1386,3 +1387,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l | |||||||
|     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); |     fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); | ||||||
|     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); |     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // | ||||||
|  | // KV cache utils | ||||||
|  | // | ||||||
|  |  | ||||||
|  | void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { | ||||||
|  |     static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; | ||||||
|  |  | ||||||
|  |     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", | ||||||
|  |         view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); | ||||||
|  |  | ||||||
|  |     llama_kv_cache_view_cell * c_curr = view.cells; | ||||||
|  |     llama_seq_id * cs_curr = view.cells_sequences; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { | ||||||
|  |         if (i % row_size == 0) { | ||||||
|  |             printf("\n%5d: ", i); | ||||||
|  |         } | ||||||
|  |         int seq_count = 0; | ||||||
|  |         for (int j = 0; j < view.n_max_seq; j++) { | ||||||
|  |             if (cs_curr[j] >= 0) { seq_count++; } | ||||||
|  |         } | ||||||
|  |         putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     printf("\n=== Done dumping\n"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { | ||||||
|  |     static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; | ||||||
|  |  | ||||||
|  |     printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", | ||||||
|  |         view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); | ||||||
|  |  | ||||||
|  |     std::unordered_map<llama_seq_id, size_t> seqs; | ||||||
|  |     llama_kv_cache_view_cell * c_curr = view.cells; | ||||||
|  |     llama_seq_id * cs_curr = view.cells_sequences; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { | ||||||
|  |         for (int j = 0; j < view.n_max_seq; j++) { | ||||||
|  |             if (cs_curr[j] < 0) { continue; } | ||||||
|  |             if (seqs.find(cs_curr[j]) == seqs.end()) { | ||||||
|  |                 if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } | ||||||
|  |                 seqs[cs_curr[j]] = seqs.size(); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     printf("=== Sequence legend: "); | ||||||
|  |     for (const auto & it : seqs) { | ||||||
|  |         printf("%zu=%d, ", it.second, it.first); | ||||||
|  |     } | ||||||
|  |     printf("'+'=other sequence ids"); | ||||||
|  |  | ||||||
|  |     c_curr = view.cells; | ||||||
|  |     cs_curr = view.cells_sequences; | ||||||
|  |     for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { | ||||||
|  |         if (i % row_size == 0) { | ||||||
|  |             printf("\n%5d: ", i); | ||||||
|  |         } | ||||||
|  |         for (int j = 0; j < view.n_max_seq; j++) { | ||||||
|  |             if (cs_curr[j] >= 0) { | ||||||
|  |                 const auto & it = seqs.find(cs_curr[j]); | ||||||
|  |                 putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); | ||||||
|  |             } else { | ||||||
|  |                 putchar('.'); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         putchar(' '); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     printf("\n=== Done dumping\n"); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -218,3 +218,13 @@ std::string get_sortable_timestamp(); | |||||||
| void dump_non_result_info_yaml( | void dump_non_result_info_yaml( | ||||||
|     FILE * stream, const gpt_params & params, const llama_context * lctx, |     FILE * stream, const gpt_params & params, const llama_context * lctx, | ||||||
|     const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc); |     const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc); | ||||||
|  |  | ||||||
|  | // | ||||||
|  | // KV cache utils | ||||||
|  | // | ||||||
|  |  | ||||||
|  | // Dump the KV cache view with the number of sequences per cell. | ||||||
|  | void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); | ||||||
|  |  | ||||||
|  | // Dump the KV cache view showing individual sequences in each cell (long output). | ||||||
|  | void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); | ||||||
|   | |||||||
| @@ -172,6 +172,8 @@ int main(int argc, char ** argv) { | |||||||
|     int32_t n_total_gen    = 0; |     int32_t n_total_gen    = 0; | ||||||
|     int32_t n_cache_miss   = 0; |     int32_t n_cache_miss   = 0; | ||||||
|  |  | ||||||
|  |     struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients); | ||||||
|  |  | ||||||
|     const auto t_main_start = ggml_time_us(); |     const auto t_main_start = ggml_time_us(); | ||||||
|  |  | ||||||
|     LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); |     LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); | ||||||
| @@ -201,6 +203,9 @@ int main(int argc, char ** argv) { | |||||||
|     LOG_TEE("Processing requests ...\n\n"); |     LOG_TEE("Processing requests ...\n\n"); | ||||||
|  |  | ||||||
|     while (true) { |     while (true) { | ||||||
|  |         llama_kv_cache_view_update(ctx, &kvc_view); | ||||||
|  |         dump_kv_cache_view_seqs(kvc_view, 40); | ||||||
|  |  | ||||||
|         llama_batch_clear(batch); |         llama_batch_clear(batch); | ||||||
|  |  | ||||||
|         // decode any currently ongoing sequences |         // decode any currently ongoing sequences | ||||||
|   | |||||||
							
								
								
									
										89
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										89
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -8805,6 +8805,95 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) { | ||||||
|  |     struct llama_kv_cache_view result = { | ||||||
|  |         /*.n_cells            = */ 0, | ||||||
|  |         /*.n_max_seq          = */ n_max_seq, | ||||||
|  |         /*.token_count        = */ 0, | ||||||
|  |         /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx), | ||||||
|  |         /*.max_contiguous     = */ 0, | ||||||
|  |         /*.max_contiguous_idx = */ -1, | ||||||
|  |         /*.cells              = */ nullptr, | ||||||
|  |         /*.cells_sequences    = */ nullptr, | ||||||
|  |     }; | ||||||
|  |     return result; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { | ||||||
|  |     if (view->cells != nullptr) { | ||||||
|  |         free(view->cells); | ||||||
|  |         view->cells = nullptr; | ||||||
|  |     } | ||||||
|  |     if (view->cells_sequences != nullptr) { | ||||||
|  |         free(view->cells_sequences); | ||||||
|  |         view->cells_sequences = nullptr; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { | ||||||
|  |     if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { | ||||||
|  |         view->n_cells = int32_t(ctx->kv_self.size); | ||||||
|  |         void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); | ||||||
|  |         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); | ||||||
|  |         view->cells = (struct llama_kv_cache_view_cell *)p; | ||||||
|  |         p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells); | ||||||
|  |         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); | ||||||
|  |         view->cells_sequences = (llama_seq_id *)p; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells; | ||||||
|  |     llama_kv_cache_view_cell * c_curr = view->cells; | ||||||
|  |     llama_seq_id * cs_curr = view->cells_sequences; | ||||||
|  |     int32_t used_cells = 0; | ||||||
|  |     int32_t token_count = 0; | ||||||
|  |     int32_t curr_contig_idx = -1; | ||||||
|  |     uint32_t max_contig = 0; | ||||||
|  |     int32_t max_contig_idx = -1; | ||||||
|  |  | ||||||
|  |     for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) { | ||||||
|  |         const size_t curr_size = kv_cells[i].seq_id.size(); | ||||||
|  |         token_count += curr_size; | ||||||
|  |         c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; | ||||||
|  |  | ||||||
|  |         if (curr_size > 0) { | ||||||
|  |             if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { | ||||||
|  |                 max_contig = i - curr_contig_idx; | ||||||
|  |                 max_contig_idx = curr_contig_idx; | ||||||
|  |             } | ||||||
|  |             curr_contig_idx = -1; | ||||||
|  |         } else if (curr_contig_idx < 0) { | ||||||
|  |             curr_contig_idx = i; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         int seq_idx = 0; | ||||||
|  |         for (const llama_seq_id it : kv_cells[i].seq_id) { | ||||||
|  |             if (seq_idx >= view->n_max_seq) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             cs_curr[seq_idx] = it; | ||||||
|  |             seq_idx++; | ||||||
|  |         } | ||||||
|  |         if (seq_idx != 0) { | ||||||
|  |             used_cells++; | ||||||
|  |         } | ||||||
|  |         for (; seq_idx < view->n_max_seq; seq_idx++) { | ||||||
|  |             cs_curr[seq_idx] = -1; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { | ||||||
|  |         max_contig_idx = curr_contig_idx; | ||||||
|  |         max_contig = kv_cells.size() - curr_contig_idx; | ||||||
|  |     } | ||||||
|  |     view->max_contiguous = max_contig; | ||||||
|  |     view->max_contiguous_idx = max_contig_idx; | ||||||
|  |     view->token_count = token_count; | ||||||
|  |     view->used_cells = used_cells; | ||||||
|  |     if (uint32_t(used_cells) != ctx->kv_self.used) { | ||||||
|  |         LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", | ||||||
|  |             __func__, ctx->kv_self.used, used_cells); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| int llama_get_kv_cache_token_count(const struct llama_context * ctx) { | int llama_get_kv_cache_token_count(const struct llama_context * ctx) { | ||||||
|     int result = 0; |     int result = 0; | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										48
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								llama.h
									
									
									
									
									
								
							| @@ -361,6 +361,54 @@ extern "C" { | |||||||
|     // KV cache |     // KV cache | ||||||
|     // |     // | ||||||
|  |  | ||||||
|  |     // Information associated with an individual cell in the KV cache view. | ||||||
|  |     struct llama_kv_cache_view_cell { | ||||||
|  |         // The position for this cell. Takes KV cache shifts into account. | ||||||
|  |         // May be negative if the cell is not populated. | ||||||
|  |         llama_pos pos; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // An updateable view of the KV cache. | ||||||
|  |     struct llama_kv_cache_view { | ||||||
|  |         // Number of KV cache cells. This will be the same as the context size. | ||||||
|  |         int32_t n_cells; | ||||||
|  |  | ||||||
|  |         // Maximum number of sequences that can exist in a cell. It's not an error | ||||||
|  |         // if there are more sequences in a cell than this value, however they will | ||||||
|  |         // not be visible in the view cells_sequences. | ||||||
|  |         int32_t n_max_seq; | ||||||
|  |  | ||||||
|  |         // Number of tokens in the cache. For example, if there are two populated | ||||||
|  |         // cells, the first with 1 sequence id in it and the second with 2 sequence | ||||||
|  |         // ids then you'll have 3 tokens. | ||||||
|  |         int32_t token_count; | ||||||
|  |  | ||||||
|  |         // Number of populated cache cells. | ||||||
|  |         int32_t used_cells; | ||||||
|  |  | ||||||
|  |         // Maximum contiguous empty slots in the cache. | ||||||
|  |         int32_t max_contiguous; | ||||||
|  |  | ||||||
|  |         // Index to the start of the max_contiguous slot range. Can be negative | ||||||
|  |         // when cache is full. | ||||||
|  |         int32_t max_contiguous_idx; | ||||||
|  |  | ||||||
|  |         // Information for an individual cell. | ||||||
|  |         struct llama_kv_cache_view_cell * cells; | ||||||
|  |  | ||||||
|  |         // The sequences for each cell. There will be n_max_seq items per cell. | ||||||
|  |         llama_seq_id * cells_sequences; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // Create an empty KV cache view. | ||||||
|  |     LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); | ||||||
|  |  | ||||||
|  |     // Free a KV cache view. | ||||||
|  |     LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); | ||||||
|  |  | ||||||
|  |     // Update the KV cache view structure with the current state of the KV cache. | ||||||
|  |     LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); | ||||||
|  |  | ||||||
|     // Returns the number of tokens in the KV cache (slow, use only for debug) |     // Returns the number of tokens in the KV cache (slow, use only for debug) | ||||||
|     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times |     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times | ||||||
|     LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); |     LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kerfuffle
					Kerfuffle