mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
server : support unified cache across slots (#16736)
* server : support unified context across slots * cont : fix speculative decoding initialization * context : fix n_ctx_per_seq computation * server : purge slots one by one * tests : add unified cache server tests * llama : update per-seq context computation * test-thread-safety : handle tiny training context of the input model * server : fix server_tokens clear() * server : use 4 slots + unified KV by default * llama : add note about context size queries * cont : update todos [no ci] * context : do not cap the size of the context * tests : adjust parameters to be CI friendlier * context : add warning
This commit is contained in:
@@ -461,7 +461,10 @@ extern "C" {
|
|||||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
LLAMA_API bool llama_supports_rpc (void);
|
LLAMA_API bool llama_supports_rpc (void);
|
||||||
|
|
||||||
|
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||||
|
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
|
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
@@ -585,7 +588,7 @@ extern "C" {
|
|||||||
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
|
||||||
|
|
||||||
// Manually free a LoRA adapter
|
// Manually free a LoRA adapter
|
||||||
// Note: loaded adapters will be free when the associated model is deleted
|
// NOTE: loaded adapters will be free when the associated model is deleted
|
||||||
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
|
||||||
|
|
||||||
// Get the invocation tokens if the current lora is an alora
|
// Get the invocation tokens if the current lora is an alora
|
||||||
|
|||||||
@@ -112,11 +112,24 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
if (cparams.kv_unified) {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx;
|
||||||
|
} else {
|
||||||
|
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
|
||||||
|
if (cparams.n_ctx_seq == 0) {
|
||||||
|
throw std::runtime_error("n_ctx_seq == 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) {
|
||||||
|
cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max;
|
||||||
|
LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n", __func__, n_ctx_per_seq);
|
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||||
@@ -125,14 +138,14 @@ llama_context::llama_context(
|
|||||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||||
|
|
||||||
if (n_ctx_per_seq < hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
if (cparams.n_ctx_seq > hparams.n_ctx_train) {
|
||||||
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
||||||
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
@@ -453,8 +466,8 @@ uint32_t llama_context::n_ctx() const {
|
|||||||
return cparams.n_ctx;
|
return cparams.n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_ctx_per_seq() const {
|
uint32_t llama_context::n_ctx_seq() const {
|
||||||
return cparams.n_ctx / cparams.n_seq_max;
|
return cparams.n_ctx_seq;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_context::n_batch() const {
|
uint32_t llama_context::n_batch() const {
|
||||||
@@ -2383,6 +2396,10 @@ uint32_t llama_n_ctx(const llama_context * ctx) {
|
|||||||
return ctx->n_ctx();
|
return ctx->n_ctx();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_n_ctx_seq(const llama_context * ctx) {
|
||||||
|
return ctx->n_ctx_seq();
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_n_batch(const llama_context * ctx) {
|
uint32_t llama_n_batch(const llama_context * ctx) {
|
||||||
return ctx->n_batch();
|
return ctx->n_batch();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ struct llama_context {
|
|||||||
ggml_backend_sched_t get_sched() const;
|
ggml_backend_sched_t get_sched() const;
|
||||||
|
|
||||||
uint32_t n_ctx() const;
|
uint32_t n_ctx() const;
|
||||||
uint32_t n_ctx_per_seq() const;
|
uint32_t n_ctx_seq() const;
|
||||||
uint32_t n_batch() const;
|
uint32_t n_batch() const;
|
||||||
uint32_t n_ubatch() const;
|
uint32_t n_ubatch() const;
|
||||||
uint32_t n_seq_max() const;
|
uint32_t n_seq_max() const;
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
|
uint32_t n_ctx_seq; // context for a single sequence
|
||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
uint32_t n_ubatch;
|
uint32_t n_ubatch;
|
||||||
uint32_t n_seq_max;
|
uint32_t n_seq_max;
|
||||||
|
|||||||
@@ -6712,14 +6712,14 @@ float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) co
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
|
||||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
const uint32_t n_ctx_seq = cparams.n_ctx_seq;
|
||||||
|
|
||||||
// choose long/short freq factors based on the context size
|
// choose long/short freq factors based on the context size
|
||||||
if (layers[il].rope_freqs != nullptr) {
|
if (layers[il].rope_freqs != nullptr) {
|
||||||
return layers[il].rope_freqs;
|
return layers[il].rope_freqs;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
|
if (n_ctx_seq > hparams.n_ctx_orig_yarn) {
|
||||||
return layers[il].rope_long;
|
return layers[il].rope_long;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -6795,12 +6795,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
/* filter_attn */ std::move(filter_attn),
|
/* filter_attn */ std::move(filter_attn),
|
||||||
/* filter_recr */ std::move(filter_recr));
|
/* filter_recr */ std::move(filter_recr));
|
||||||
} else {
|
} else {
|
||||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
|
||||||
|
|
||||||
if (!cparams.kv_unified) {
|
|
||||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||||
|
|
||||||
if (arch == LLM_ARCH_GEMMA3N) {
|
if (arch == LLM_ARCH_GEMMA3N) {
|
||||||
@@ -6824,7 +6818,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
params.swa_full,
|
params.swa_full,
|
||||||
cparams.kv_unified,
|
cparams.kv_unified,
|
||||||
n_ctx_per_stream,
|
cparams.n_ctx_seq,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
cparams.n_ubatch,
|
cparams.n_ubatch,
|
||||||
1,
|
1,
|
||||||
@@ -6840,7 +6834,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||||||
!cparams.flash_attn,
|
!cparams.flash_attn,
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
cparams.kv_unified,
|
cparams.kv_unified,
|
||||||
n_ctx_per_stream,
|
cparams.n_ctx_seq,
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max,
|
||||||
1,
|
1,
|
||||||
hparams.n_swa,
|
hparams.n_swa,
|
||||||
|
|||||||
@@ -131,7 +131,14 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
batch = llama_batch_get_one(&token, 1);
|
batch = llama_batch_get_one(&token, 1);
|
||||||
if (llama_decode(ctx.get(), batch)) {
|
|
||||||
|
int ret = llama_decode(ctx.get(), batch);
|
||||||
|
if (ret == 1 && i > 0) {
|
||||||
|
LOG_INF("Context full, stopping generation.\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret != 0) {
|
||||||
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
|
LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
|
||||||
failed.store(true);
|
failed.store(true);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -2407,7 +2407,7 @@ struct server_context {
|
|||||||
|
|
||||||
params_dft.devices = params_base.speculative.devices;
|
params_dft.devices = params_base.speculative.devices;
|
||||||
params_dft.model = params_base.speculative.model;
|
params_dft.model = params_base.speculative.model;
|
||||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx;
|
||||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||||
params_dft.n_parallel = 1;
|
params_dft.n_parallel = 1;
|
||||||
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
params_dft.cache_type_k = params_base.speculative.cache_type_k;
|
||||||
@@ -2495,10 +2495,16 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
||||||
|
|
||||||
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
||||||
|
|
||||||
|
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||||
|
|
||||||
|
int n_ctx_slot = llama_n_ctx_seq(ctx);
|
||||||
|
if (n_ctx_slot > n_ctx_train) {
|
||||||
|
SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train);
|
||||||
|
n_ctx_slot = n_ctx_train;
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
server_slot slot;
|
server_slot slot;
|
||||||
|
|
||||||
@@ -2527,7 +2533,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||||
|
|
||||||
slot.callback_on_release = [this](int) {
|
slot.callback_on_release = [this](int) {
|
||||||
queue_tasks.pop_deferred_task();
|
queue_tasks.pop_deferred_task();
|
||||||
@@ -2699,6 +2705,39 @@ struct server_context {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// return true if at least one slot has been purged
|
||||||
|
// TODO: improve logic
|
||||||
|
// - smarter decision which slot to purge (LRU or longest prompt?)
|
||||||
|
// - move slot to level 2 cache instead of removing?
|
||||||
|
// - instead of purging, try to store and resume later?
|
||||||
|
bool try_purge_idle_slots() {
|
||||||
|
bool res = false;
|
||||||
|
|
||||||
|
if (!params_base.kv_unified) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
if (slot.is_processing()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.prompt.n_tokens() > 0) {
|
||||||
|
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
|
||||||
|
res = true;
|
||||||
|
|
||||||
|
// purge slots one by one
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
@@ -3635,9 +3674,10 @@ struct server_context {
|
|||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
|
||||||
float alora_scale = -1.0f;
|
float alora_scale = -1.0f;
|
||||||
size_t alora_disabled_id = 0;
|
size_t alora_disabled_id = 0;
|
||||||
|
|
||||||
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// check if we can batch this slot with the previous one
|
// check if we can batch this slot with the previous one
|
||||||
@@ -3914,8 +3954,11 @@ struct server_context {
|
|||||||
|
|
||||||
// truncate any tokens that are beyond n_past for this slot
|
// truncate any tokens that are beyond n_past for this slot
|
||||||
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
const llama_pos p0 = slot.prompt.tokens.pos_next();
|
||||||
|
|
||||||
|
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
||||||
|
|
||||||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d\n", p0);
|
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
@@ -3924,8 +3967,6 @@ struct server_context {
|
|||||||
slot.prompt.tokens.clear();
|
slot.prompt.tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
|
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||||
// process the image
|
// process the image
|
||||||
@@ -4126,6 +4167,8 @@ struct server_context {
|
|||||||
std::string err;
|
std::string err;
|
||||||
|
|
||||||
if (n_batch == 1 && ret == 1) {
|
if (n_batch == 1 && ret == 1) {
|
||||||
|
// TODO: try to terminate only the largest active slot/sequence and continue with the rest
|
||||||
|
// need to remove the tokens from the current batch too
|
||||||
err = "Context size has been exceeded.";
|
err = "Context size has been exceeded.";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4141,17 +4184,23 @@ struct server_context {
|
|||||||
// TODO: handle ret == 2 (abort) when we start aborting
|
// TODO: handle ret == 2 (abort) when we start aborting
|
||||||
|
|
||||||
if (!err.empty()) {
|
if (!err.empty()) {
|
||||||
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
if (slot.is_processing()) {
|
||||||
send_error(slot, err);
|
send_error(slot, err);
|
||||||
slot.release();
|
slot.release();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retry with half the batch size to try to find a free slot in the KV cache
|
// retry with half the batch size to try to find a free slot in the KV cache
|
||||||
|
if (!try_purge_idle_slots()) {
|
||||||
n_batch /= 2;
|
n_batch /= 2;
|
||||||
|
}
|
||||||
|
|
||||||
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
||||||
|
|
||||||
@@ -4391,6 +4440,15 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: should we have a separate n_parallel parameter for the server?
|
||||||
|
// https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177
|
||||||
|
if (params.n_parallel == 1 && params.kv_unified == false) {
|
||||||
|
LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true\n", __func__);
|
||||||
|
|
||||||
|
params.n_parallel = 4;
|
||||||
|
params.kv_unified = true;
|
||||||
|
}
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
|
||||||
// struct that contains llama context and inference
|
// struct that contains llama context and inference
|
||||||
@@ -4944,7 +5002,7 @@ int main(int argc, char ** argv) {
|
|||||||
// Everything else, including multimodal completions.
|
// Everything else, including multimodal completions.
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
|
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto n_prompt_tokens = inputs[i].size();
|
auto n_prompt_tokens = inputs[i].size();
|
||||||
|
|||||||
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"n_batch,batch_count,reuse_cache",
|
"n_batch,batch_count,reuse_cache",
|
||||||
[
|
[
|
||||||
(64, 15, False),
|
(64, 3, False),
|
||||||
(64, 1, True),
|
(64, 1, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_return_progresssss(n_batch, batch_count, reuse_cache):
|
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||||
global server
|
global server
|
||||||
server.n_batch = n_batch
|
server.n_batch = n_batch
|
||||||
server.n_ctx = 2048
|
server.n_ctx = 256
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
server.start()
|
server.start()
|
||||||
def make_cmpl_request():
|
def make_cmpl_request():
|
||||||
return server.make_stream_request("POST", "/chat/completions", data={
|
return server.make_stream_request("POST", "/chat/completions", data={
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "This is a test" * 100},
|
{"role": "user", "content": "This is a test" * 10},
|
||||||
],
|
],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"return_progress": True,
|
"return_progress": True,
|
||||||
|
|||||||
@@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||||||
# assert match_regex(re_content, res.body["content"])
|
# assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"n_ctx,n_slots,n_predict_vals,expected_success",
|
||||||
|
[
|
||||||
|
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
|
||||||
|
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
|
||||||
|
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
|
||||||
|
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
|
||||||
|
global server
|
||||||
|
server.n_slots = n_slots
|
||||||
|
server.kv_unified = True
|
||||||
|
server.n_ctx = n_ctx
|
||||||
|
server.start()
|
||||||
|
prompt = "A"
|
||||||
|
tasks = []
|
||||||
|
for n_predict in n_predict_vals:
|
||||||
|
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
|
||||||
|
if expect_ok:
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "content" in res.body
|
||||||
|
if "timings" in res.body:
|
||||||
|
assert res.body["timings"]["predicted_n"] == n_predict
|
||||||
|
else:
|
||||||
|
assert res.status_code == 500
|
||||||
|
assert "content" not in res.body
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"prompt,n_predict,response_fields",
|
"prompt,n_predict,response_fields",
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
|
|||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
|
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
def test_infill_with_input_extra():
|
def test_infill_with_input_extra():
|
||||||
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
|
|||||||
"input_suffix": "}\n",
|
"input_suffix": "}\n",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Dad|excited|park)+", res.body["content"])
|
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("input_extra", [
|
@pytest.mark.parametrize("input_extra", [
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class ServerProcess:
|
|||||||
server_embeddings: bool | None = False
|
server_embeddings: bool | None = False
|
||||||
server_reranking: bool | None = False
|
server_reranking: bool | None = False
|
||||||
server_metrics: bool | None = False
|
server_metrics: bool | None = False
|
||||||
|
kv_unified: bool | None = False
|
||||||
server_slots: bool | None = False
|
server_slots: bool | None = False
|
||||||
pooling: str | None = None
|
pooling: str | None = None
|
||||||
draft: int | None = None
|
draft: int | None = None
|
||||||
@@ -159,6 +160,8 @@ class ServerProcess:
|
|||||||
server_args.append("--reranking")
|
server_args.append("--reranking")
|
||||||
if self.server_metrics:
|
if self.server_metrics:
|
||||||
server_args.append("--metrics")
|
server_args.append("--metrics")
|
||||||
|
if self.kv_unified:
|
||||||
|
server_args.append("--kv-unified")
|
||||||
if self.server_slots:
|
if self.server_slots:
|
||||||
server_args.append("--slots")
|
server_args.append("--slots")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1244,6 +1244,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void clear() {
|
void clear() {
|
||||||
|
map_idx_to_media.clear();
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user