diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8737fba124..b84c119579 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2413,7 +2413,7 @@ struct server_context { params_dft.devices = params_base.speculative.devices; params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? slots.front().n_ctx : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; params_dft.cache_type_k = params_base.speculative.cache_type_k; @@ -2501,7 +2501,7 @@ struct server_context { } void init() { - const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; + const int32_t n_ctx_slot = params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel; SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); @@ -2705,6 +2705,36 @@ struct server_context { return ret; } + // return true if at least one slot has been purged + // TODO: improve logic + // - smarter decision which slot to purge + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_purge_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + + res = true; + } + } + + return res; + } + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); @@ -3661,9 +3691,10 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // next, batch any pending prompts without exceeding n_batch - float alora_scale = -1.0f; + float alora_scale = -1.0f; size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one @@ -4144,6 +4175,8 @@ struct server_context { std::string err; if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot and continue + // need to remove the tokens from the current batch too err = "Context size has been exceeded."; } @@ -4159,17 +4192,23 @@ struct server_context { // TODO: handle ret == 2 (abort) when we start aborting if (!err.empty()) { - SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + for (auto & slot : slots) { - send_error(slot, err); - slot.release(); + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + } } + break; } } // retry with half the batch size to try to find a free slot in the KV cache - n_batch /= 2; + if (!try_purge_idle_slots()) { + n_batch /= 2; + } SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); @@ -4963,7 +5002,7 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; + const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { auto n_prompt_tokens = inputs[i].size();