mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	 05c3a444b8
			
		
	
	05c3a444b8
	
	
	
		
			
			* server : fill usage info in embeddings response * server : fill usage info in reranking response
		
			
				
	
	
		
			130 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			130 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| from openai import OpenAI
 | |
| from utils import *
 | |
| 
 | |
| server = ServerPreset.bert_bge_small()
 | |
| 
 | |
| EPSILON = 1e-3
 | |
| 
 | |
| @pytest.fixture(scope="module", autouse=True)
 | |
| def create_server():
 | |
|     global server
 | |
|     server = ServerPreset.bert_bge_small()
 | |
| 
 | |
| 
 | |
| def test_embedding_single():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "input": "I believe the meaning of life is",
 | |
|     })
 | |
|     assert res.status_code == 200
 | |
|     assert len(res.body['data']) == 1
 | |
|     assert 'embedding' in res.body['data'][0]
 | |
|     assert len(res.body['data'][0]['embedding']) > 1
 | |
| 
 | |
|     # make sure embedding vector is normalized
 | |
|     assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
 | |
| 
 | |
| 
 | |
| def test_embedding_multiple():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "input": [
 | |
|             "I believe the meaning of life is",
 | |
|             "Write a joke about AI from a very long prompt which will not be truncated",
 | |
|             "This is a test",
 | |
|             "This is another test",
 | |
|         ],
 | |
|     })
 | |
|     assert res.status_code == 200
 | |
|     assert len(res.body['data']) == 4
 | |
|     for d in res.body['data']:
 | |
|         assert 'embedding' in d
 | |
|         assert len(d['embedding']) > 1
 | |
| 
 | |
| 
 | |
| def test_embedding_openai_library_single():
 | |
|     global server
 | |
|     server.start()
 | |
|     client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
 | |
|     res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
 | |
|     assert len(res.data) == 1
 | |
|     assert len(res.data[0].embedding) > 1
 | |
| 
 | |
| 
 | |
| def test_embedding_openai_library_multiple():
 | |
|     global server
 | |
|     server.start()
 | |
|     client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
 | |
|     res = client.embeddings.create(model="text-embedding-3-small", input=[
 | |
|         "I believe the meaning of life is",
 | |
|         "Write a joke about AI from a very long prompt which will not be truncated",
 | |
|         "This is a test",
 | |
|         "This is another test",
 | |
|     ])
 | |
|     assert len(res.data) == 4
 | |
|     for d in res.data:
 | |
|         assert len(d.embedding) > 1
 | |
| 
 | |
| 
 | |
| def test_embedding_error_prompt_too_long():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "input": "This is a test " * 512,
 | |
|     })
 | |
|     assert res.status_code != 200
 | |
|     assert "too large" in res.body["error"]["message"]
 | |
| 
 | |
| 
 | |
| def test_same_prompt_give_same_result():
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "input": [
 | |
|             "I believe the meaning of life is",
 | |
|             "I believe the meaning of life is",
 | |
|             "I believe the meaning of life is",
 | |
|             "I believe the meaning of life is",
 | |
|             "I believe the meaning of life is",
 | |
|         ],
 | |
|     })
 | |
|     assert res.status_code == 200
 | |
|     assert len(res.body['data']) == 5
 | |
|     for i in range(1, len(res.body['data'])):
 | |
|         v0 = res.body['data'][0]['embedding']
 | |
|         vi = res.body['data'][i]['embedding']
 | |
|         for x, y in zip(v0, vi):
 | |
|             assert abs(x - y) < EPSILON
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize(
 | |
|     "content,n_tokens",
 | |
|     [
 | |
|         ("I believe the meaning of life is", 7),
 | |
|         ("This is a test", 4),
 | |
|     ]
 | |
| )
 | |
| def test_embedding_usage_single(content, n_tokens):
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={"input": content})
 | |
|     assert res.status_code == 200
 | |
|     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
 | |
|     assert res.body['usage']['prompt_tokens'] == n_tokens
 | |
| 
 | |
| 
 | |
| def test_embedding_usage_multiple():
 | |
|     global server
 | |
|     server.start()
 | |
|     res = server.make_request("POST", "/embeddings", data={
 | |
|         "input": [
 | |
|             "I believe the meaning of life is",
 | |
|             "I believe the meaning of life is",
 | |
|         ],
 | |
|     })
 | |
|     assert res.status_code == 200
 | |
|     assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
 | |
|     assert res.body['usage']['prompt_tokens'] == 2 * 7
 |