server : use slot context size instead of training context size

This commit is contained in:
Georgi Gerganov
2025-10-28 11:39:07 +02:00
parent e776168267
commit 7ebe7f77a1
2 changed files with 4 additions and 6 deletions

View File

@@ -2946,17 +2946,15 @@ struct server_context {
SLT_DBG(slot, "%s", "stopped by EOS\n"); SLT_DBG(slot, "%s", "stopped by EOS\n");
} }
const auto n_ctx_train = llama_model_n_ctx_train(model); if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= slot.n_ctx) {
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
slot.truncated = true; slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT; slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction slot.has_next_token = false; // stop prediction
SLT_WRN(slot, SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. " "n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", "Limiting generated tokens to slot.n_ctx (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, n_ctx_train); slot.task->params.n_predict, slot.n_ctx);
} }
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());

View File

@@ -45,7 +45,7 @@ def test_ctx_shift_enabled():
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ @pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False), (64, 64, False),
(-1, 120, True), (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
]) ])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server global server