mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-18 11:46:58 +00:00
server : handle context overflow during decode (#17267)
* server : handle context overflow during decode * server : minor refactor
This commit is contained in:
@@ -1686,14 +1686,13 @@ struct server_slot {
|
|||||||
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
|
||||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||||
|
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
|
||||||
prompt.tokens.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<common_adapter_lora_info> lora;
|
std::vector<common_adapter_lora_info> lora;
|
||||||
@@ -2339,7 +2338,6 @@ struct server_context {
|
|||||||
|
|
||||||
llama_batch batch {};
|
llama_batch batch {};
|
||||||
|
|
||||||
bool clean_kv_cache = true;
|
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
|
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
@@ -2702,7 +2700,10 @@ struct server_context {
|
|||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
ret->prompt_save(*prompt_cache);
|
ret->prompt_save(*prompt_cache);
|
||||||
ret->prompt_load(*prompt_cache, task.tokens);
|
|
||||||
|
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||||
|
clear_slot(*ret);
|
||||||
|
}
|
||||||
|
|
||||||
prompt_cache->update();
|
prompt_cache->update();
|
||||||
|
|
||||||
@@ -2713,12 +2714,21 @@ struct server_context {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// return true if at least one slot has been purged
|
void clear_slot(server_slot & slot) const {
|
||||||
|
GGML_ASSERT(!slot.is_processing());
|
||||||
|
|
||||||
|
SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size());
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// return true if at least one slot has been cleared
|
||||||
// TODO: improve logic
|
// TODO: improve logic
|
||||||
// - smarter decision which slot to purge (LRU or longest prompt?)
|
// - smarter decision which slot to clear (LRU or longest prompt?)
|
||||||
// - move slot to level 2 cache instead of removing?
|
// - move slot to level 2 cache instead of removing?
|
||||||
// - instead of purging, try to store and resume later?
|
// - instead of purging, try to store and resume later?
|
||||||
bool try_purge_idle_slots() {
|
bool try_clear_idle_slots() {
|
||||||
bool res = false;
|
bool res = false;
|
||||||
|
|
||||||
if (!params_base.kv_unified) {
|
if (!params_base.kv_unified) {
|
||||||
@@ -2733,12 +2743,11 @@ struct server_context {
|
|||||||
if (slot.prompt.n_tokens() > 0) {
|
if (slot.prompt.n_tokens() > 0) {
|
||||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||||
|
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
clear_slot(slot);
|
||||||
slot.prompt.tokens.clear();
|
|
||||||
|
|
||||||
res = true;
|
res = true;
|
||||||
|
|
||||||
// purge slots one by one
|
// clear slots one by one
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2848,14 +2857,6 @@ struct server_context {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void kv_cache_clear() {
|
|
||||||
SRV_DBG("%s", "clearing KV cache\n");
|
|
||||||
|
|
||||||
// clear the entire KV cache
|
|
||||||
llama_memory_clear(llama_get_memory(ctx), true);
|
|
||||||
clean_kv_cache = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
const std::string token_str = result.text_to_send;
|
const std::string token_str = result.text_to_send;
|
||||||
@@ -3443,8 +3444,8 @@ struct server_context {
|
|||||||
|
|
||||||
// Erase token cache
|
// Erase token cache
|
||||||
const size_t n_erased = slot->prompt.tokens.size();
|
const size_t n_erased = slot->prompt.tokens.size();
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1);
|
|
||||||
slot->prompt.tokens.clear();
|
clear_slot(*slot);
|
||||||
|
|
||||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||||
res->id = task.id;
|
res->id = task.id;
|
||||||
@@ -3477,9 +3478,6 @@ struct server_context {
|
|||||||
|
|
||||||
if (all_idle) {
|
if (all_idle) {
|
||||||
SRV_INF("%s", "all slots are idle\n");
|
SRV_INF("%s", "all slots are idle\n");
|
||||||
if (clean_kv_cache) {
|
|
||||||
kv_cache_clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -3873,12 +3871,11 @@ struct server_context {
|
|||||||
|
|
||||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
|
||||||
|
clear_slot(slot);
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_prompt_tokens_cache = 0;
|
slot.n_prompt_tokens_cache = 0;
|
||||||
|
|
||||||
slot.prompt.tokens.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
@@ -4108,6 +4105,10 @@ struct server_context {
|
|||||||
if (slot.is_processing()) {
|
if (slot.is_processing()) {
|
||||||
send_error(slot, err);
|
send_error(slot, err);
|
||||||
slot.release();
|
slot.release();
|
||||||
|
|
||||||
|
// note: it's complicated to keep track of how much of the current batch has been
|
||||||
|
// processed before the error occurred, so we simply clear the entire context
|
||||||
|
clear_slot(slot);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4116,7 +4117,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
if (!try_purge_idle_slots()) {
|
if (!try_clear_idle_slots()) {
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user