server : speed up tests (#15836)

* server : speed up tests

* clean up

* restore timeout_seconds in some places

* flake8

* explicit offline
This commit is contained in:
Xuan-Son Nguyen
2025-09-06 19:45:24 +07:00
committed by GitHub
parent 61bdfd5298
commit 3c3635d2f2
6 changed files with 90 additions and 50 deletions

View File

@@ -53,7 +53,7 @@ import typer
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
if True:
from tools.server.tests.utils import ServerProcess
from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
@contextmanager
@@ -335,7 +335,7 @@ def run(
# server.debug = True
with scoped_server(server):
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=15 * 60)
for ignore_chat_grammar in [False]:
run(
server,

View File

@@ -5,6 +5,12 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="session", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()
@pytest.fixture(autouse=True)
def create_server():
global server

View File

@@ -14,14 +14,11 @@ from utils import *
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.server_port = 8081
server.n_slots = 1
@@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe
server.jinja = True
server.reasoning_budget = reasoning_budget
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
@@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
@@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [

View File

@@ -12,7 +12,7 @@ from enum import Enum
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
TIMEOUT_HTTP_REQUEST = 60
@pytest.fixture(autouse=True)
@@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
@@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
@@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
@@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
@@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
@@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
@@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)

View File

@@ -5,18 +5,31 @@ import requests
server: ServerProcess
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
def get_img_url(id: str) -> str:
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
if id == "IMG_URL_0":
return IMG_URL_0
elif id == "IMG_URL_1":
return IMG_URL_1
elif id == "IMG_BASE64_URI_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_URI_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
else:
return id
JSON_MULTIMODAL_KEY = "multimodal_data"
JSON_PROMPT_STRING_KEY = "prompt_string"
@@ -28,7 +41,7 @@ def create_server():
def test_models_supports_multimodal_capability():
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("GET", "/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
def test_v1_models_supports_multimodal_capability():
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("GET", "/v1/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
"prompt, image_url, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this:\n", IMG_URL_0, True, "(cat)+"),
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
("What is this:\n", IMG_URL_1, True, "(frog)+"),
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
@@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability():
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
global server
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
if image_url == "IMG_BASE64_URI_0":
image_url = IMG_BASE64_URI_0
server.start()
res = server.make_request("POST", "/chat/completions", data={
"temperature": 0.0,
"top_k": 1,
@@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": image_url,
"url": get_img_url(image_url),
}},
]},
],
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
"prompt, image_data, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
("What is this: <__media__>\n", "malformed", False, None), # non-image data
("What is this:\n", "", False, None), # empty string
]
)
def test_vision_completion(prompt, image_data, success, re_content):
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("POST", "/completions", data={
"temperature": 0.0,
"top_k": 1,
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
"prompt": {
JSON_PROMPT_STRING_KEY: prompt,
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
},
})
if success:
assert res.status_code == 200
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
"prompt, image_data, success",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
("What is this: <__media__>\n", IMG_BASE64_1, True),
("What is this: <__media__>\n", "IMG_BASE64_0", True),
("What is this: <__media__>\n", "IMG_BASE64_1", True),
("What is this: <__media__>\n", "malformed", False), # non-image data
("What is this:\n", "base64", False), # non-image data
]
)
def test_vision_embeddings(prompt, image_data, success):
global server
server.server_embeddings=True
server.n_batch=512
server.start() # vision model may take longer to load due to download size
server.server_embeddings = True
server.n_batch = 512
server.start()
image_data = get_img_url(image_data)
res = server.make_request("POST", "/embeddings", data={
"content": [
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },

View File

@@ -26,7 +26,7 @@ from re import RegexFlag
import wget
DEFAULT_HTTP_TIMEOUT = 30
DEFAULT_HTTP_TIMEOUT = 60
class ServerResponse:
@@ -45,6 +45,7 @@ class ServerProcess:
model_alias: str = "tinyllama-2"
temperature: float = 0.8
seed: int = 42
offline: bool = False
# custom options
model_alias: str | None = None
@@ -118,6 +119,8 @@ class ServerProcess:
"--seed",
self.seed,
]
if self.offline:
server_args.append("--offline")
if self.model_file:
server_args.extend(["--model", self.model_file])
if self.model_url:
@@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def load_all() -> None:
""" Load all server presets to ensure model files are cached. """
servers: List[ServerProcess] = [
method()
for name, method in ServerPreset.__dict__.items()
if callable(method) and name != "load_all"
]
for server in servers:
server.offline = False
server.start()
server.stop()
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
@@ -408,6 +424,7 @@ class ServerPreset:
@staticmethod
def bert_bge_small() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
@@ -422,6 +439,7 @@ class ServerPreset:
@staticmethod
def bert_bge_small_with_fa() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
@@ -437,6 +455,7 @@ class ServerPreset:
@staticmethod
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
server.model_alias = "tinyllama-infill"
@@ -451,6 +470,7 @@ class ServerPreset:
@staticmethod
def stories15m_moe() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/stories15M_MOE"
server.model_hf_file = "stories15M_MOE-F16.gguf"
server.model_alias = "stories15m-moe"
@@ -465,6 +485,7 @@ class ServerPreset:
@staticmethod
def jina_reranker_tiny() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
server.model_alias = "jina-reranker"
@@ -478,6 +499,7 @@ class ServerPreset:
@staticmethod
def tinygemma3() -> ServerProcess:
server = ServerProcess()
server.offline = True # will be downloaded by load_all()
# mmproj is already provided by HF registry API
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
server.model_hf_file = "tinygemma3-Q8_0.gguf"