mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : speed up tests (#15836)
* server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline
This commit is contained in:
		| @@ -53,7 +53,7 @@ import typer | ||||
| sys.path.insert(0, Path(__file__).parent.parent.as_posix()) | ||||
| if True: | ||||
|     from tools.server.tests.utils import ServerProcess | ||||
|     from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather | ||||
|     from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| @@ -335,7 +335,7 @@ def run( | ||||
|                     # server.debug = True | ||||
|  | ||||
|                     with scoped_server(server): | ||||
|                         server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|                         server.start(timeout_seconds=15 * 60) | ||||
|                         for ignore_chat_grammar in [False]: | ||||
|                             run( | ||||
|                                 server, | ||||
|   | ||||
| @@ -5,6 +5,12 @@ from utils import * | ||||
| server = ServerPreset.tinyllama2() | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="session", autouse=True) | ||||
| def do_something(): | ||||
|     # this will be run once per test session, before any tests | ||||
|     ServerPreset.load_all() | ||||
|  | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def create_server(): | ||||
|     global server | ||||
|   | ||||
| @@ -14,14 +14,11 @@ from utils import * | ||||
|  | ||||
| server: ServerProcess | ||||
|  | ||||
| TIMEOUT_SERVER_START = 15*60 | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| def create_server(): | ||||
|     global server | ||||
|     server = ServerPreset.tinyllama2() | ||||
|     server.model_alias = "tinyllama-2" | ||||
|     server.server_port = 8081 | ||||
|     server.n_slots = 1 | ||||
|  | ||||
|  | ||||
| @@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe | ||||
|     server.jinja = True | ||||
|     server.reasoning_budget = reasoning_budget | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|  | ||||
|     res = server.make_request("POST", "/apply-template", data={ | ||||
|         "messages": [ | ||||
| @@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): | ||||
|     global server | ||||
|     server.jinja = True | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|  | ||||
|     res = server.make_request("POST", "/apply-template", data={ | ||||
|         "messages": [ | ||||
| @@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s | ||||
|     global server | ||||
|     server.jinja = True | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|  | ||||
|     res = server.make_request("POST", "/apply-template", data={ | ||||
|         "messages": [ | ||||
|   | ||||
| @@ -12,7 +12,7 @@ from enum import Enum | ||||
|  | ||||
| server: ServerProcess | ||||
|  | ||||
| TIMEOUT_SERVER_START = 15*60 | ||||
| TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests | ||||
| TIMEOUT_HTTP_REQUEST = 60 | ||||
|  | ||||
| @pytest.fixture(autouse=True) | ||||
| @@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, | ||||
|     server.jinja = True | ||||
|     server.n_predict = n_predict | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) | ||||
|  | ||||
|  | ||||
| @@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, | ||||
|     server.jinja = True | ||||
|     server.n_predict = n_predict | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||
|     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) | ||||
|  | ||||
|  | ||||
| @@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||
|     body = server.make_any_request("POST", "/v1/chat/completions", data={ | ||||
|         "max_tokens": n_predict, | ||||
|         "messages": [ | ||||
| @@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t | ||||
|     server.n_predict = n_predict | ||||
|     server.jinja = True | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) | ||||
|  | ||||
|  | ||||
| @@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t | ||||
|     server.n_predict = n_predict | ||||
|     server.jinja = True | ||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||
|     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) | ||||
|  | ||||
|  | ||||
| @@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|     do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) | ||||
|  | ||||
|  | ||||
| @@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||
|     do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) | ||||
|  | ||||
|  | ||||
| @@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start() | ||||
|     body = server.make_any_request("POST", "/v1/chat/completions", data={ | ||||
|         "max_tokens": n_predict, | ||||
|         "messages": [ | ||||
| @@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non | ||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||
|     elif isinstance(template_override, str): | ||||
|         server.chat_template = template_override | ||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) | ||||
|     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||
|  | ||||
|     do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) | ||||
|  | ||||
|   | ||||
| @@ -5,18 +5,31 @@ import requests | ||||
|  | ||||
| server: ServerProcess | ||||
|  | ||||
| IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" | ||||
| IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" | ||||
|  | ||||
| response = requests.get(IMG_URL_0) | ||||
| response.raise_for_status() # Raise an exception for bad status codes | ||||
| IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||
| IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8") | ||||
|  | ||||
| response = requests.get(IMG_URL_1) | ||||
| response.raise_for_status() # Raise an exception for bad status codes | ||||
| IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||
| IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8") | ||||
| def get_img_url(id: str) -> str: | ||||
|     IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" | ||||
|     IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" | ||||
|     if id == "IMG_URL_0": | ||||
|         return IMG_URL_0 | ||||
|     elif id == "IMG_URL_1": | ||||
|         return IMG_URL_1 | ||||
|     elif id == "IMG_BASE64_URI_0": | ||||
|         response = requests.get(IMG_URL_0) | ||||
|         response.raise_for_status() # Raise an exception for bad status codes | ||||
|         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||
|     elif id == "IMG_BASE64_0": | ||||
|         response = requests.get(IMG_URL_0) | ||||
|         response.raise_for_status() # Raise an exception for bad status codes | ||||
|         return base64.b64encode(response.content).decode("utf-8") | ||||
|     elif id == "IMG_BASE64_URI_1": | ||||
|         response = requests.get(IMG_URL_1) | ||||
|         response.raise_for_status() # Raise an exception for bad status codes | ||||
|         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||
|     elif id == "IMG_BASE64_1": | ||||
|         response = requests.get(IMG_URL_1) | ||||
|         response.raise_for_status() # Raise an exception for bad status codes | ||||
|         return base64.b64encode(response.content).decode("utf-8") | ||||
|     else: | ||||
|         return id | ||||
|  | ||||
| JSON_MULTIMODAL_KEY = "multimodal_data" | ||||
| JSON_PROMPT_STRING_KEY = "prompt_string" | ||||
| @@ -28,7 +41,7 @@ def create_server(): | ||||
|  | ||||
| def test_models_supports_multimodal_capability(): | ||||
|     global server | ||||
|     server.start() # vision model may take longer to load due to download size | ||||
|     server.start() | ||||
|     res = server.make_request("GET", "/models", data={}) | ||||
|     assert res.status_code == 200 | ||||
|     model_info = res.body["models"][0] | ||||
| @@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability(): | ||||
|  | ||||
| def test_v1_models_supports_multimodal_capability(): | ||||
|     global server | ||||
|     server.start() # vision model may take longer to load due to download size | ||||
|     server.start() | ||||
|     res = server.make_request("GET", "/v1/models", data={}) | ||||
|     assert res.status_code == 200 | ||||
|     model_info = res.body["models"][0] | ||||
| @@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability(): | ||||
|     "prompt, image_url, success, re_content", | ||||
|     [ | ||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||
|         ("What is this:\n", IMG_URL_0,                True, "(cat)+"), | ||||
|         ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"), # exceptional, so that we don't cog up the log | ||||
|         ("What is this:\n", IMG_URL_1,                True, "(frog)+"), | ||||
|         ("Test test\n",     IMG_URL_1,                True, "(frog)+"), # test invalidate cache | ||||
|         ("What is this:\n", "IMG_URL_0",              True, "(cat)+"), | ||||
|         ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"), | ||||
|         ("What is this:\n", "IMG_URL_1",              True, "(frog)+"), | ||||
|         ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache | ||||
|         ("What is this:\n", "malformed",              False, None), | ||||
|         ("What is this:\n", "https://google.com/404", False, None), # non-existent image | ||||
|         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data | ||||
| @@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability(): | ||||
| ) | ||||
| def test_vision_chat_completion(prompt, image_url, success, re_content): | ||||
|     global server | ||||
|     server.start(timeout_seconds=60) # vision model may take longer to load due to download size | ||||
|     if image_url == "IMG_BASE64_URI_0": | ||||
|         image_url = IMG_BASE64_URI_0 | ||||
|     server.start() | ||||
|     res = server.make_request("POST", "/chat/completions", data={ | ||||
|         "temperature": 0.0, | ||||
|         "top_k": 1, | ||||
| @@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): | ||||
|             {"role": "user", "content": [ | ||||
|                 {"type": "text", "text": prompt}, | ||||
|                 {"type": "image_url", "image_url": { | ||||
|                     "url": image_url, | ||||
|                     "url": get_img_url(image_url), | ||||
|                 }}, | ||||
|             ]}, | ||||
|         ], | ||||
| @@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): | ||||
|     "prompt, image_data, success, re_content", | ||||
|     [ | ||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||
|         ("What is this: <__media__>\n", IMG_BASE64_0,           True, "(cat)+"), | ||||
|         ("What is this: <__media__>\n", IMG_BASE64_1,           True, "(frog)+"), | ||||
|         ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"), | ||||
|         ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"), | ||||
|         ("What is this: <__media__>\n", "malformed",            False, None), # non-image data | ||||
|         ("What is this:\n",             "",                     False, None), # empty string | ||||
|     ] | ||||
| ) | ||||
| def test_vision_completion(prompt, image_data, success, re_content): | ||||
|     global server | ||||
|     server.start() # vision model may take longer to load due to download size | ||||
|     server.start() | ||||
|     res = server.make_request("POST", "/completions", data={ | ||||
|         "temperature": 0.0, | ||||
|         "top_k": 1, | ||||
|         "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, | ||||
|         "prompt": { | ||||
|             JSON_PROMPT_STRING_KEY: prompt, | ||||
|             JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], | ||||
|         }, | ||||
|     }) | ||||
|     if success: | ||||
|         assert res.status_code == 200 | ||||
| @@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content): | ||||
|     "prompt, image_data, success", | ||||
|     [ | ||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||
|         ("What is this: <__media__>\n", IMG_BASE64_0,           True), # exceptional, so that we don't cog up the log | ||||
|         ("What is this: <__media__>\n", IMG_BASE64_1,           True), | ||||
|         ("What is this: <__media__>\n", "IMG_BASE64_0",         True), | ||||
|         ("What is this: <__media__>\n", "IMG_BASE64_1",         True), | ||||
|         ("What is this: <__media__>\n", "malformed",            False), # non-image data | ||||
|         ("What is this:\n",             "base64",               False), # non-image data | ||||
|     ] | ||||
| ) | ||||
| def test_vision_embeddings(prompt, image_data, success): | ||||
|     global server | ||||
|     server.server_embeddings=True | ||||
|     server.n_batch=512 | ||||
|     server.start() # vision model may take longer to load due to download size | ||||
|     server.server_embeddings = True | ||||
|     server.n_batch = 512 | ||||
|     server.start() | ||||
|     image_data = get_img_url(image_data) | ||||
|     res = server.make_request("POST", "/embeddings", data={ | ||||
|         "content": [ | ||||
|             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, | ||||
|   | ||||
| @@ -26,7 +26,7 @@ from re import RegexFlag | ||||
| import wget | ||||
|  | ||||
|  | ||||
| DEFAULT_HTTP_TIMEOUT = 30 | ||||
| DEFAULT_HTTP_TIMEOUT = 60 | ||||
|  | ||||
|  | ||||
| class ServerResponse: | ||||
| @@ -45,6 +45,7 @@ class ServerProcess: | ||||
|     model_alias: str = "tinyllama-2" | ||||
|     temperature: float = 0.8 | ||||
|     seed: int = 42 | ||||
|     offline: bool = False | ||||
|  | ||||
|     # custom options | ||||
|     model_alias: str | None = None | ||||
| @@ -118,6 +119,8 @@ class ServerProcess: | ||||
|             "--seed", | ||||
|             self.seed, | ||||
|         ] | ||||
|         if self.offline: | ||||
|             server_args.append("--offline") | ||||
|         if self.model_file: | ||||
|             server_args.extend(["--model", self.model_file]) | ||||
|         if self.model_url: | ||||
| @@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set() | ||||
|  | ||||
|  | ||||
| class ServerPreset: | ||||
|     @staticmethod | ||||
|     def load_all() -> None: | ||||
|         """ Load all server presets to ensure model files are cached. """ | ||||
|         servers: List[ServerProcess] = [ | ||||
|             method() | ||||
|             for name, method in ServerPreset.__dict__.items() | ||||
|             if callable(method) and name != "load_all" | ||||
|         ] | ||||
|         for server in servers: | ||||
|             server.offline = False | ||||
|             server.start() | ||||
|             server.stop() | ||||
|  | ||||
|     @staticmethod | ||||
|     def tinyllama2() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
| @@ -408,6 +424,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def bert_bge_small() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         server.model_hf_repo = "ggml-org/models" | ||||
|         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" | ||||
|         server.model_alias = "bert-bge-small" | ||||
| @@ -422,6 +439,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def bert_bge_small_with_fa() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         server.model_hf_repo = "ggml-org/models" | ||||
|         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" | ||||
|         server.model_alias = "bert-bge-small" | ||||
| @@ -437,6 +455,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def tinyllama_infill() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         server.model_hf_repo = "ggml-org/models" | ||||
|         server.model_hf_file = "tinyllamas/stories260K-infill.gguf" | ||||
|         server.model_alias = "tinyllama-infill" | ||||
| @@ -451,6 +470,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def stories15m_moe() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         server.model_hf_repo = "ggml-org/stories15M_MOE" | ||||
|         server.model_hf_file = "stories15M_MOE-F16.gguf" | ||||
|         server.model_alias = "stories15m-moe" | ||||
| @@ -465,6 +485,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def jina_reranker_tiny() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         server.model_hf_repo = "ggml-org/models" | ||||
|         server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" | ||||
|         server.model_alias = "jina-reranker" | ||||
| @@ -478,6 +499,7 @@ class ServerPreset: | ||||
|     @staticmethod | ||||
|     def tinygemma3() -> ServerProcess: | ||||
|         server = ServerProcess() | ||||
|         server.offline = True # will be downloaded by load_all() | ||||
|         # mmproj is already provided by HF registry API | ||||
|         server.model_hf_repo = "ggml-org/tinygemma3-GGUF" | ||||
|         server.model_hf_file = "tinygemma3-Q8_0.gguf" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen