mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : (refactor) no more json in server_task input (#10691)
* server : (refactor) no more json in server_task input * add test for slots endpoint * add tests for /props and /slots * remove task inf_type * fix CI by adding safe_json_to_str * add "model_path" to /props * update readme
This commit is contained in:
		| @@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make | ||||
|     } | ||||
|   }, | ||||
|   "total_slots": 1, | ||||
|   "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", | ||||
|   "chat_template": "..." | ||||
| } | ||||
| ``` | ||||
|  | ||||
| - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint. | ||||
| - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) | ||||
| - `model_path` - the path to model file (same with `-m` argument) | ||||
| - `chat_template` - the model's original Jinja2 prompt template | ||||
|  | ||||
| ### POST `/props`: Change server global properties. | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -22,7 +22,12 @@ def test_server_props(): | ||||
|     server.start() | ||||
|     res = server.make_request("GET", "/props") | ||||
|     assert res.status_code == 200 | ||||
|     assert ".gguf" in res.body["model_path"] | ||||
|     assert res.body["total_slots"] == server.n_slots | ||||
|     default_val = res.body["default_generation_settings"] | ||||
|     assert server.n_ctx is not None and server.n_slots is not None | ||||
|     assert default_val["n_ctx"] == server.n_ctx / server.n_slots | ||||
|     assert default_val["params"]["seed"] == server.seed | ||||
|  | ||||
|  | ||||
| def test_server_models(): | ||||
| @@ -33,6 +38,31 @@ def test_server_models(): | ||||
|     assert len(res.body["data"]) == 1 | ||||
|     assert res.body["data"][0]["id"] == server.model_alias | ||||
|  | ||||
|  | ||||
| def test_server_slots(): | ||||
|     global server | ||||
|  | ||||
|     # without slots endpoint enabled, this should return error | ||||
|     server.server_slots = False | ||||
|     server.start() | ||||
|     res = server.make_request("GET", "/slots") | ||||
|     assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED | ||||
|     assert "error" in res.body | ||||
|     server.stop() | ||||
|  | ||||
|     # with slots endpoint enabled, this should return slots info | ||||
|     server.server_slots = True | ||||
|     server.n_slots = 2 | ||||
|     server.start() | ||||
|     res = server.make_request("GET", "/slots") | ||||
|     assert res.status_code == 200 | ||||
|     assert len(res.body) == server.n_slots | ||||
|     assert server.n_ctx is not None and server.n_slots is not None | ||||
|     assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots | ||||
|     assert "params" in res.body[0] | ||||
|     assert res.body[0]["params"]["seed"] == server.seed | ||||
|  | ||||
|  | ||||
| def test_load_split_model(): | ||||
|     global server | ||||
|     server.model_hf_repo = "ggml-org/models" | ||||
|   | ||||
| @@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte | ||||
|         ], | ||||
|     }) | ||||
|     assert res.status_code == 200 | ||||
|     assert "cmpl" in res.body["id"] # make sure the completion id has the expected format | ||||
|     assert res.body["model"] == model if model is not None else server.model_alias | ||||
|     assert res.body["usage"]["prompt_tokens"] == n_prompt | ||||
|     assert res.body["usage"]["completion_tokens"] == n_predicted | ||||
| @@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte | ||||
|         "stream": True, | ||||
|     }) | ||||
|     content = "" | ||||
|     last_cmpl_id = None | ||||
|     for data in res: | ||||
|         choice = data["choices"][0] | ||||
|         assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future | ||||
|         if last_cmpl_id is None: | ||||
|             last_cmpl_id = data["id"] | ||||
|         assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream | ||||
|         if choice["finish_reason"] in ["stop", "length"]: | ||||
|             assert data["usage"]["prompt_tokens"] == n_prompt | ||||
|             assert data["usage"]["completion_tokens"] == n_predicted | ||||
|   | ||||
| @@ -64,6 +64,7 @@ class ServerProcess: | ||||
|     server_embeddings: bool | None = False | ||||
|     server_reranking: bool | None = False | ||||
|     server_metrics: bool | None = False | ||||
|     server_slots: bool | None = False | ||||
|     draft: int | None = None | ||||
|     api_key: str | None = None | ||||
|     response_format: str | None = None | ||||
| @@ -91,7 +92,6 @@ class ServerProcess: | ||||
|         else: | ||||
|             server_path = "../../../build/bin/llama-server" | ||||
|         server_args = [ | ||||
|             "--slots",  # requires to get slot status via /slots endpoint | ||||
|             "--host", | ||||
|             self.server_host, | ||||
|             "--port", | ||||
| @@ -129,6 +129,8 @@ class ServerProcess: | ||||
|             server_args.append("--reranking") | ||||
|         if self.server_metrics: | ||||
|             server_args.append("--metrics") | ||||
|         if self.server_slots: | ||||
|             server_args.append("--slots") | ||||
|         if self.model_alias: | ||||
|             server_args.extend(["--alias", self.model_alias]) | ||||
|         if self.n_ctx: | ||||
| @@ -181,7 +183,7 @@ class ServerProcess: | ||||
|         start_time = time.time() | ||||
|         while time.time() - start_time < timeout_seconds: | ||||
|             try: | ||||
|                 response = self.make_request("GET", "/slots", headers={ | ||||
|                 response = self.make_request("GET", "/health", headers={ | ||||
|                     "Authorization": f"Bearer {self.api_key}" if self.api_key else None | ||||
|                 }) | ||||
|                 if response.status_code == 200: | ||||
| @@ -224,7 +226,7 @@ class ServerProcess: | ||||
|         result.headers = dict(response.headers) | ||||
|         result.status_code = response.status_code | ||||
|         result.body = response.json() if parse_body else None | ||||
|         print("Response from server", result.body) | ||||
|         print("Response from server", json.dumps(result.body, indent=2)) | ||||
|         return result | ||||
|  | ||||
|     def make_stream_request( | ||||
| @@ -245,7 +247,7 @@ class ServerProcess: | ||||
|                 break | ||||
|             elif line.startswith('data: '): | ||||
|                 data = json.loads(line[6:]) | ||||
|                 print("Partial response from server", data) | ||||
|                 print("Partial response from server", json.dumps(data, indent=2)) | ||||
|                 yield data | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con | ||||
|     } else { | ||||
|         throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); | ||||
|     } | ||||
|     if (result.empty()) { | ||||
|         throw std::runtime_error("\"prompt\" must not be empty"); | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| @@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse( | ||||
|     const std::string & chat_template) { | ||||
|     json llama_params; | ||||
|  | ||||
|     llama_params["__oaicompat"] = true; | ||||
|  | ||||
|     // Apply chat template to the list of messages | ||||
|     llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); | ||||
|  | ||||
| @@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) { | ||||
|         {"content", content} | ||||
|     }; | ||||
| } | ||||
|  | ||||
| static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) { | ||||
|     json data = json::array(); | ||||
|     for (const auto & lb : logit_bias) { | ||||
|         data.push_back(json{ | ||||
|             {"bias", lb.bias}, | ||||
|             {"token", lb.token}, | ||||
|         }); | ||||
|     } | ||||
|     return data; | ||||
| } | ||||
|  | ||||
| static std::string safe_json_to_str(json data) { | ||||
|     return data.dump(-1, ' ', false, json::error_handler_t::replace); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen