server : host-memory prompt caching (#16391)

* minor : code style

* server : fix prompt similarity calculation

* server : initial host-memory prompt caching

* cont

* server : refactor

* cont

* cont : make the server task of the slot const

* cont : minor [no ci]

* server : cache prompts and checkpoints only for completion tasks

* server : improve prompt caching logic

* cont : fix check for number of cached prompts [no ci]

* server : improve caching logic, add -cram CLI arg

* server : print prompt mismatch info

* cont : better naming [no ci]

* server : improve prompt cache loading logic

* server : add option to debug the slot contents (#16482)

* server : add option to debug the slot contents

* Update tools/server/server.cpp

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>

* server : add option to disable prompt cache

---------

Co-authored-by: Xuan-Son Nguyen <son@huggingface.co>
This commit is contained in:
Georgi Gerganov
2025-10-09 18:54:51 +03:00
committed by GitHub
parent 8328fd4bae
commit d00cbea63c
10 changed files with 813 additions and 471 deletions

View File

@@ -66,8 +66,7 @@ def test_server_slots():
assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" in res.body[0]
assert res.body[0]["params"]["seed"] == server.seed
assert "params" not in res.body[0]
def test_load_split_model():

View File

@@ -19,8 +19,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
@@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
]
)
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):

View File

@@ -16,7 +16,7 @@ def create_server():
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
global server
@@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server

View File

@@ -4,6 +4,12 @@ from utils import *
server = ServerPreset.tinyllama2()
SHORT_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
""".strip()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
@@ -21,19 +27,18 @@ def create_server():
def test_ctx_shift_enabled():
# the prompt is 301 tokens
# the prompt is 226 tokens
# the slot context is 512/2 = 256 tokens
# the prompt is truncated to keep the last (301 - 256/2) = 173 tokens
# 96 tokens are generated thanks to shifting the context when it gets full
global server
server.enable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 96,
"prompt": LONG_TEXT,
"prompt": SHORT_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 173
assert res.body["timings"]["prompt_n"] == 226
assert res.body["timings"]["predicted_n"] == 96
assert res.body["truncated"] is True