mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : fix handling of "future" tokens when loading sessions
This commit is contained in:
		| @@ -543,6 +543,9 @@ int main(int argc, char ** argv) { | ||||
|                 if (i > 0) { | ||||
|                     embd.erase(embd.begin(), embd.begin() + i); | ||||
|                 } | ||||
|  | ||||
|                 // remove any "future" tokens that we might have inherited from the session from the KV cache | ||||
|                 llama_kv_cache_tokens_rm(ctx, n_past, -1); | ||||
|             } | ||||
|  | ||||
|             // evaluate tokens in batches | ||||
|   | ||||
| @@ -332,7 +332,7 @@ int main(int argc, char ** argv) { | ||||
|                     } | ||||
|  | ||||
|                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache | ||||
|                     llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx); | ||||
|                     llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1); | ||||
|  | ||||
|                     const auto t_main_end = ggml_time_us(); | ||||
|  | ||||
|   | ||||
| @@ -448,7 +448,7 @@ struct llama_server_context | ||||
|         n_past = common_part(embd, prompt_tokens); | ||||
|  | ||||
|         // since #3228 we now have to manually manage the KV cache | ||||
|         llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); | ||||
|         llama_kv_cache_seq_rm(ctx, 0, n_past, -1); | ||||
|  | ||||
|         embd = prompt_tokens; | ||||
|         if (n_past == num_prompt_tokens) | ||||
|   | ||||
| @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { | ||||
|                 LOG("out of drafted tokens\n"); | ||||
|             } | ||||
|  | ||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); | ||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); | ||||
|             llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); | ||||
|             ++n_past_dft; | ||||
|  | ||||
| @@ -257,7 +257,7 @@ int main(int argc, char ** argv) { | ||||
|             } | ||||
|  | ||||
|             // evaluate the drafted token on the draft model | ||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); | ||||
|             llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); | ||||
|             llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); | ||||
|             ++n_past_cur; | ||||
|  | ||||
| @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { | ||||
|         } | ||||
|  | ||||
|         // evaluate the target model on the drafted tokens | ||||
|         llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); | ||||
|         llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); | ||||
|         llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); | ||||
|         ++n_past_tgt; | ||||
|  | ||||
|   | ||||
							
								
								
									
										60
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1281,8 +1281,8 @@ static bool llama_kv_cache_init( | ||||
| // find an empty slot of size "n_tokens" in the cache | ||||
| // updates the cache head | ||||
| static bool llama_kv_cache_find_slot( | ||||
|              struct llama_kv_cache & cache, | ||||
|           const struct llama_batch & batch) { | ||||
|            struct llama_kv_cache & cache, | ||||
|         const struct llama_batch & batch) { | ||||
|     const uint32_t n_ctx    = cache.size; | ||||
|     const uint32_t n_tokens = batch.n_tokens; | ||||
|  | ||||
| @@ -1350,10 +1350,13 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_rm( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].seq_id.erase(seq_id); | ||||
| @@ -1365,11 +1368,14 @@ static void llama_kv_cache_seq_rm( | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_cp( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id_src, | ||||
|                       llama_seq_id   seq_id_dst, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id_src, | ||||
|                  llama_seq_id   seq_id_dst, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].seq_id.insert(seq_id_dst); | ||||
| @@ -1387,11 +1393,14 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id | ||||
| } | ||||
|  | ||||
| static void llama_kv_cache_seq_shift( | ||||
|              struct llama_kv_cache & cache, | ||||
|                       llama_seq_id   seq_id, | ||||
|                          llama_pos   p0, | ||||
|                          llama_pos   p1, | ||||
|                          llama_pos   delta) { | ||||
|         struct llama_kv_cache & cache, | ||||
|                  llama_seq_id   seq_id, | ||||
|                     llama_pos   p0, | ||||
|                     llama_pos   p1, | ||||
|                     llama_pos   delta) { | ||||
|     if (p0 < 0) p0 = 0; | ||||
|     if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); | ||||
|  | ||||
|     for (uint32_t i = 0; i < cache.size; ++i) { | ||||
|         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { | ||||
|             cache.cells[i].pos += delta; | ||||
| @@ -7478,25 +7487,6 @@ void llama_batch_free(struct llama_batch batch) { | ||||
| int llama_decode( | ||||
|         struct llama_context * ctx, | ||||
|           struct llama_batch   batch) { | ||||
|     // TODO: temporary solution to auto clear "future" tokens from the cache | ||||
|     //       ref: https://github.com/ggerganov/llama.cpp/pull/3400 | ||||
|     if (batch.pos) { | ||||
|         std::map<llama_seq_id, llama_pos> seq_min_pos; | ||||
|         for (int i = 0; i < batch.n_tokens; i++) { | ||||
|             if (seq_min_pos.count(batch.seq_id[i]) == 0) { | ||||
|                 seq_min_pos[batch.seq_id[i]] = batch.pos[i]; | ||||
|             } else { | ||||
|                 seq_min_pos[batch.seq_id[i]] = std::min(seq_min_pos[batch.seq_id[i]], batch.pos[i]); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         for (auto & kv : seq_min_pos) { | ||||
|             llama_kv_cache_seq_rm(ctx->kv_self, kv.first, kv.second, ctx->cparams.n_ctx); | ||||
|         } | ||||
|     } else { | ||||
|         llama_kv_cache_seq_rm(ctx->kv_self, batch.all_seq_id, batch.all_pos_0, ctx->cparams.n_ctx); | ||||
|     } | ||||
|  | ||||
|     const int ret = llama_decode_internal(*ctx, batch); | ||||
|     if (ret < 0) { | ||||
|         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); | ||||
|   | ||||
							
								
								
									
										8
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								llama.h
									
									
									
									
									
								
							| @@ -330,12 +330,16 @@ extern "C" { | ||||
|             "avoid using this, it will be removed in the future, instead - count the tokens in user code"); | ||||
|  | ||||
|     // Remove all tokens data of cells in [c0, c1) | ||||
|     // c0 < -1 : [0,  c1] | ||||
|     // c1 < -1 : [c0, inf) | ||||
|     LLAMA_API void llama_kv_cache_tokens_rm( | ||||
|             struct llama_context * ctx, | ||||
|                          int32_t   c0, | ||||
|                          int32_t   c1); | ||||
|  | ||||
|     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) | ||||
|     // p0 < -1 : [0,  p1] | ||||
|     // p1 < -1 : [p0, inf) | ||||
|     LLAMA_API void llama_kv_cache_seq_rm( | ||||
|             struct llama_context * ctx, | ||||
|                     llama_seq_id   seq_id, | ||||
| @@ -344,6 +348,8 @@ extern "C" { | ||||
|  | ||||
|     // Copy all tokens that belong to the specified sequence to another sequence | ||||
|     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence | ||||
|     // p0 < -1 : [0,  p1] | ||||
|     // p1 < -1 : [p0, inf) | ||||
|     LLAMA_API void llama_kv_cache_seq_cp( | ||||
|             struct llama_context * ctx, | ||||
|                     llama_seq_id   seq_id_src, | ||||
| @@ -358,6 +364,8 @@ extern "C" { | ||||
|  | ||||
|     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) | ||||
|     // If the KV cache is RoPEd, the KV data is updated accordingly | ||||
|     // p0 < -1 : [0,  p1] | ||||
|     // p1 < -1 : [p0, inf) | ||||
|     LLAMA_API void llama_kv_cache_seq_shift( | ||||
|             struct llama_context * ctx, | ||||
|                     llama_seq_id   seq_id, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov