mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
server : handle context overflow during decode (#17267)
* server : handle context overflow during decode * server : minor refactor
This commit is contained in:
@@ -1686,14 +1686,13 @@ struct server_slot {
|
||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||
}
|
||||
|
||||
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||
if (!res) {
|
||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
@@ -2339,7 +2338,6 @@ struct server_context {
|
||||
|
||||
llama_batch batch {};
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
@@ -2702,7 +2700,10 @@ struct server_context {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
ret->prompt_save(*prompt_cache);
|
||||
ret->prompt_load(*prompt_cache, task.tokens);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
clear_slot(*ret);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
|
||||
@@ -2713,12 +2714,21 @@ struct server_context {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// return true if at least one slot has been purged
|
||||
void clear_slot(server_slot & slot) const {
|
||||
GGML_ASSERT(!slot.is_processing());
|
||||
|
||||
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// return true if at least one slot has been cleared
|
||||
// TODO: improve logic
|
||||
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||
// - move slot to level 2 cache instead of removing?
|
||||
// - instead of purging, try to store and resume later?
|
||||
bool try_purge_idle_slots() {
|
||||
bool try_clear_idle_slots() {
|
||||
bool res = false;
|
||||
|
||||
if (!params_base.kv_unified) {
|
||||
@@ -2733,12 +2743,11 @@ struct server_context {
|
||||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
clear_slot(slot);
|
||||
|
||||
res = true;
|
||||
|
||||
// purge slots one by one
|
||||
// clear slots one by one
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -2848,14 +2857,6 @@ struct server_context {
|
||||
return true;
|
||||
}
|
||||
|
||||
void kv_cache_clear() {
|
||||
SRV_DBG("%s", "clearing KV cache\n");
|
||||
|
||||
// clear the entire KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
@@ -3443,8 +3444,8 @@ struct server_context {
|
||||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
|
||||
slot->prompt.tokens.clear();
|
||||
|
||||
clear_slot(*slot);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
@@ -3477,9 +3478,6 @@ struct server_context {
|
||||
|
||||
if (all_idle) {
|
||||
SRV_INF("%s", "all slots are idle\n");
|
||||
if (clean_kv_cache) {
|
||||
kv_cache_clear();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -3873,12 +3871,11 @@ struct server_context {
|
||||
|
||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
|
||||
clear_slot(slot);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
|
||||
slot.prompt.tokens.clear();
|
||||
}
|
||||
|
||||
// check if we should process the image
|
||||
@@ -4108,6 +4105,10 @@ struct server_context {
|
||||
if (slot.is_processing()) {
|
||||
send_error(slot, err);
|
||||
slot.release();
|
||||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
clear_slot(slot);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4116,7 +4117,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
if (!try_purge_idle_slots()) {
|
||||
if (!try_clear_idle_slots()) {
|
||||
n_batch /= 2;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user