diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e949afab21..8641586eeb 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -115,19 +115,8 @@ llama_context::llama_context( if (cparams.kv_unified) { cparams.n_ctx_seq = cparams.n_ctx; } else { - cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; - } - - if (cparams.n_ctx_seq > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: capping n_ctx_seq (%u) to n_ctx_train (%u)\n", __func__, cparams.n_ctx_seq, hparams.n_ctx_train); - - cparams.n_ctx_seq = hparams.n_ctx_train; - } - - if (cparams.kv_unified) { - cparams.n_ctx = cparams.n_ctx_seq; - } else { - cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; + cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; + cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; } LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 29ba95ded5..90f49d4aa6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2497,12 +2497,20 @@ struct server_context { void init() { SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + const int n_ctx_train = llama_model_n_ctx_train(model); + + int n_ctx_slot = llama_n_ctx_seq(ctx); + if (n_ctx_slot > n_ctx_train) { + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); + n_ctx_slot = n_ctx_train; + } + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; slot.id = i; slot.ctx = ctx; - slot.n_ctx = llama_n_ctx_seq(ctx); + slot.n_ctx = n_ctx_slot; slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr;