mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server : avoid context swaps by shifting the KV cache
This commit is contained in:
		| @@ -381,6 +381,10 @@ struct llama_server_context | |||||||
|  |  | ||||||
|         // compare the evaluated prompt with the new prompt |         // compare the evaluated prompt with the new prompt | ||||||
|         n_past = common_part(embd, prompt_tokens); |         n_past = common_part(embd, prompt_tokens); | ||||||
|  |  | ||||||
|  |         // since #3228 we now have to manually manage the KV cache | ||||||
|  |         llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); | ||||||
|  |  | ||||||
|         embd = prompt_tokens; |         embd = prompt_tokens; | ||||||
|         if (n_past == num_prompt_tokens) |         if (n_past == num_prompt_tokens) | ||||||
|         { |         { | ||||||
| @@ -411,19 +415,27 @@ struct llama_server_context | |||||||
|  |  | ||||||
|         if (embd.size() >= (size_t)params.n_ctx) |         if (embd.size() >= (size_t)params.n_ctx) | ||||||
|         { |         { | ||||||
|             // Reset context |             // Shift context | ||||||
|             const int n_left = (params.n_ctx - params.n_keep) / 2; |  | ||||||
|  |             const int n_left    = n_past - params.n_keep - 1; | ||||||
|  |             const int n_discard = n_left/2; | ||||||
|  |  | ||||||
|  |             llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1); | ||||||
|  |             llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); | ||||||
|  |  | ||||||
|  |             for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) | ||||||
|  |             { | ||||||
|  |                 embd[i - n_discard] = embd[i]; | ||||||
|  |             } | ||||||
|  |             embd.resize(embd.size() - n_discard); | ||||||
|  |  | ||||||
|  |             n_past -= n_discard; | ||||||
|  |  | ||||||
|             std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep); |  | ||||||
|             new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); |  | ||||||
|             embd = new_tokens; |  | ||||||
|             n_past = params.n_keep; |  | ||||||
|             truncated = true; |             truncated = true; | ||||||
|             LOG_VERBOSE("input truncated", { |             LOG_VERBOSE("input truncated", { | ||||||
|                                                {"n_ctx", params.n_ctx}, |                                                {"n_ctx", params.n_ctx}, | ||||||
|                                                {"n_keep", params.n_keep}, |                                                {"n_keep", params.n_keep}, | ||||||
|                                                {"n_left", n_left}, |                                                {"n_left", n_left}, | ||||||
|                                                {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, |  | ||||||
|                                            }); |                                            }); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -435,9 +447,6 @@ struct llama_server_context | |||||||
|                 n_eval = params.n_batch; |                 n_eval = params.n_batch; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // since #3228 we now have to manually manage the KV cache |  | ||||||
|             llama_kv_cache_tokens_rm(ctx, n_past, -1); |  | ||||||
|  |  | ||||||
|             if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) |             if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) | ||||||
|             { |             { | ||||||
|                 LOG_ERROR("failed to eval", { |                 LOG_ERROR("failed to eval", { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user