Files
llama.cpp/tools/server/tests/unit/test_ctx_shift.py
Georgi Gerganov d00cbea63c server : host-memory prompt caching (#16391)
* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
2025-10-09 18:54:51 +03:00

90 lines
3.0 KiB
Python

import pytest
from utils import *
server = ServerPreset.tinyllama2()
SHORT_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
""".strip()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_ctx = 512
server.n_slots = 2
server.n_predict = 128
def test_ctx_shift_enabled():
# the prompt is 226 tokens
# the slot context is 512/2 = 256 tokens
# 96 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 96,
"prompt": SHORT_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 226
assert res.body["timings"]["predicted_n"] == 96
assert res.body["truncated"] is True
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
assert res.body["truncated"] == truncated
def test_ctx_shift_disabled_long_prompt():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code != 200
assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"]
def test_ctx_shift_disabled_stream():
global server
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"n_predict": 256,
"prompt": "Once",
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] == "length":
assert len(content) > 0
else:
assert choice["finish_reason"] is None
content += choice["text"]