mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
server : Support multimodal completion and embeddings prompts in JSON format (#15108)
- Use server_tokens in more places in server and util.cpp - Convert most functions that used llama_tokens to server_tokens - Modify input tokenizer to handle JSON objects as subprompts - Break out MTMD prompt parsing into utility function - Support JSON objects with multimodal_data arrays for MTMD prompts along with other existing types - Add capability to model endpoint to indicate if client can send multimodal data - Add tests.
This commit is contained in:
@@ -6,6 +6,8 @@ from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
@@ -231,6 +233,28 @@ def test_nocache_long_input_prompt():
|
||||
})
|
||||
assert res.status_code == 400
|
||||
|
||||
def test_json_prompt_no_mtmd():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
def test_json_prompt_mtm_error_when_not_supported():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: "I believe the meaning of life is <__media__>", JSON_MULTIMODAL_KEY: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" },
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"cache_prompt": False,
|
||||
})
|
||||
# MTMD is disabled on this model, so this should fail.
|
||||
assert res.status_code != 200
|
||||
|
||||
def test_completion_with_tokens_input():
|
||||
global server
|
||||
@@ -269,6 +293,20 @@ def test_completion_with_tokens_input():
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed JSON and tokens
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [
|
||||
tokens,
|
||||
{
|
||||
JSON_PROMPT_STRING_KEY: "I believe the meaning of life is",
|
||||
},
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert type(res.body) == list
|
||||
assert len(res.body) == 2
|
||||
assert res.body[0]["content"] == res.body[1]["content"]
|
||||
|
||||
# mixed string and tokens in one sequence
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||
|
||||
@@ -10,21 +10,48 @@ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/9
|
||||
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
response = requests.get(IMG_URL_1)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
JSON_MULTIMODAL_KEY = "multimodal_data"
|
||||
JSON_PROMPT_STRING_KEY = "prompt_string"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
|
||||
def test_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("GET", "/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
def test_v1_models_supports_multimodal_capability():
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("GET", "/v1/models", data={})
|
||||
assert res.status_code == 200
|
||||
model_info = res.body["models"][0]
|
||||
print(model_info)
|
||||
assert "completion" in model_info["capabilities"]
|
||||
assert "multimodal" in model_info["capabilities"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_url, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
||||
("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
||||
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
||||
("What is this:\n", "malformed", False, None),
|
||||
@@ -36,8 +63,8 @@ def create_server():
|
||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
global server
|
||||
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
||||
if image_url == "IMG_BASE64_0":
|
||||
image_url = IMG_BASE64_0
|
||||
if image_url == "IMG_BASE64_URI_0":
|
||||
image_url = IMG_BASE64_URI_0
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
@@ -58,3 +85,61 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
|
||||
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
|
||||
("What is this: <__media__>\n", "malformed", False, None), # non-image data
|
||||
("What is this:\n", "", False, None), # empty string
|
||||
]
|
||||
)
|
||||
def test_vision_completion(prompt, image_data, success, re_content):
|
||||
global server
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("POST", "/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body["content"]
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_data, success",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
|
||||
("What is this: <__media__>\n", IMG_BASE64_1, True),
|
||||
("What is this: <__media__>\n", "malformed", False), # non-image data
|
||||
("What is this:\n", "base64", False), # non-image data
|
||||
]
|
||||
)
|
||||
def test_vision_embeddings(prompt, image_data, success):
|
||||
global server
|
||||
server.server_embeddings=True
|
||||
server.n_batch=512
|
||||
server.start() # vision model may take longer to load due to download size
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"content": [
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
|
||||
{ JSON_PROMPT_STRING_KEY: prompt, },
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
content = res.body
|
||||
# Ensure embeddings are stable when multimodal.
|
||||
assert content[0]['embedding'] == content[1]['embedding']
|
||||
# Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
|
||||
assert content[0]['embedding'] != content[2]['embedding']
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
Reference in New Issue
Block a user