mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : speed up tests (#15836)
* server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline
This commit is contained in:
		| @@ -53,7 +53,7 @@ import typer | |||||||
| sys.path.insert(0, Path(__file__).parent.parent.as_posix()) | sys.path.insert(0, Path(__file__).parent.parent.as_posix()) | ||||||
| if True: | if True: | ||||||
|     from tools.server.tests.utils import ServerProcess |     from tools.server.tests.utils import ServerProcess | ||||||
|     from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather |     from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager | @contextmanager | ||||||
| @@ -335,7 +335,7 @@ def run( | |||||||
|                     # server.debug = True |                     # server.debug = True | ||||||
|  |  | ||||||
|                     with scoped_server(server): |                     with scoped_server(server): | ||||||
|                         server.start(timeout_seconds=TIMEOUT_SERVER_START) |                         server.start(timeout_seconds=15 * 60) | ||||||
|                         for ignore_chat_grammar in [False]: |                         for ignore_chat_grammar in [False]: | ||||||
|                             run( |                             run( | ||||||
|                                 server, |                                 server, | ||||||
|   | |||||||
| @@ -5,6 +5,12 @@ from utils import * | |||||||
| server = ServerPreset.tinyllama2() | server = ServerPreset.tinyllama2() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture(scope="session", autouse=True) | ||||||
|  | def do_something(): | ||||||
|  |     # this will be run once per test session, before any tests | ||||||
|  |     ServerPreset.load_all() | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture(autouse=True) | @pytest.fixture(autouse=True) | ||||||
| def create_server(): | def create_server(): | ||||||
|     global server |     global server | ||||||
|   | |||||||
| @@ -14,14 +14,11 @@ from utils import * | |||||||
|  |  | ||||||
| server: ServerProcess | server: ServerProcess | ||||||
|  |  | ||||||
| TIMEOUT_SERVER_START = 15*60 |  | ||||||
|  |  | ||||||
| @pytest.fixture(autouse=True) | @pytest.fixture(autouse=True) | ||||||
| def create_server(): | def create_server(): | ||||||
|     global server |     global server | ||||||
|     server = ServerPreset.tinyllama2() |     server = ServerPreset.tinyllama2() | ||||||
|     server.model_alias = "tinyllama-2" |     server.model_alias = "tinyllama-2" | ||||||
|     server.server_port = 8081 |  | ||||||
|     server.n_slots = 1 |     server.n_slots = 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe | |||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.reasoning_budget = reasoning_budget |     server.reasoning_budget = reasoning_budget | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|  |  | ||||||
|     res = server.make_request("POST", "/apply-template", data={ |     res = server.make_request("POST", "/apply-template", data={ | ||||||
|         "messages": [ |         "messages": [ | ||||||
| @@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): | |||||||
|     global server |     global server | ||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|  |  | ||||||
|     res = server.make_request("POST", "/apply-template", data={ |     res = server.make_request("POST", "/apply-template", data={ | ||||||
|         "messages": [ |         "messages": [ | ||||||
| @@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s | |||||||
|     global server |     global server | ||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|  |  | ||||||
|     res = server.make_request("POST", "/apply-template", data={ |     res = server.make_request("POST", "/apply-template", data={ | ||||||
|         "messages": [ |         "messages": [ | ||||||
|   | |||||||
| @@ -12,7 +12,7 @@ from enum import Enum | |||||||
|  |  | ||||||
| server: ServerProcess | server: ServerProcess | ||||||
|  |  | ||||||
| TIMEOUT_SERVER_START = 15*60 | TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests | ||||||
| TIMEOUT_HTTP_REQUEST = 60 | TIMEOUT_HTTP_REQUEST = 60 | ||||||
|  |  | ||||||
| @pytest.fixture(autouse=True) | @pytest.fixture(autouse=True) | ||||||
| @@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, | |||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.n_predict = n_predict |     server.n_predict = n_predict | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) |     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, | |||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.n_predict = n_predict |     server.n_predict = n_predict | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||||
|     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) |     do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | |||||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." |         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||||
|     elif isinstance(template_override, str): |     elif isinstance(template_override, str): | ||||||
|         server.chat_template = template_override |         server.chat_template = template_override | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||||
|     body = server.make_any_request("POST", "/v1/chat/completions", data={ |     body = server.make_any_request("POST", "/v1/chat/completions", data={ | ||||||
|         "max_tokens": n_predict, |         "max_tokens": n_predict, | ||||||
|         "messages": [ |         "messages": [ | ||||||
| @@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t | |||||||
|     server.n_predict = n_predict |     server.n_predict = n_predict | ||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) |     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t | |||||||
|     server.n_predict = n_predict |     server.n_predict = n_predict | ||||||
|     server.jinja = True |     server.jinja = True | ||||||
|     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' |     server.chat_template_file = f'../../../models/templates/{template_name}.jinja' | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||||
|     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) |     do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | | |||||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." |         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||||
|     elif isinstance(template_override, str): |     elif isinstance(template_override, str): | ||||||
|         server.chat_template = template_override |         server.chat_template = template_override | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|     do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) |     do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, | |||||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." |         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||||
|     elif isinstance(template_override, str): |     elif isinstance(template_override, str): | ||||||
|         server.chat_template = template_override |         server.chat_template = template_override | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||||
|     do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) |     do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | |||||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." |         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||||
|     elif isinstance(template_override, str): |     elif isinstance(template_override, str): | ||||||
|         server.chat_template = template_override |         server.chat_template = template_override | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start() | ||||||
|     body = server.make_any_request("POST", "/v1/chat/completions", data={ |     body = server.make_any_request("POST", "/v1/chat/completions", data={ | ||||||
|         "max_tokens": n_predict, |         "max_tokens": n_predict, | ||||||
|         "messages": [ |         "messages": [ | ||||||
| @@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non | |||||||
|         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." |         assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." | ||||||
|     elif isinstance(template_override, str): |     elif isinstance(template_override, str): | ||||||
|         server.chat_template = template_override |         server.chat_template = template_override | ||||||
|     server.start(timeout_seconds=TIMEOUT_SERVER_START) |     server.start(timeout_seconds=TIMEOUT_START_SLOW) | ||||||
|  |  | ||||||
|     do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) |     do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,18 +5,31 @@ import requests | |||||||
|  |  | ||||||
| server: ServerProcess | server: ServerProcess | ||||||
|  |  | ||||||
| IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" | def get_img_url(id: str) -> str: | ||||||
| IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" |     IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" | ||||||
|  |     IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" | ||||||
| response = requests.get(IMG_URL_0) |     if id == "IMG_URL_0": | ||||||
| response.raise_for_status() # Raise an exception for bad status codes |         return IMG_URL_0 | ||||||
| IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") |     elif id == "IMG_URL_1": | ||||||
| IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8") |         return IMG_URL_1 | ||||||
|  |     elif id == "IMG_BASE64_URI_0": | ||||||
| response = requests.get(IMG_URL_1) |         response = requests.get(IMG_URL_0) | ||||||
| response.raise_for_status() # Raise an exception for bad status codes |         response.raise_for_status() # Raise an exception for bad status codes | ||||||
| IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") |         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||||
| IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8") |     elif id == "IMG_BASE64_0": | ||||||
|  |         response = requests.get(IMG_URL_0) | ||||||
|  |         response.raise_for_status() # Raise an exception for bad status codes | ||||||
|  |         return base64.b64encode(response.content).decode("utf-8") | ||||||
|  |     elif id == "IMG_BASE64_URI_1": | ||||||
|  |         response = requests.get(IMG_URL_1) | ||||||
|  |         response.raise_for_status() # Raise an exception for bad status codes | ||||||
|  |         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") | ||||||
|  |     elif id == "IMG_BASE64_1": | ||||||
|  |         response = requests.get(IMG_URL_1) | ||||||
|  |         response.raise_for_status() # Raise an exception for bad status codes | ||||||
|  |         return base64.b64encode(response.content).decode("utf-8") | ||||||
|  |     else: | ||||||
|  |         return id | ||||||
|  |  | ||||||
| JSON_MULTIMODAL_KEY = "multimodal_data" | JSON_MULTIMODAL_KEY = "multimodal_data" | ||||||
| JSON_PROMPT_STRING_KEY = "prompt_string" | JSON_PROMPT_STRING_KEY = "prompt_string" | ||||||
| @@ -28,7 +41,7 @@ def create_server(): | |||||||
|  |  | ||||||
| def test_models_supports_multimodal_capability(): | def test_models_supports_multimodal_capability(): | ||||||
|     global server |     global server | ||||||
|     server.start() # vision model may take longer to load due to download size |     server.start() | ||||||
|     res = server.make_request("GET", "/models", data={}) |     res = server.make_request("GET", "/models", data={}) | ||||||
|     assert res.status_code == 200 |     assert res.status_code == 200 | ||||||
|     model_info = res.body["models"][0] |     model_info = res.body["models"][0] | ||||||
| @@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability(): | |||||||
|  |  | ||||||
| def test_v1_models_supports_multimodal_capability(): | def test_v1_models_supports_multimodal_capability(): | ||||||
|     global server |     global server | ||||||
|     server.start() # vision model may take longer to load due to download size |     server.start() | ||||||
|     res = server.make_request("GET", "/v1/models", data={}) |     res = server.make_request("GET", "/v1/models", data={}) | ||||||
|     assert res.status_code == 200 |     assert res.status_code == 200 | ||||||
|     model_info = res.body["models"][0] |     model_info = res.body["models"][0] | ||||||
| @@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability(): | |||||||
|     "prompt, image_url, success, re_content", |     "prompt, image_url, success, re_content", | ||||||
|     [ |     [ | ||||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size |         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||||
|         ("What is this:\n", IMG_URL_0,                True, "(cat)+"), |         ("What is this:\n", "IMG_URL_0",              True, "(cat)+"), | ||||||
|         ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"), # exceptional, so that we don't cog up the log |         ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"), | ||||||
|         ("What is this:\n", IMG_URL_1,                True, "(frog)+"), |         ("What is this:\n", "IMG_URL_1",              True, "(frog)+"), | ||||||
|         ("Test test\n",     IMG_URL_1,                True, "(frog)+"), # test invalidate cache |         ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache | ||||||
|         ("What is this:\n", "malformed",              False, None), |         ("What is this:\n", "malformed",              False, None), | ||||||
|         ("What is this:\n", "https://google.com/404", False, None), # non-existent image |         ("What is this:\n", "https://google.com/404", False, None), # non-existent image | ||||||
|         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data |         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data | ||||||
| @@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability(): | |||||||
| ) | ) | ||||||
| def test_vision_chat_completion(prompt, image_url, success, re_content): | def test_vision_chat_completion(prompt, image_url, success, re_content): | ||||||
|     global server |     global server | ||||||
|     server.start(timeout_seconds=60) # vision model may take longer to load due to download size |     server.start() | ||||||
|     if image_url == "IMG_BASE64_URI_0": |  | ||||||
|         image_url = IMG_BASE64_URI_0 |  | ||||||
|     res = server.make_request("POST", "/chat/completions", data={ |     res = server.make_request("POST", "/chat/completions", data={ | ||||||
|         "temperature": 0.0, |         "temperature": 0.0, | ||||||
|         "top_k": 1, |         "top_k": 1, | ||||||
| @@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): | |||||||
|             {"role": "user", "content": [ |             {"role": "user", "content": [ | ||||||
|                 {"type": "text", "text": prompt}, |                 {"type": "text", "text": prompt}, | ||||||
|                 {"type": "image_url", "image_url": { |                 {"type": "image_url", "image_url": { | ||||||
|                     "url": image_url, |                     "url": get_img_url(image_url), | ||||||
|                 }}, |                 }}, | ||||||
|             ]}, |             ]}, | ||||||
|         ], |         ], | ||||||
| @@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): | |||||||
|     "prompt, image_data, success, re_content", |     "prompt, image_data, success, re_content", | ||||||
|     [ |     [ | ||||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size |         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||||
|         ("What is this: <__media__>\n", IMG_BASE64_0,           True, "(cat)+"), |         ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"), | ||||||
|         ("What is this: <__media__>\n", IMG_BASE64_1,           True, "(frog)+"), |         ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"), | ||||||
|         ("What is this: <__media__>\n", "malformed",            False, None), # non-image data |         ("What is this: <__media__>\n", "malformed",            False, None), # non-image data | ||||||
|         ("What is this:\n",             "",                     False, None), # empty string |         ("What is this:\n",             "",                     False, None), # empty string | ||||||
|     ] |     ] | ||||||
| ) | ) | ||||||
| def test_vision_completion(prompt, image_data, success, re_content): | def test_vision_completion(prompt, image_data, success, re_content): | ||||||
|     global server |     global server | ||||||
|     server.start() # vision model may take longer to load due to download size |     server.start() | ||||||
|     res = server.make_request("POST", "/completions", data={ |     res = server.make_request("POST", "/completions", data={ | ||||||
|         "temperature": 0.0, |         "temperature": 0.0, | ||||||
|         "top_k": 1, |         "top_k": 1, | ||||||
|         "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, |         "prompt": { | ||||||
|  |             JSON_PROMPT_STRING_KEY: prompt, | ||||||
|  |             JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], | ||||||
|  |         }, | ||||||
|     }) |     }) | ||||||
|     if success: |     if success: | ||||||
|         assert res.status_code == 200 |         assert res.status_code == 200 | ||||||
| @@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content): | |||||||
|     "prompt, image_data, success", |     "prompt, image_data, success", | ||||||
|     [ |     [ | ||||||
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size |         # test model is trained on CIFAR-10, but it's quite dumb due to small size | ||||||
|         ("What is this: <__media__>\n", IMG_BASE64_0,           True), # exceptional, so that we don't cog up the log |         ("What is this: <__media__>\n", "IMG_BASE64_0",         True), | ||||||
|         ("What is this: <__media__>\n", IMG_BASE64_1,           True), |         ("What is this: <__media__>\n", "IMG_BASE64_1",         True), | ||||||
|         ("What is this: <__media__>\n", "malformed",            False), # non-image data |         ("What is this: <__media__>\n", "malformed",            False), # non-image data | ||||||
|         ("What is this:\n",             "base64",               False), # non-image data |         ("What is this:\n",             "base64",               False), # non-image data | ||||||
|     ] |     ] | ||||||
| ) | ) | ||||||
| def test_vision_embeddings(prompt, image_data, success): | def test_vision_embeddings(prompt, image_data, success): | ||||||
|     global server |     global server | ||||||
|     server.server_embeddings=True |     server.server_embeddings = True | ||||||
|     server.n_batch=512 |     server.n_batch = 512 | ||||||
|     server.start() # vision model may take longer to load due to download size |     server.start() | ||||||
|  |     image_data = get_img_url(image_data) | ||||||
|     res = server.make_request("POST", "/embeddings", data={ |     res = server.make_request("POST", "/embeddings", data={ | ||||||
|         "content": [ |         "content": [ | ||||||
|             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, |             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, | ||||||
|   | |||||||
| @@ -26,7 +26,7 @@ from re import RegexFlag | |||||||
| import wget | import wget | ||||||
|  |  | ||||||
|  |  | ||||||
| DEFAULT_HTTP_TIMEOUT = 30 | DEFAULT_HTTP_TIMEOUT = 60 | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServerResponse: | class ServerResponse: | ||||||
| @@ -45,6 +45,7 @@ class ServerProcess: | |||||||
|     model_alias: str = "tinyllama-2" |     model_alias: str = "tinyllama-2" | ||||||
|     temperature: float = 0.8 |     temperature: float = 0.8 | ||||||
|     seed: int = 42 |     seed: int = 42 | ||||||
|  |     offline: bool = False | ||||||
|  |  | ||||||
|     # custom options |     # custom options | ||||||
|     model_alias: str | None = None |     model_alias: str | None = None | ||||||
| @@ -118,6 +119,8 @@ class ServerProcess: | |||||||
|             "--seed", |             "--seed", | ||||||
|             self.seed, |             self.seed, | ||||||
|         ] |         ] | ||||||
|  |         if self.offline: | ||||||
|  |             server_args.append("--offline") | ||||||
|         if self.model_file: |         if self.model_file: | ||||||
|             server_args.extend(["--model", self.model_file]) |             server_args.extend(["--model", self.model_file]) | ||||||
|         if self.model_url: |         if self.model_url: | ||||||
| @@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set() | |||||||
|  |  | ||||||
|  |  | ||||||
| class ServerPreset: | class ServerPreset: | ||||||
|  |     @staticmethod | ||||||
|  |     def load_all() -> None: | ||||||
|  |         """ Load all server presets to ensure model files are cached. """ | ||||||
|  |         servers: List[ServerProcess] = [ | ||||||
|  |             method() | ||||||
|  |             for name, method in ServerPreset.__dict__.items() | ||||||
|  |             if callable(method) and name != "load_all" | ||||||
|  |         ] | ||||||
|  |         for server in servers: | ||||||
|  |             server.offline = False | ||||||
|  |             server.start() | ||||||
|  |             server.stop() | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def tinyllama2() -> ServerProcess: |     def tinyllama2() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
| @@ -408,6 +424,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def bert_bge_small() -> ServerProcess: |     def bert_bge_small() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         server.model_hf_repo = "ggml-org/models" |         server.model_hf_repo = "ggml-org/models" | ||||||
|         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" |         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" | ||||||
|         server.model_alias = "bert-bge-small" |         server.model_alias = "bert-bge-small" | ||||||
| @@ -422,6 +439,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def bert_bge_small_with_fa() -> ServerProcess: |     def bert_bge_small_with_fa() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         server.model_hf_repo = "ggml-org/models" |         server.model_hf_repo = "ggml-org/models" | ||||||
|         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" |         server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" | ||||||
|         server.model_alias = "bert-bge-small" |         server.model_alias = "bert-bge-small" | ||||||
| @@ -437,6 +455,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def tinyllama_infill() -> ServerProcess: |     def tinyllama_infill() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         server.model_hf_repo = "ggml-org/models" |         server.model_hf_repo = "ggml-org/models" | ||||||
|         server.model_hf_file = "tinyllamas/stories260K-infill.gguf" |         server.model_hf_file = "tinyllamas/stories260K-infill.gguf" | ||||||
|         server.model_alias = "tinyllama-infill" |         server.model_alias = "tinyllama-infill" | ||||||
| @@ -451,6 +470,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def stories15m_moe() -> ServerProcess: |     def stories15m_moe() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         server.model_hf_repo = "ggml-org/stories15M_MOE" |         server.model_hf_repo = "ggml-org/stories15M_MOE" | ||||||
|         server.model_hf_file = "stories15M_MOE-F16.gguf" |         server.model_hf_file = "stories15M_MOE-F16.gguf" | ||||||
|         server.model_alias = "stories15m-moe" |         server.model_alias = "stories15m-moe" | ||||||
| @@ -465,6 +485,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def jina_reranker_tiny() -> ServerProcess: |     def jina_reranker_tiny() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         server.model_hf_repo = "ggml-org/models" |         server.model_hf_repo = "ggml-org/models" | ||||||
|         server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" |         server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" | ||||||
|         server.model_alias = "jina-reranker" |         server.model_alias = "jina-reranker" | ||||||
| @@ -478,6 +499,7 @@ class ServerPreset: | |||||||
|     @staticmethod |     @staticmethod | ||||||
|     def tinygemma3() -> ServerProcess: |     def tinygemma3() -> ServerProcess: | ||||||
|         server = ServerProcess() |         server = ServerProcess() | ||||||
|  |         server.offline = True # will be downloaded by load_all() | ||||||
|         # mmproj is already provided by HF registry API |         # mmproj is already provided by HF registry API | ||||||
|         server.model_hf_repo = "ggml-org/tinygemma3-GGUF" |         server.model_hf_repo = "ggml-org/tinygemma3-GGUF" | ||||||
|         server.model_hf_file = "tinygemma3-Q8_0.gguf" |         server.model_hf_file = "tinygemma3-Q8_0.gguf" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen