mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
server : support unified context across slots
This commit is contained in:
@@ -2413,7 +2413,7 @@ struct server_context {
|
|||||||
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
params_dft.devices = params_base.speculative.devices;
|
||||||
params_dft.model = params_base.speculative.model;
|
params_dft.model = params_base.speculative.model;
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? slots.front().n_ctx : params_base.speculative.n_ctx;
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
params_dft.n_parallel = 1;
|
params_dft.n_parallel = 1;
|
||||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||||
@@ -2501,7 +2501,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
const int32_t n_ctx_slot = params_base.kv_unified ? n_ctx : n_ctx / params_base.n_parallel;
|
||||||
|
|
||||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||||
|
|
||||||
@@ -2705,6 +2705,36 @@ struct server_context {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return true if at least one slot has been purged
|
||||||
|
// TODO: improve logic
|
||||||
|
// - smarter decision which slot to purge
|
||||||
|
// - move slot to level 2 cache instead of removing?
|
||||||
|
// - instead of purging, try to store and resume later?
|
||||||
|
bool try_purge_idle_slots() {
|
||||||
|
bool res = false;
|
||||||
|
|
||||||
|
if (!params_base.kv_unified) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
if (slot.is_processing()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.prompt.n_tokens() > 0) {
|
||||||
|
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
|
||||||
|
res = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
@@ -3661,9 +3691,10 @@ struct server_context {
|
|||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
float alora_scale = -1.0f;
|
||||||
float alora_scale = -1.0f;
|
|
||||||
size_t alora_disabled_id = 0;
|
size_t alora_disabled_id = 0;
|
||||||
|
|
||||||
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
@@ -4144,6 +4175,8 @@ struct server_context {
|
|||||||
std::string err;
|
std::string err;
|
||||||
|
|
||||||
if (n_batch == 1 && ret == 1) {
|
if (n_batch == 1 && ret == 1) {
|
||||||
|
// TODO: try to terminate only the largest active slot and continue
|
||||||
|
// need to remove the tokens from the current batch too
|
||||||
err = "Context size has been exceeded.";
|
err = "Context size has been exceeded.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4159,17 +4192,23 @@ struct server_context {
|
|||||||
// TODO: handle ret == 2 (abort) when we start aborting
|
// TODO: handle ret == 2 (abort) when we start aborting
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
send_error(slot, err);
|
if (slot.is_processing()) {
|
||||||
slot.release();
|
send_error(slot, err);
|
||||||
|
slot.release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
n_batch /= 2;
|
if (!try_purge_idle_slots()) {
|
||||||
|
n_batch /= 2;
|
||||||
|
}
|
||||||
|
|
||||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||||
|
|
||||||
@@ -4963,7 +5002,7 @@ int main(int argc, char ** argv) {
|
|||||||
// Everything else, including multimodal completions.
|
// Everything else, including multimodal completions.
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
|
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto n_prompt_tokens = inputs[i].size();
|
auto n_prompt_tokens = inputs[i].size();
|
||||||
|
|||||||
Reference in New Issue
Block a user