mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
* server / ranking : add sorting and management of top_n
* Make the retro compatible if no top_n will return
all results
here is a script to make some test
```script
URL=${1:-http://127.0.0.1:8181}
curl "$URL/v1/rerank" -H "Content-Type: application/json" \
-d '{ "model": "M", "query": "What is the recipe to make bread ?",
"return_text" : true,
"texts" : true,
"top_n": 6,
"documents": [
"voici la recette pour faire du pain, il faut de la farine de l eau et du levain et du sel",
"it is a bear",
"bread recipe : floor, water, yest, salt",
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.",
"here is the ingedients to bake bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
"recipe to make cookies : floor, eggs, water, chocolat",
"here is the recipe to make bread : 500g floor, 350g water, 120g fresh refresh yest, 15g salt",
"il fait tres beau aujourd hui",
"je n ai pas faim, je ne veux pas manger",
"je suis a paris"
] }' | jq
```
* use resize() instead for(...)
* simplify top_n init since no need to return error
result to test :
./tests.sh unit/test_rerank.py -v -x
==================================================== test session starts =====================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 8 items
unit/test_rerank.py::test_rerank PASSED [ 12%]
unit/test_rerank.py::test_rerank_tei_format PASSED [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED [ 37%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED [ 50%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED [ 62%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED [ 75%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED [ 87%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED [100%]
===================================================== 8 passed in 4.31s ======================================================
* add rerank top_n unit test
here is the result :
./tests.sh unit/test_rerank.py -v -x
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.6.0 -- /home/yann/dev/yann/llama.cpp/tools/server/tests/test/bin/python3
cachedir: .pytest_cache
rootdir: /home/yann/dev/yann/llama.cpp/tools/server/tests
configfile: pytest.ini
plugins: anyio-4.11.0
collected 16 items
unit/test_rerank.py::test_rerank PASSED [ 6%]
unit/test_rerank.py::test_rerank_tei_format PASSED [ 12%]
unit/test_rerank.py::test_invalid_rerank_req[documents0] PASSED [ 18%]
unit/test_rerank.py::test_invalid_rerank_req[None] PASSED [ 25%]
unit/test_rerank.py::test_invalid_rerank_req[123] PASSED [ 31%]
unit/test_rerank.py::test_invalid_rerank_req[documents3] PASSED [ 37%]
unit/test_rerank.py::test_rerank_usage[Machine learning is-A machine-Learning is-19] PASSED [ 43%]
unit/test_rerank.py::test_rerank_usage[Which city?-Machine learning is -Paris, capitale de la-26] PASSED [ 50%]
unit/test_rerank.py::test_rerank_top_n[None-4] PASSED [ 56%]
unit/test_rerank.py::test_rerank_top_n[2-2] PASSED [ 62%]
unit/test_rerank.py::test_rerank_top_n[4-4] PASSED [ 68%]
unit/test_rerank.py::test_rerank_top_n[99-4] PASSED [ 75%]
unit/test_rerank.py::test_rerank_tei_top_n[None-4] PASSED [ 81%]
unit/test_rerank.py::test_rerank_tei_top_n[2-2] PASSED [ 87%]
unit/test_rerank.py::test_rerank_tei_top_n[4-4] PASSED [ 93%]
unit/test_rerank.py::test_rerank_tei_top_n[99-4] PASSED [100%]
=================================================================== 16 passed in 8.84s ===================================================================
* editor config check fix
147 lines
4.7 KiB
Python
147 lines
4.7 KiB
Python
import pytest
|
|
from utils import *
|
|
|
|
server = ServerPreset.jina_reranker_tiny()
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def create_server():
|
|
global server
|
|
server = ServerPreset.jina_reranker_tiny()
|
|
|
|
|
|
TEST_DOCUMENTS = [
|
|
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
|
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
|
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
|
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
|
]
|
|
|
|
|
|
def test_rerank():
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/rerank", data={
|
|
"query": "Machine learning is",
|
|
"documents": TEST_DOCUMENTS,
|
|
})
|
|
assert res.status_code == 200
|
|
assert len(res.body["results"]) == 4
|
|
|
|
most_relevant = res.body["results"][0]
|
|
least_relevant = res.body["results"][0]
|
|
for doc in res.body["results"]:
|
|
if doc["relevance_score"] > most_relevant["relevance_score"]:
|
|
most_relevant = doc
|
|
if doc["relevance_score"] < least_relevant["relevance_score"]:
|
|
least_relevant = doc
|
|
|
|
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
|
|
assert most_relevant["index"] == 2
|
|
assert least_relevant["index"] == 3
|
|
|
|
|
|
def test_rerank_tei_format():
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/rerank", data={
|
|
"query": "Machine learning is",
|
|
"texts": TEST_DOCUMENTS,
|
|
})
|
|
assert res.status_code == 200
|
|
assert len(res.body) == 4
|
|
|
|
most_relevant = res.body[0]
|
|
least_relevant = res.body[0]
|
|
for doc in res.body:
|
|
if doc["score"] > most_relevant["score"]:
|
|
most_relevant = doc
|
|
if doc["score"] < least_relevant["score"]:
|
|
least_relevant = doc
|
|
|
|
assert most_relevant["score"] > least_relevant["score"]
|
|
assert most_relevant["index"] == 2
|
|
assert least_relevant["index"] == 3
|
|
|
|
|
|
@pytest.mark.parametrize("documents", [
|
|
[],
|
|
None,
|
|
123,
|
|
[1, 2, 3],
|
|
])
|
|
def test_invalid_rerank_req(documents):
|
|
global server
|
|
server.start()
|
|
res = server.make_request("POST", "/rerank", data={
|
|
"query": "Machine learning is",
|
|
"documents": documents,
|
|
})
|
|
assert res.status_code == 400
|
|
assert "error" in res.body
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"query,doc1,doc2,n_tokens",
|
|
[
|
|
("Machine learning is", "A machine", "Learning is", 19),
|
|
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
|
|
]
|
|
)
|
|
def test_rerank_usage(query, doc1, doc2, n_tokens):
|
|
global server
|
|
server.start()
|
|
|
|
res = server.make_request("POST", "/rerank", data={
|
|
"query": query,
|
|
"documents": [
|
|
doc1,
|
|
doc2,
|
|
]
|
|
})
|
|
assert res.status_code == 200
|
|
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
|
assert res.body['usage']['prompt_tokens'] == n_tokens
|
|
|
|
|
|
@pytest.mark.parametrize("top_n,expected_len", [
|
|
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
|
(2, 2),
|
|
(4, 4),
|
|
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
|
])
|
|
def test_rerank_top_n(top_n, expected_len):
|
|
global server
|
|
server.start()
|
|
data = {
|
|
"query": "Machine learning is",
|
|
"documents": TEST_DOCUMENTS,
|
|
}
|
|
if top_n is not None:
|
|
data["top_n"] = top_n
|
|
|
|
res = server.make_request("POST", "/rerank", data=data)
|
|
assert res.status_code == 200
|
|
assert len(res.body["results"]) == expected_len
|
|
|
|
|
|
@pytest.mark.parametrize("top_n,expected_len", [
|
|
(None, len(TEST_DOCUMENTS)), # no top_n parameter
|
|
(2, 2),
|
|
(4, 4),
|
|
(99, len(TEST_DOCUMENTS)), # higher than available docs
|
|
])
|
|
def test_rerank_tei_top_n(top_n, expected_len):
|
|
global server
|
|
server.start()
|
|
data = {
|
|
"query": "Machine learning is",
|
|
"texts": TEST_DOCUMENTS,
|
|
}
|
|
if top_n is not None:
|
|
data["top_n"] = top_n
|
|
|
|
res = server.make_request("POST", "/rerank", data=data)
|
|
assert res.status_code == 200
|
|
assert len(res.body) == expected_len
|