mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
server : speed up tests (#15836)
* server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline
This commit is contained in:
@@ -53,7 +53,7 @@ import typer
|
|||||||
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
|
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
|
||||||
if True:
|
if True:
|
||||||
from tools.server.tests.utils import ServerProcess
|
from tools.server.tests.utils import ServerProcess
|
||||||
from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
|
from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -335,7 +335,7 @@ def run(
|
|||||||
# server.debug = True
|
# server.debug = True
|
||||||
|
|
||||||
with scoped_server(server):
|
with scoped_server(server):
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=15 * 60)
|
||||||
for ignore_chat_grammar in [False]:
|
for ignore_chat_grammar in [False]:
|
||||||
run(
|
run(
|
||||||
server,
|
server,
|
||||||
|
|||||||
@@ -5,6 +5,12 @@ from utils import *
|
|||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def do_something():
|
||||||
|
# this will be run once per test session, before any tests
|
||||||
|
ServerPreset.load_all()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
|
|||||||
@@ -14,14 +14,11 @@ from utils import *
|
|||||||
|
|
||||||
server: ServerProcess
|
server: ServerProcess
|
||||||
|
|
||||||
TIMEOUT_SERVER_START = 15*60
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
server.model_alias = "tinyllama-2"
|
server.model_alias = "tinyllama-2"
|
||||||
server.server_port = 8081
|
|
||||||
server.n_slots = 1
|
server.n_slots = 1
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe
|
|||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.reasoning_budget = reasoning_budget
|
server.reasoning_budget = reasoning_budget
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
|
|
||||||
res = server.make_request("POST", "/apply-template", data={
|
res = server.make_request("POST", "/apply-template", data={
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
|||||||
global server
|
global server
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
|
|
||||||
res = server.make_request("POST", "/apply-template", data={
|
res = server.make_request("POST", "/apply-template", data={
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s
|
|||||||
global server
|
global server
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
|
|
||||||
res = server.make_request("POST", "/apply-template", data={
|
res = server.make_request("POST", "/apply-template", data={
|
||||||
"messages": [
|
"messages": [
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from enum import Enum
|
|||||||
|
|
||||||
server: ServerProcess
|
server: ServerProcess
|
||||||
|
|
||||||
TIMEOUT_SERVER_START = 15*60
|
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
|
||||||
TIMEOUT_HTTP_REQUEST = 60
|
TIMEOUT_HTTP_REQUEST = 60
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
|||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||||
|
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
|||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
|||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||||
|
|
||||||
|
|
||||||
@@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||||
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start()
|
||||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
|
|||||||
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_START_SLOW)
|
||||||
|
|
||||||
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||||
|
|
||||||
|
|||||||
@@ -5,18 +5,31 @@ import requests
|
|||||||
|
|
||||||
server: ServerProcess
|
server: ServerProcess
|
||||||
|
|
||||||
|
def get_img_url(id: str) -> str:
|
||||||
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||||
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
||||||
|
if id == "IMG_URL_0":
|
||||||
|
return IMG_URL_0
|
||||||
|
elif id == "IMG_URL_1":
|
||||||
|
return IMG_URL_1
|
||||||
|
elif id == "IMG_BASE64_URI_0":
|
||||||
response = requests.get(IMG_URL_0)
|
response = requests.get(IMG_URL_0)
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||||
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
|
elif id == "IMG_BASE64_0":
|
||||||
|
response = requests.get(IMG_URL_0)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
|
return base64.b64encode(response.content).decode("utf-8")
|
||||||
|
elif id == "IMG_BASE64_URI_1":
|
||||||
response = requests.get(IMG_URL_1)
|
response = requests.get(IMG_URL_1)
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||||
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
|
elif id == "IMG_BASE64_1":
|
||||||
|
response = requests.get(IMG_URL_1)
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes
|
||||||
|
return base64.b64encode(response.content).decode("utf-8")
|
||||||
|
else:
|
||||||
|
return id
|
||||||
|
|
||||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||||
@@ -28,7 +41,7 @@ def create_server():
|
|||||||
|
|
||||||
def test_models_supports_multimodal_capability():
|
def test_models_supports_multimodal_capability():
|
||||||
global server
|
global server
|
||||||
server.start() # vision model may take longer to load due to download size
|
server.start()
|
||||||
res = server.make_request("GET", "/models", data={})
|
res = server.make_request("GET", "/models", data={})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
model_info = res.body["models"][0]
|
model_info = res.body["models"][0]
|
||||||
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
|
|||||||
|
|
||||||
def test_v1_models_supports_multimodal_capability():
|
def test_v1_models_supports_multimodal_capability():
|
||||||
global server
|
global server
|
||||||
server.start() # vision model may take longer to load due to download size
|
server.start()
|
||||||
res = server.make_request("GET", "/v1/models", data={})
|
res = server.make_request("GET", "/v1/models", data={})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
model_info = res.body["models"][0]
|
model_info = res.body["models"][0]
|
||||||
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
|
|||||||
"prompt, image_url, success, re_content",
|
"prompt, image_url, success, re_content",
|
||||||
[
|
[
|
||||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||||
("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
|
||||||
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
|
||||||
("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
|
||||||
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
|
||||||
("What is this:\n", "malformed", False, None),
|
("What is this:\n", "malformed", False, None),
|
||||||
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
||||||
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
||||||
@@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability():
|
|||||||
)
|
)
|
||||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||||
global server
|
global server
|
||||||
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
server.start()
|
||||||
if image_url == "IMG_BASE64_URI_0":
|
|
||||||
image_url = IMG_BASE64_URI_0
|
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
@@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|||||||
{"role": "user", "content": [
|
{"role": "user", "content": [
|
||||||
{"type": "text", "text": prompt},
|
{"type": "text", "text": prompt},
|
||||||
{"type": "image_url", "image_url": {
|
{"type": "image_url", "image_url": {
|
||||||
"url": image_url,
|
"url": get_img_url(image_url),
|
||||||
}},
|
}},
|
||||||
]},
|
]},
|
||||||
],
|
],
|
||||||
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
|||||||
"prompt, image_data, success, re_content",
|
"prompt, image_data, success, re_content",
|
||||||
[
|
[
|
||||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||||
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
|
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
|
||||||
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
|
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
|
||||||
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
||||||
("What is this:\n", "", False, None), # empty string
|
("What is this:\n", "", False, None), # empty string
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_vision_completion(prompt, image_data, success, re_content):
|
def test_vision_completion(prompt, image_data, success, re_content):
|
||||||
global server
|
global server
|
||||||
server.start() # vision model may take longer to load due to download size
|
server.start()
|
||||||
res = server.make_request("POST", "/completions", data={
|
res = server.make_request("POST", "/completions", data={
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
"prompt": {
|
||||||
|
JSON_PROMPT_STRING_KEY: prompt,
|
||||||
|
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
|
||||||
|
},
|
||||||
})
|
})
|
||||||
if success:
|
if success:
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
@@ -116,8 +130,8 @@ def test_vision_completion(prompt, image_data, success, re_content):
|
|||||||
"prompt, image_data, success",
|
"prompt, image_data, success",
|
||||||
[
|
[
|
||||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||||
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
|
("What is this: <__media__>\n", "IMG_BASE64_0", True),
|
||||||
("What is this: <__media__>\n", IMG_BASE64_1, True),
|
("What is this: <__media__>\n", "IMG_BASE64_1", True),
|
||||||
("What is this: <__media__>\n", "malformed", False), # non-image data
|
("What is this: <__media__>\n", "malformed", False), # non-image data
|
||||||
("What is this:\n", "base64", False), # non-image data
|
("What is this:\n", "base64", False), # non-image data
|
||||||
]
|
]
|
||||||
@@ -126,7 +140,8 @@ def test_vision_embeddings(prompt, image_data, success):
|
|||||||
global server
|
global server
|
||||||
server.server_embeddings = True
|
server.server_embeddings = True
|
||||||
server.n_batch = 512
|
server.n_batch = 512
|
||||||
server.start() # vision model may take longer to load due to download size
|
server.start()
|
||||||
|
image_data = get_img_url(image_data)
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/embeddings", data={
|
||||||
"content": [
|
"content": [
|
||||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from re import RegexFlag
|
|||||||
import wget
|
import wget
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_HTTP_TIMEOUT = 30
|
DEFAULT_HTTP_TIMEOUT = 60
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
@@ -45,6 +45,7 @@ class ServerProcess:
|
|||||||
model_alias: str = "tinyllama-2"
|
model_alias: str = "tinyllama-2"
|
||||||
temperature: float = 0.8
|
temperature: float = 0.8
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
offline: bool = False
|
||||||
|
|
||||||
# custom options
|
# custom options
|
||||||
model_alias: str | None = None
|
model_alias: str | None = None
|
||||||
@@ -118,6 +119,8 @@ class ServerProcess:
|
|||||||
"--seed",
|
"--seed",
|
||||||
self.seed,
|
self.seed,
|
||||||
]
|
]
|
||||||
|
if self.offline:
|
||||||
|
server_args.append("--offline")
|
||||||
if self.model_file:
|
if self.model_file:
|
||||||
server_args.extend(["--model", self.model_file])
|
server_args.extend(["--model", self.model_file])
|
||||||
if self.model_url:
|
if self.model_url:
|
||||||
@@ -392,6 +395,19 @@ server_instances: Set[ServerProcess] = set()
|
|||||||
|
|
||||||
|
|
||||||
class ServerPreset:
|
class ServerPreset:
|
||||||
|
@staticmethod
|
||||||
|
def load_all() -> None:
|
||||||
|
""" Load all server presets to ensure model files are cached. """
|
||||||
|
servers: List[ServerProcess] = [
|
||||||
|
method()
|
||||||
|
for name, method in ServerPreset.__dict__.items()
|
||||||
|
if callable(method) and name != "load_all"
|
||||||
|
]
|
||||||
|
for server in servers:
|
||||||
|
server.offline = False
|
||||||
|
server.start()
|
||||||
|
server.stop()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tinyllama2() -> ServerProcess:
|
def tinyllama2() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
@@ -408,6 +424,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def bert_bge_small() -> ServerProcess:
|
def bert_bge_small() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||||
server.model_alias = "bert-bge-small"
|
server.model_alias = "bert-bge-small"
|
||||||
@@ -422,6 +439,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def bert_bge_small_with_fa() -> ServerProcess:
|
def bert_bge_small_with_fa() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
|
||||||
server.model_alias = "bert-bge-small"
|
server.model_alias = "bert-bge-small"
|
||||||
@@ -437,6 +455,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def tinyllama_infill() -> ServerProcess:
|
def tinyllama_infill() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
|
||||||
server.model_alias = "tinyllama-infill"
|
server.model_alias = "tinyllama-infill"
|
||||||
@@ -451,6 +470,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def stories15m_moe() -> ServerProcess:
|
def stories15m_moe() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/stories15M_MOE"
|
server.model_hf_repo = "ggml-org/stories15M_MOE"
|
||||||
server.model_hf_file = "stories15M_MOE-F16.gguf"
|
server.model_hf_file = "stories15M_MOE-F16.gguf"
|
||||||
server.model_alias = "stories15m-moe"
|
server.model_alias = "stories15m-moe"
|
||||||
@@ -465,6 +485,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def jina_reranker_tiny() -> ServerProcess:
|
def jina_reranker_tiny() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
server.model_hf_repo = "ggml-org/models"
|
server.model_hf_repo = "ggml-org/models"
|
||||||
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
|
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
|
||||||
server.model_alias = "jina-reranker"
|
server.model_alias = "jina-reranker"
|
||||||
@@ -478,6 +499,7 @@ class ServerPreset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def tinygemma3() -> ServerProcess:
|
def tinygemma3() -> ServerProcess:
|
||||||
server = ServerProcess()
|
server = ServerProcess()
|
||||||
|
server.offline = True # will be downloaded by load_all()
|
||||||
# mmproj is already provided by HF registry API
|
# mmproj is already provided by HF registry API
|
||||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
|
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
|
||||||
server.model_hf_file = "tinygemma3-Q8_0.gguf"
|
server.model_hf_file = "tinygemma3-Q8_0.gguf"
|
||||||
|
|||||||
Reference in New Issue
Block a user