server : return HTTP 400 if prompt exceeds context length (#16486)

In streaming mode when prompt exceeds context length, the server returns
HTTP 200 status code with a JSON error in the body.  This is very
confusing and inconsistent with all other inference engines which return
HTTP 4xx error in this case.

This patch fixes this problem and makes the server return HTTP 400 in
such cases.
This commit is contained in:
Radoslav Gerganov
2025-10-10 17:11:07 +03:00
committed by GitHub
parent cdb6da468c
commit 68ee98ae18
3 changed files with 40 additions and 2 deletions

View File

@@ -3727,7 +3727,7 @@ struct server_context {
}
} else {
if (slot.n_prompt_tokens() >= slot.n_ctx) {
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
slot.release();
continue;
}
@@ -4955,9 +4955,17 @@ int main(int argc, char ** argv) {
// Everything else, including multimodal completions.
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel;
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto n_prompt_tokens = inputs[i].size();
if (n_prompt_tokens >= n_ctx_slot) {
json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
error_data["n_prompt_tokens"] = n_prompt_tokens;
error_data["n_ctx"] = n_ctx_slot;
res_error(res, error_data);
return;
}
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();

View File

@@ -408,6 +408,28 @@ def test_context_size_exceeded():
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
def test_context_size_exceeded_stream():
global server
server.start()
try:
for _ in server.make_stream_request("POST", "/chat/completions", data={
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
] * 100, # make the prompt too long
"stream": True}):
pass
assert False, "Should have failed"
except ServerError as e:
assert e.code == 400
assert "error" in e.body
assert e.body["error"]["type"] == "exceed_context_size_error"
assert e.body["error"]["n_prompt_tokens"] > 0
assert server.n_ctx is not None
assert server.n_slots is not None
assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
@pytest.mark.parametrize(
"n_batch,batch_count,reuse_cache",
[

View File

@@ -35,6 +35,12 @@ class ServerResponse:
body: dict | Any
class ServerError(Exception):
def __init__(self, code, body):
self.code = code
self.body = body
class ServerProcess:
# default options
debug: bool = False
@@ -297,6 +303,8 @@ class ServerProcess:
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
if response.status_code != 200:
raise ServerError(response.status_code, response.json())
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line: