mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server: add exceed_context_size_error type (#15780)
* server: add exceed_context_size_error type * change error code to 400
This commit is contained in:
		| @@ -86,6 +86,7 @@ enum error_type { | ||||
|     ERROR_TYPE_PERMISSION, | ||||
|     ERROR_TYPE_UNAVAILABLE, // custom error | ||||
|     ERROR_TYPE_NOT_SUPPORTED, // custom error | ||||
|     ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error | ||||
| }; | ||||
|  | ||||
| static bool server_task_type_need_embd(server_task_type task_type) { | ||||
| @@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_ | ||||
|             type_str = "unavailable_error"; | ||||
|             code = 503; | ||||
|             break; | ||||
|         case ERROR_TYPE_EXCEED_CONTEXT_SIZE: | ||||
|             type_str = "exceed_context_size_error"; | ||||
|             code = 400; | ||||
|             break; | ||||
|     } | ||||
|     return json { | ||||
|         {"code", code}, | ||||
| @@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result { | ||||
|     error_type err_type = ERROR_TYPE_SERVER; | ||||
|     std::string err_msg; | ||||
|  | ||||
|     // for ERROR_TYPE_EXCEED_CONTEXT_SIZE | ||||
|     int32_t n_prompt_tokens = 0; | ||||
|     int32_t n_ctx           = 0; | ||||
|  | ||||
|     virtual bool is_error() override { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     virtual json to_json() override { | ||||
|         return format_error_response(err_msg, err_type); | ||||
|         json res = format_error_response(err_msg, err_type); | ||||
|         if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { | ||||
|             res["n_prompt_tokens"] = n_prompt_tokens; | ||||
|             res["n_ctx"]           = n_ctx; | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @@ -2605,16 +2619,22 @@ struct server_context { | ||||
|     } | ||||
|  | ||||
|     void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { | ||||
|         send_error(slot.id_task, error, type); | ||||
|         send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx); | ||||
|     } | ||||
|  | ||||
|     void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { | ||||
|     void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { | ||||
|         SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); | ||||
|  | ||||
|         if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { | ||||
|             GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); | ||||
|         } | ||||
|  | ||||
|         auto res = std::make_unique<server_task_result_error>(); | ||||
|         res->id       = id_task; | ||||
|         res->err_type = type; | ||||
|         res->err_msg  = error; | ||||
|         res->id              = id_task; | ||||
|         res->err_type        = type; | ||||
|         res->err_msg         = error; | ||||
|         res->n_prompt_tokens = n_prompt_tokens; | ||||
|         res->n_ctx           = n_ctx; | ||||
|  | ||||
|         queue_results.send(std::move(res)); | ||||
|     } | ||||
| @@ -3286,7 +3306,7 @@ struct server_context { | ||||
|  | ||||
|                             if (slot.n_prompt_tokens > slot.n_ctx) { | ||||
|                                 slot.release(); | ||||
|                                 send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); | ||||
|                                 send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); | ||||
|                                 continue; | ||||
|                             } | ||||
|                         } else { | ||||
| @@ -3296,7 +3316,7 @@ struct server_context { | ||||
|                                 //       context shift should be applied only during the generation phase | ||||
|                                 if (slot.n_prompt_tokens >= slot.n_ctx) { | ||||
|                                     slot.release(); | ||||
|                                     send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); | ||||
|                                     send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); | ||||
|                                     continue; | ||||
|                                 } | ||||
|                             } | ||||
|   | ||||
| @@ -385,3 +385,20 @@ def test_logit_bias(): | ||||
|     output_text = res.choices[0].message.content | ||||
|     assert output_text | ||||
|     assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude) | ||||
|  | ||||
| def test_context_size_exceeded(): | ||||
|     global server | ||||
|     server.start() | ||||
|     res = server.make_request("POST", "/chat/completions", data={ | ||||
|         "messages": [ | ||||
|             {"role": "system", "content": "Book"}, | ||||
|             {"role": "user", "content": "What is the best book"}, | ||||
|         ] * 100, # make the prompt too long | ||||
|     }) | ||||
|     assert res.status_code == 400 | ||||
|     assert "error" in res.body | ||||
|     assert res.body["error"]["type"] == "exceed_context_size_error" | ||||
|     assert res.body["error"]["n_prompt_tokens"] > 0 | ||||
|     assert server.n_ctx is not None | ||||
|     assert server.n_slots is not None | ||||
|     assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen