mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
server : handle failures to restore host cache (#17078)
* server : handle failures to restore host cache * server : add tests for the prompt cache
This commit is contained in:
@@ -1690,6 +1690,9 @@ struct server_slot {
|
|||||||
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
bool res = prompt_cache.load(prompt, tokens, ctx, id);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||||
|
prompt.tokens.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
|
import random
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
@@ -564,3 +566,43 @@ def test_cancel_request():
|
|||||||
time.sleep(1) # wait for HTTP_POLLING_SECONDS
|
time.sleep(1) # wait for HTTP_POLLING_SECONDS
|
||||||
res = server.make_request("GET", "/slots")
|
res = server.make_request("GET", "/slots")
|
||||||
assert res.body[0]["is_processing"] == False
|
assert res.body[0]["is_processing"] == False
|
||||||
|
|
||||||
|
|
||||||
|
# this test exercises the host-memory prompt cache
|
||||||
|
# ref: https://github.com/ggml-org/llama.cpp/pull/16391
|
||||||
|
# ref: https://github.com/ggml-org/llama.cpp/pull/17078
|
||||||
|
def test_completion_prompt_cache():
|
||||||
|
global server
|
||||||
|
server.n_slots = 2
|
||||||
|
server.kv_unified = True
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
for _ in range(16):
|
||||||
|
# generate alternating random prompts with variable lengths in order to get them in and out of the cache
|
||||||
|
r = random.randint(0, 4)
|
||||||
|
prompt = (" Hello " + str(r)) * (40 + r)
|
||||||
|
n_prompt = (40 + r)*5 + 2
|
||||||
|
n_predict = random.randint(1, 8)
|
||||||
|
|
||||||
|
res = server.make_request(
|
||||||
|
"POST",
|
||||||
|
"/completion",
|
||||||
|
data={
|
||||||
|
"prompt": prompt,
|
||||||
|
"n_predict": n_predict,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "content" in res.body
|
||||||
|
content = res.body["content"]
|
||||||
|
assert isinstance(content, str)
|
||||||
|
assert len(content) > 0
|
||||||
|
|
||||||
|
assert type(res.body["has_new_line"]) == bool
|
||||||
|
assert "timings" in res.body
|
||||||
|
timings = res.body["timings"]
|
||||||
|
|
||||||
|
assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
|
||||||
|
assert "predicted_n" in timings and timings["predicted_n"] == n_predict
|
||||||
|
assert "tokens" in res.body and isinstance(res.body["tokens"], list)
|
||||||
|
|||||||
Reference in New Issue
Block a user