import pytest from utils import * import base64 import requests server: ServerProcess def get_img_url(id: str) -> str: IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" if id == "IMG_URL_0": return IMG_URL_0 elif id == "IMG_URL_1": return IMG_URL_1 elif id == "IMG_BASE64_URI_0": response = requests.get(IMG_URL_0) response.raise_for_status() # Raise an exception for bad status codes return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") elif id == "IMG_BASE64_0": response = requests.get(IMG_URL_0) response.raise_for_status() # Raise an exception for bad status codes return base64.b64encode(response.content).decode("utf-8") elif id == "IMG_BASE64_URI_1": response = requests.get(IMG_URL_1) response.raise_for_status() # Raise an exception for bad status codes return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") elif id == "IMG_BASE64_1": response = requests.get(IMG_URL_1) response.raise_for_status() # Raise an exception for bad status codes return base64.b64encode(response.content).decode("utf-8") else: return id JSON_MULTIMODAL_KEY = "multimodal_data" JSON_PROMPT_STRING_KEY = "prompt_string" @pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinygemma3() def test_models_supports_multimodal_capability(): global server server.start() res = server.make_request("GET", "/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] print(model_info) assert "completion" in model_info["capabilities"] assert "multimodal" in model_info["capabilities"] def test_v1_models_supports_multimodal_capability(): global server server.start() res = server.make_request("GET", "/v1/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] print(model_info) assert "completion" in model_info["capabilities"] assert "multimodal" in model_info["capabilities"] @pytest.mark.parametrize( "prompt, image_url, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size ("What is this:\n", "IMG_URL_0", True, "(cat)+"), ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), ("What is this:\n", "IMG_URL_1", True, "(frog)+"), ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache ("What is this:\n", "malformed", False, None), ("What is this:\n", "https://google.com/404", False, None), # non-existent image ("What is this:\n", "https://ggml.ai", False, None), # non-image data # TODO @ngxson : test with multiple images, no images and with audio ] ) def test_vision_chat_completion(prompt, image_url, success, re_content): global server server.start() res = server.make_request("POST", "/chat/completions", data={ "temperature": 0.0, "top_k": 1, "messages": [ {"role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": { "url": get_img_url(image_url), }}, ]}, ], }) if success: assert res.status_code == 200 choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] assert match_regex(re_content, choice["message"]["content"]) else: assert res.status_code != 200 @pytest.mark.parametrize( "prompt, image_data, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"), ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"), ("What is this: <__media__>\n", "malformed", False, None), # non-image data ("What is this:\n", "", False, None), # empty string ] ) def test_vision_completion(prompt, image_data, success, re_content): global server server.start() res = server.make_request("POST", "/completions", data={ "temperature": 0.0, "top_k": 1, "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], }, }) if success: assert res.status_code == 200 content = res.body["content"] assert match_regex(re_content, content) else: assert res.status_code != 200 @pytest.mark.parametrize( "prompt, image_data, success", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size ("What is this: <__media__>\n", "IMG_BASE64_0", True), ("What is this: <__media__>\n", "IMG_BASE64_1", True), ("What is this: <__media__>\n", "malformed", False), # non-image data ("What is this:\n", "base64", False), # non-image data ] ) def test_vision_embeddings(prompt, image_data, success): global server server.server_embeddings = True server.n_batch = 512 server.start() image_data = get_img_url(image_data) res = server.make_request("POST", "/embeddings", data={ "content": [ { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, { JSON_PROMPT_STRING_KEY: prompt, }, ], }) if success: assert res.status_code == 200 content = res.body # Ensure embeddings are stable when multimodal. assert content[0]['embedding'] == content[1]['embedding'] # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings. assert content[0]['embedding'] != content[2]['embedding'] else: assert res.status_code != 200