mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
kv-cache : pad the cache size to 256 for performance (#17046)
* kv-cache : pad the size of the small SWA cache for performance * context : pad the total context to 256 * cont : future-proof the swa pad * server : adjust test params to new logic
This commit is contained in:
@@ -463,6 +463,7 @@ extern "C" {
|
|||||||
|
|
||||||
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
|
||||||
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
|
|||||||
@@ -114,10 +114,14 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
|
||||||
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
|
||||||
|
|
||||||
if (cparams.kv_unified) {
|
if (cparams.kv_unified) {
|
||||||
cparams.n_ctx_seq = cparams.n_ctx;
|
cparams.n_ctx_seq = cparams.n_ctx;
|
||||||
} else {
|
} else {
|
||||||
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||||
|
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
|
||||||
|
|
||||||
if (cparams.n_ctx_seq == 0) {
|
if (cparams.n_ctx_seq == 0) {
|
||||||
throw std::runtime_error("n_ctx_seq == 0");
|
throw std::runtime_error("n_ctx_seq == 0");
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|||||||
|
|
||||||
const uint32_t size_base = kv_size;
|
const uint32_t size_base = kv_size;
|
||||||
|
|
||||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
// note: the SWA cache is always padded to 256 for performance
|
||||||
|
// https://github.com/ggml-org/llama.cpp/issues/17037
|
||||||
|
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
|
||||||
|
|
||||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||||
if (swa_full) {
|
if (swa_full) {
|
||||||
|
|||||||
@@ -77,10 +77,10 @@ def test_different_draft_min_draft_max():
|
|||||||
|
|
||||||
def test_slot_ctx_not_exceeded():
|
def test_slot_ctx_not_exceeded():
|
||||||
global server
|
global server
|
||||||
server.n_ctx = 64
|
server.n_ctx = 256
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "Hello " * 56,
|
"prompt": "Hello " * 248,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"speculative.p_min": 0.0,
|
"speculative.p_min": 0.0,
|
||||||
@@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded():
|
|||||||
|
|
||||||
def test_with_ctx_shift():
|
def test_with_ctx_shift():
|
||||||
global server
|
global server
|
||||||
server.n_ctx = 64
|
server.n_ctx = 256
|
||||||
server.enable_ctx_shift = True
|
server.enable_ctx_shift = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "Hello " * 56,
|
"prompt": "Hello " * 248,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"n_predict": 64,
|
"n_predict": 256,
|
||||||
"speculative.p_min": 0.0,
|
"speculative.p_min": 0.0,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert len(res.body["content"]) > 0
|
assert len(res.body["content"]) > 0
|
||||||
assert res.body["tokens_predicted"] == 64
|
assert res.body["tokens_predicted"] == 256
|
||||||
assert res.body["truncated"] == True
|
assert res.body["truncated"] == True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user