mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
server: add exceed_context_size_error type (#15780)
* server: add exceed_context_size_error type * change error code to 400
This commit is contained in:
@@ -86,6 +86,7 @@ enum error_type {
|
|||||||
ERROR_TYPE_PERMISSION,
|
ERROR_TYPE_PERMISSION,
|
||||||
ERROR_TYPE_UNAVAILABLE, // custom error
|
ERROR_TYPE_UNAVAILABLE, // custom error
|
||||||
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
||||||
|
ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool server_task_type_need_embd(server_task_type task_type) {
|
static bool server_task_type_need_embd(server_task_type task_type) {
|
||||||
@@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_
|
|||||||
type_str = "unavailable_error";
|
type_str = "unavailable_error";
|
||||||
code = 503;
|
code = 503;
|
||||||
break;
|
break;
|
||||||
|
case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
|
||||||
|
type_str = "exceed_context_size_error";
|
||||||
|
code = 400;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return json {
|
return json {
|
||||||
{"code", code},
|
{"code", code},
|
||||||
@@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result {
|
|||||||
error_type err_type = ERROR_TYPE_SERVER;
|
error_type err_type = ERROR_TYPE_SERVER;
|
||||||
std::string err_msg;
|
std::string err_msg;
|
||||||
|
|
||||||
|
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
|
||||||
|
int32_t n_prompt_tokens = 0;
|
||||||
|
int32_t n_ctx = 0;
|
||||||
|
|
||||||
virtual bool is_error() override {
|
virtual bool is_error() override {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return format_error_response(err_msg, err_type);
|
json res = format_error_response(err_msg, err_type);
|
||||||
|
if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
||||||
|
res["n_prompt_tokens"] = n_prompt_tokens;
|
||||||
|
res["n_ctx"] = n_ctx;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -2605,16 +2619,22 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||||
send_error(slot.id_task, error, type);
|
send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
|
||||||
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
||||||
|
|
||||||
|
if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
||||||
|
GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
|
||||||
|
}
|
||||||
|
|
||||||
auto res = std::make_unique<server_task_result_error>();
|
auto res = std::make_unique<server_task_result_error>();
|
||||||
res->id = id_task;
|
res->id = id_task;
|
||||||
res->err_type = type;
|
res->err_type = type;
|
||||||
res->err_msg = error;
|
res->err_msg = error;
|
||||||
|
res->n_prompt_tokens = n_prompt_tokens;
|
||||||
|
res->n_ctx = n_ctx;
|
||||||
|
|
||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
@@ -3286,7 +3306,7 @@ struct server_context {
|
|||||||
|
|
||||||
if (slot.n_prompt_tokens > slot.n_ctx) {
|
if (slot.n_prompt_tokens > slot.n_ctx) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -3296,7 +3316,7 @@ struct server_context {
|
|||||||
// context shift should be applied only during the generation phase
|
// context shift should be applied only during the generation phase
|
||||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -385,3 +385,20 @@ def test_logit_bias():
|
|||||||
output_text = res.choices[0].message.content
|
output_text = res.choices[0].message.content
|
||||||
assert output_text
|
assert output_text
|
||||||
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
||||||
|
|
||||||
|
def test_context_size_exceeded():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "Book"},
|
||||||
|
{"role": "user", "content": "What is the best book"},
|
||||||
|
] * 100, # make the prompt too long
|
||||||
|
})
|
||||||
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
assert res.body["error"]["type"] == "exceed_context_size_error"
|
||||||
|
assert res.body["error"]["n_prompt_tokens"] > 0
|
||||||
|
assert server.n_ctx is not None
|
||||||
|
assert server.n_slots is not None
|
||||||
|
assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots
|
||||||
|
|||||||
Reference in New Issue
Block a user