server : support unified cache across slots (#16736)

* server : support unified context across slots

* cont : fix speculative decoding initialization

* context : fix n_ctx_per_seq computation

* server : purge slots one by one

* tests : add unified cache server tests

* llama : update per-seq context computation

* test-thread-safety : handle tiny training context of the input model

* server : fix server_tokens clear()

* server : use 4 slots + unified KV by default

* llama : add note about context size queries

* cont : update todos [no ci]

* context : do not cap the size of the context

* tests : adjust parameters to be CI friendlier

* context : add warning
This commit is contained in:
Georgi Gerganov
2025-11-02 18:14:04 +02:00
committed by GitHub
parent 87c9efc3b2
commit cd5e3b5754
12 changed files with 163 additions and 48 deletions

View File

@@ -2407,7 +2407,7 @@ struct server_context {
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
params_dft.cache_type_k = params_base.speculative.cache_type_k;
@@ -2495,10 +2495,16 @@ struct server_context {
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
const int n_ctx_train = llama_model_n_ctx_train(model);
int n_ctx_slot = llama_n_ctx_seq(ctx);
if (n_ctx_slot > n_ctx_train) {
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
n_ctx_slot = n_ctx_train;
}
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@@ -2527,7 +2533,7 @@ struct server_context {
}
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
slot.callback_on_release = [this](int) {
queue_tasks.pop_deferred_task();
@@ -2699,6 +2705,39 @@ struct server_context {
return ret;
}
// return true if at least one slot has been purged
// TODO: improve logic
// - smarter decision which slot to purge (LRU or longest prompt?)
// - move slot to level 2 cache instead of removing?
// - instead of purging, try to store and resume later?
bool try_purge_idle_slots() {
bool res = false;
if (!params_base.kv_unified) {
return res;
}
for (auto & slot : slots) {
if (slot.is_processing()) {
continue;
}
if (slot.prompt.n_tokens() > 0) {
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
res = true;
// purge slots one by one
break;
}
}
return res;
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
@@ -3635,9 +3674,10 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
float alora_scale = -1.0f;
float alora_scale = -1.0f;
size_t alora_disabled_id = 0;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
@@ -3914,8 +3954,11 @@ struct server_context {
// truncate any tokens that are beyond n_past for this slot
const llama_pos p0 = slot.prompt.tokens.pos_next();
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
@@ -3924,8 +3967,6 @@ struct server_context {
slot.prompt.tokens.clear();
}
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
// check if we should process the image
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
// process the image
@@ -4126,6 +4167,8 @@ struct server_context {
std::string err;
if (n_batch == 1 && ret == 1) {
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
// need to remove the tokens from the current batch too
err = "Context size has been exceeded.";
}
@@ -4141,17 +4184,23 @@ struct server_context {
// TODO: handle ret == 2 (abort) when we start aborting
if (!err.empty()) {
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
for (auto & slot : slots) {
send_error(slot, err);
slot.release();
if (slot.is_processing()) {
send_error(slot, err);
slot.release();
}
}
break;
}
}
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
if (!try_purge_idle_slots()) {
n_batch /= 2;
}
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
return 1;
}
// TODO: should we have a separate n_parallel parameter for the server?
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
if (params.n_parallel == 1 && params.kv_unified == false) {
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
params.n_parallel = 4;
params.kv_unified = true;
}
common_init();
// struct that contains llama context and inference
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();