mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	 3c3635d2f2
			
		
	
	3c3635d2f2
	
	
	
		
			
			* server : speed up tests * clean up * restore timeout_seconds in some places * flake8 * explicit offline
		
			
				
	
	
		
			161 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			161 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| from utils import *
 | |
| import base64
 | |
| import requests
 | |
| 
 | |
| server: ServerProcess
 | |
| 
 | |
| def get_img_url(id: str) -> str:
 | |
|     IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
 | |
|     IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
 | |
|     if id == "IMG_URL_0":
 | |
|         return IMG_URL_0
 | |
|     elif id == "IMG_URL_1":
 | |
|         return IMG_URL_1
 | |
|     elif id == "IMG_BASE64_URI_0":
 | |
|         response = requests.get(IMG_URL_0)
 | |
|         response.raise_for_status() # Raise an exception for bad status codes
 | |
|         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
 | |
|     elif id == "IMG_BASE64_0":
 | |
|         response = requests.get(IMG_URL_0)
 | |
|         response.raise_for_status() # Raise an exception for bad status codes
 | |
|         return base64.b64encode(response.content).decode("utf-8")
 | |
|     elif id == "IMG_BASE64_URI_1":
 | |
|         response = requests.get(IMG_URL_1)
 | |
|         response.raise_for_status() # Raise an exception for bad status codes
 | |
|         return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
 | |
|     elif id == "IMG_BASE64_1":
 | |
|         response = requests.get(IMG_URL_1)
 | |
|         response.raise_for_status() # Raise an exception for bad status codes
 | |
|         return base64.b64encode(response.content).decode("utf-8")
 | |
|     else:
 | |
|         return id
 | |
| 
 | |
| JSON_MULTIMODAL_KEY = "multimodal_data"
 | |
| JSON_PROMPT_STRING_KEY = "prompt_string"
 | |
| 
 | |
| @pytest.fixture(autouse=True)
 | |
| def create_server():
 | |
|     global server
 | |
|     server = ServerPreset.tinygemma3()
 | |
| 
 | |
| def test_models_supports_multimodal_capability():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("GET", "/models", data={})
 | |
|     assert res.status_code == 200
 | |
|     model_info = res.body["models"][0]
 | |
|     print(model_info)
 | |
|     assert "completion" in model_info["capabilities"]
 | |
|     assert "multimodal" in model_info["capabilities"]
 | |
| 
 | |
| def test_v1_models_supports_multimodal_capability():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("GET", "/v1/models", data={})
 | |
|     assert res.status_code == 200
 | |
|     model_info = res.body["models"][0]
 | |
|     print(model_info)
 | |
|     assert "completion" in model_info["capabilities"]
 | |
|     assert "multimodal" in model_info["capabilities"]
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "prompt, image_url, success, re_content",
 | |
|     [
 | |
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size
 | |
|         ("What is this:\n", "IMG_URL_0",              True, "(cat)+"),
 | |
|         ("What is this:\n", "IMG_BASE64_URI_0",       True, "(cat)+"),
 | |
|         ("What is this:\n", "IMG_URL_1",              True, "(frog)+"),
 | |
|         ("Test test\n",     "IMG_URL_1",              True, "(frog)+"), # test invalidate cache
 | |
|         ("What is this:\n", "malformed",              False, None),
 | |
|         ("What is this:\n", "https://google.com/404", False, None), # non-existent image
 | |
|         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
 | |
|         # TODO @ngxson : test with multiple images, no images and with audio
 | |
|     ]
 | |
| )
 | |
| def test_vision_chat_completion(prompt, image_url, success, re_content):
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/chat/completions", data={
 | |
|         "temperature": 0.0,
 | |
|         "top_k": 1,
 | |
|         "messages": [
 | |
|             {"role": "user", "content": [
 | |
|                 {"type": "text", "text": prompt},
 | |
|                 {"type": "image_url", "image_url": {
 | |
|                     "url": get_img_url(image_url),
 | |
|                 }},
 | |
|             ]},
 | |
|         ],
 | |
|     })
 | |
|     if success:
 | |
|         assert res.status_code == 200
 | |
|         choice = res.body["choices"][0]
 | |
|         assert "assistant" == choice["message"]["role"]
 | |
|         assert match_regex(re_content, choice["message"]["content"])
 | |
|     else:
 | |
|         assert res.status_code != 200
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "prompt, image_data, success, re_content",
 | |
|     [
 | |
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size
 | |
|         ("What is this: <__media__>\n", "IMG_BASE64_0",         True, "(cat)+"),
 | |
|         ("What is this: <__media__>\n", "IMG_BASE64_1",         True, "(frog)+"),
 | |
|         ("What is this: <__media__>\n", "malformed",            False, None), # non-image data
 | |
|         ("What is this:\n",             "",                     False, None), # empty string
 | |
|     ]
 | |
| )
 | |
| def test_vision_completion(prompt, image_data, success, re_content):
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/completions", data={
 | |
|         "temperature": 0.0,
 | |
|         "top_k": 1,
 | |
|         "prompt": {
 | |
|             JSON_PROMPT_STRING_KEY: prompt,
 | |
|             JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
 | |
|         },
 | |
|     })
 | |
|     if success:
 | |
|         assert res.status_code == 200
 | |
|         content = res.body["content"]
 | |
|         assert match_regex(re_content, content)
 | |
|     else:
 | |
|         assert res.status_code != 200
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "prompt, image_data, success",
 | |
|     [
 | |
|         # test model is trained on CIFAR-10, but it's quite dumb due to small size
 | |
|         ("What is this: <__media__>\n", "IMG_BASE64_0",         True),
 | |
|         ("What is this: <__media__>\n", "IMG_BASE64_1",         True),
 | |
|         ("What is this: <__media__>\n", "malformed",            False), # non-image data
 | |
|         ("What is this:\n",             "base64",               False), # non-image data
 | |
|     ]
 | |
| )
 | |
| def test_vision_embeddings(prompt, image_data, success):
 | |
|     global server
 | |
|     server.server_embeddings = True
 | |
|     server.n_batch = 512
 | |
|     server.start()
 | |
|     image_data = get_img_url(image_data)
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "content": [
 | |
|             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
 | |
|             { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
 | |
|             { JSON_PROMPT_STRING_KEY: prompt, },
 | |
|         ],
 | |
|     })
 | |
|     if success:
 | |
|         assert res.status_code == 200
 | |
|         content = res.body
 | |
|         # Ensure embeddings are stable when multimodal.
 | |
|         assert content[0]['embedding'] == content[1]['embedding']
 | |
|         # Ensure embeddings without multimodal but same prompt do not match multimodal embeddings.
 | |
|         assert content[0]['embedding'] != content[2]['embedding']
 | |
|     else:
 | |
|         assert res.status_code != 200
 |