mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix handling of "future" tokens when loading sessions
This commit is contained in:
		| @@ -543,6 +543,9 @@ int main(int argc, char ** argv) { | |||||||
|                 if (i > 0) { |                 if (i > 0) { | ||||||
|                     embd.erase(embd.begin(), embd.begin() + i); |                     embd.erase(embd.begin(), embd.begin() + i); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|  |                 // remove any "future" tokens that we might have inherited from the session from the KV cache | ||||||
|  |                 llama_kv_cache_tokens_rm(ctx, n_past, -1); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // evaluate tokens in batches |             // evaluate tokens in batches | ||||||
|   | |||||||
| @@ -332,7 +332,7 @@ int main(int argc, char ** argv) { | |||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache |                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache | ||||||
|                     llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx); |                     llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); | ||||||
|  |  | ||||||
|                     const auto t_main_end = ggml_time_us(); |                     const auto t_main_end = ggml_time_us(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -448,7 +448,7 @@ struct llama_server_context | |||||||
|         n_past = common_part(embd, prompt_tokens); |         n_past = common_part(embd, prompt_tokens); | ||||||
|  |  | ||||||
|         // since #3228 we now have to manually manage the KV cache |         // since #3228 we now have to manually manage the KV cache | ||||||
|         llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); |         llama_kv_cache_seq_rm(ctx, 0, n_past, -1); | ||||||
|  |  | ||||||
|         embd = prompt_tokens; |         embd = prompt_tokens; | ||||||
|         if (n_past == num_prompt_tokens) |         if (n_past == num_prompt_tokens) | ||||||
|   | |||||||
| @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { | |||||||
|                 LOG("out of drafted tokens\n"); |                 LOG("out of drafted tokens\n"); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); |             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); | ||||||
|             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); |             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); | ||||||
|             ++n_past_dft; |             ++n_past_dft; | ||||||
|  |  | ||||||
| @@ -257,7 +257,7 @@ int main(int argc, char ** argv) { | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             // evaluate the drafted token on the draft model |             // evaluate the drafted token on the draft model | ||||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); |             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); | ||||||
|             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); |             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); | ||||||
|             ++n_past_cur; |             ++n_past_cur; | ||||||
|  |  | ||||||
| @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // evaluate the target model on the drafted tokens |         // evaluate the target model on the drafted tokens | ||||||
|         llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); |         llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); | ||||||
|         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); |         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); | ||||||
|         ++n_past_tgt; |         ++n_past_tgt; | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										60
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1281,8 +1281,8 @@ static bool llama_kv_cache_init( | |||||||
| // find an empty slot of size "n_tokens" in the cache | // find an empty slot of size "n_tokens" in the cache | ||||||
| // updates the cache head | // updates the cache head | ||||||
| static bool llama_kv_cache_find_slot( | static bool llama_kv_cache_find_slot( | ||||||
|              struct llama_kv_cache & cache, |            struct llama_kv_cache & cache, | ||||||
|           const struct llama_batch & batch) { |         const struct llama_batch & batch) { | ||||||
|     const uint32_t n_ctx    = cache.size; |     const uint32_t n_ctx    = cache.size; | ||||||
|     const uint32_t n_tokens = batch.n_tokens; |     const uint32_t n_tokens = batch.n_tokens; | ||||||
|  |  | ||||||
| @@ -1350,10 +1350,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, | |||||||
| } | } | ||||||
|  |  | ||||||
| static void llama_kv_cache_seq_rm( | static void llama_kv_cache_seq_rm( | ||||||
|              struct llama_kv_cache & cache, |         struct llama_kv_cache & cache, | ||||||
|                       llama_seq_id   seq_id, |                  llama_seq_id   seq_id, | ||||||
|                          llama_pos   p0, |                     llama_pos   p0, | ||||||
|                          llama_pos   p1) { |                     llama_pos   p1) { | ||||||
|  |     if (p0 < 0) p0 = 0; | ||||||
|  |     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < cache.size; ++i) { |     for (uint32_t i = 0; i < cache.size; ++i) { | ||||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { |         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||||
|             cache.cells[i].seq_id.erase(seq_id); |             cache.cells[i].seq_id.erase(seq_id); | ||||||
| @@ -1365,11 +1368,14 @@ static void llama_kv_cache_seq_rm( | |||||||
| } | } | ||||||
|  |  | ||||||
| static void llama_kv_cache_seq_cp( | static void llama_kv_cache_seq_cp( | ||||||
|              struct llama_kv_cache & cache, |         struct llama_kv_cache & cache, | ||||||
|                       llama_seq_id   seq_id_src, |                  llama_seq_id   seq_id_src, | ||||||
|                       llama_seq_id   seq_id_dst, |                  llama_seq_id   seq_id_dst, | ||||||
|                          llama_pos   p0, |                     llama_pos   p0, | ||||||
|                          llama_pos   p1) { |                     llama_pos   p1) { | ||||||
|  |     if (p0 < 0) p0 = 0; | ||||||
|  |     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < cache.size; ++i) { |     for (uint32_t i = 0; i < cache.size; ++i) { | ||||||
|         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { |         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||||
|             cache.cells[i].seq_id.insert(seq_id_dst); |             cache.cells[i].seq_id.insert(seq_id_dst); | ||||||
| @@ -1387,11 +1393,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | |||||||
| } | } | ||||||
|  |  | ||||||
| static void llama_kv_cache_seq_shift( | static void llama_kv_cache_seq_shift( | ||||||
|              struct llama_kv_cache & cache, |         struct llama_kv_cache & cache, | ||||||
|                       llama_seq_id   seq_id, |                  llama_seq_id   seq_id, | ||||||
|                          llama_pos   p0, |                     llama_pos   p0, | ||||||
|                          llama_pos   p1, |                     llama_pos   p1, | ||||||
|                          llama_pos   delta) { |                     llama_pos   delta) { | ||||||
|  |     if (p0 < 0) p0 = 0; | ||||||
|  |     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||||
|  |  | ||||||
|     for (uint32_t i = 0; i < cache.size; ++i) { |     for (uint32_t i = 0; i < cache.size; ++i) { | ||||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { |         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||||
|             cache.cells[i].pos += delta; |             cache.cells[i].pos += delta; | ||||||
| @@ -7478,25 +7487,6 @@ void llama_batch_free(struct llama_batch batch) { | |||||||
| int llama_decode( | int llama_decode( | ||||||
|         struct llama_context * ctx, |         struct llama_context * ctx, | ||||||
|           struct llama_batch   batch) { |           struct llama_batch   batch) { | ||||||
|     // TODO: temporary solution to auto clear "future" tokens from the cache |  | ||||||
|     //       ref: https://github.com/ggerganov/llama.cpp/pull/3400 |  | ||||||
|     if (batch.pos) { |  | ||||||
|         std::map<llama_seq_id, llama_pos> seq_min_pos; |  | ||||||
|         for (int i = 0; i < batch.n_tokens; i++) { |  | ||||||
|             if (seq_min_pos.count(batch.seq_id[i]) == 0) { |  | ||||||
|                 seq_min_pos[batch.seq_id[i]] = batch.pos[i]; |  | ||||||
|             } else { |  | ||||||
|                 seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         for (auto & kv : seq_min_pos) { |  | ||||||
|             llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx); |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     const int ret = llama_decode_internal(*ctx, batch); |     const int ret = llama_decode_internal(*ctx, batch); | ||||||
|     if (ret < 0) { |     if (ret < 0) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								llama.h
									
									
									
									
									
								
							| @@ -330,12 +330,16 @@ extern "C" { | |||||||
|             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); |             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); | ||||||
|  |  | ||||||
|     // Remove all tokens data of cells in [c0, c1) |     // Remove all tokens data of cells in [c0, c1) | ||||||
|  |     // c0 < -1 : [0,  c1] | ||||||
|  |     // c1 < -1 : [c0, inf) | ||||||
|     LLAMA_API void llama_kv_cache_tokens_rm( |     LLAMA_API void llama_kv_cache_tokens_rm( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                          int32_t   c0, |                          int32_t   c0, | ||||||
|                          int32_t   c1); |                          int32_t   c1); | ||||||
|  |  | ||||||
|     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) |     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) | ||||||
|  |     // p0 < -1 : [0,  p1] | ||||||
|  |     // p1 < -1 : [p0, inf) | ||||||
|     LLAMA_API void llama_kv_cache_seq_rm( |     LLAMA_API void llama_kv_cache_seq_rm( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                     llama_seq_id   seq_id, |                     llama_seq_id   seq_id, | ||||||
| @@ -344,6 +348,8 @@ extern "C" { | |||||||
|  |  | ||||||
|     // Copy all tokens that belong to the specified sequence to another sequence |     // Copy all tokens that belong to the specified sequence to another sequence | ||||||
|     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence |     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence | ||||||
|  |     // p0 < -1 : [0,  p1] | ||||||
|  |     // p1 < -1 : [p0, inf) | ||||||
|     LLAMA_API void llama_kv_cache_seq_cp( |     LLAMA_API void llama_kv_cache_seq_cp( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                     llama_seq_id   seq_id_src, |                     llama_seq_id   seq_id_src, | ||||||
| @@ -358,6 +364,8 @@ extern "C" { | |||||||
|  |  | ||||||
|     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) |     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) | ||||||
|     // If the KV cache is RoPEd, the KV data is updated accordingly |     // If the KV cache is RoPEd, the KV data is updated accordingly | ||||||
|  |     // p0 < -1 : [0,  p1] | ||||||
|  |     // p1 < -1 : [p0, inf) | ||||||
|     LLAMA_API void llama_kv_cache_seq_shift( |     LLAMA_API void llama_kv_cache_seq_shift( | ||||||
|             struct llama_context * ctx, |             struct llama_context * ctx, | ||||||
|                     llama_seq_id   seq_id, |                     llama_seq_id   seq_id, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov