mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
server : support unified cache across slots (#16736)
* server : support unified context across slots * cont : fix speculative decoding initialization * context : fix n_ctx_per_seq computation * server : purge slots one by one * tests : add unified cache server tests * llama : update per-seq context computation * test-thread-safety : handle tiny training context of the input model * server : fix server_tokens clear() * server : use 4 slots + unified KV by default * llama : add note about context size queries * cont : update todos [no ci] * context : do not cap the size of the context * tests : adjust parameters to be CI friendlier * context : add warning
This commit is contained in:
@@ -433,21 +433,21 @@ def test_context_size_exceeded_stream():
|
||||
@pytest.mark.parametrize(
|
||||
"n_batch,batch_count,reuse_cache",
|
||||
[
|
||||
(64, 15, False),
|
||||
(64, 3, False),
|
||||
(64, 1, True),
|
||||
]
|
||||
)
|
||||
def test_return_progresssss(n_batch, batch_count, reuse_cache):
|
||||
def test_return_progress(n_batch, batch_count, reuse_cache):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
server.n_ctx = 2048
|
||||
server.n_ctx = 256
|
||||
server.n_slots = 1
|
||||
server.start()
|
||||
def make_cmpl_request():
|
||||
return server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 10,
|
||||
"messages": [
|
||||
{"role": "user", "content": "This is a test" * 100},
|
||||
{"role": "user", "content": "This is a test" * 10},
|
||||
],
|
||||
"stream": True,
|
||||
"return_progress": True,
|
||||
|
||||
@@ -368,6 +368,37 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_ctx,n_slots,n_predict_vals,expected_success",
|
||||
[
|
||||
(256, 4, [80, 40, 80, 80], [True, True, True, True]),
|
||||
(256, 4, [70, 70, 70, 70], [False, False, False, False]),
|
||||
(256, 4, [90, 90, 40, 90], [False, False, True, False]),
|
||||
(256, 4, [90, 90, 40, 75], [True, True, True, True]),
|
||||
],
|
||||
)
|
||||
def test_completion_unified(n_ctx, n_slots, n_predict_vals, expected_success):
|
||||
global server
|
||||
server.n_slots = n_slots
|
||||
server.kv_unified = True
|
||||
server.n_ctx = n_ctx
|
||||
server.start()
|
||||
prompt = "A"
|
||||
tasks = []
|
||||
for n_predict in n_predict_vals:
|
||||
tasks.append((server.make_request, ("POST", "/completion", {"prompt": prompt, "n_predict": n_predict})))
|
||||
results = parallel_function_calls(tasks)
|
||||
for res, n_predict, expect_ok in zip(results, n_predict_vals, expected_success):
|
||||
if expect_ok:
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
if "timings" in res.body:
|
||||
assert res.body["timings"]["predicted_n"] == n_predict
|
||||
else:
|
||||
assert res.status_code == 500
|
||||
assert "content" not in res.body
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,n_predict,response_fields",
|
||||
[
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_infill_without_input_extra():
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
|
||||
assert match_regex("(Ann|small|shiny|Daddy|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
def test_infill_with_input_extra():
|
||||
@@ -34,7 +34,7 @@ def test_infill_with_input_extra():
|
||||
"input_suffix": "}\n",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert match_regex("(Dad|excited|park)+", res.body["content"])
|
||||
assert match_regex("(Dad|excited|park|Jimmy)+", res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_extra", [
|
||||
|
||||
@@ -78,6 +78,7 @@ class ServerProcess:
|
||||
server_embeddings: bool | None = False
|
||||
server_reranking: bool | None = False
|
||||
server_metrics: bool | None = False
|
||||
kv_unified: bool | None = False
|
||||
server_slots: bool | None = False
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
@@ -159,6 +160,8 @@ class ServerProcess:
|
||||
server_args.append("--reranking")
|
||||
if self.server_metrics:
|
||||
server_args.append("--metrics")
|
||||
if self.kv_unified:
|
||||
server_args.append("--kv-unified")
|
||||
if self.server_slots:
|
||||
server_args.append("--slots")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user