bench : cache the llama_context state at computed depth (#16944)

* bench : cache llama_context state at depth

* cont : handle failures to restore the old state

* cont : print information when the state is being reused
This commit is contained in:
Georgi Gerganov
2025-11-07 21:23:11 +02:00
committed by GitHub
parent 9008027aa3
commit 7956bb4d7f

View File

@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
} }
}; };
struct ctx_state {
int depth = 0; // in tokens
std::vector<uint8_t> buf; // the llama_context state buffer
};
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
llama_model * lmodel = nullptr; llama_model * lmodel = nullptr;
const cmd_params_instance * prev_inst = nullptr; const cmd_params_instance * prev_inst = nullptr;
// store the llama_context state at the previous depth that we performed a test
// ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
ctx_state cstate;
int params_idx = 0; int params_idx = 0;
auto params_count = params_instances.size(); auto params_count = params_instances.size();
for (const auto & inst : params_instances) { for (const auto & inst : params_instances) {
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
llama_memory_clear(llama_get_memory(ctx), false); llama_memory_clear(llama_get_memory(ctx), false);
if (t.n_depth > 0) { if (t.n_depth > 0) {
if (params.progress) { bool is_cached = t.n_depth == cstate.depth;
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
i + 1, params.reps); if (is_cached) {
// if previously we have computed at this depth, just restore the state
const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
if (ret == 0) {
// if the old state is incompatible with the current context - reprocess from scratch
is_cached = false;
}
} }
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
if (!res) { if (!is_cached) {
fprintf(stderr, "%s: error: failed to run depth\n", __func__); if (params.progress) {
exit(1); fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
if (!res) {
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
exit(1);
}
// store the context state for reuse in later runs
cstate.depth = t.n_depth;
cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
} else {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
i + 1, params.reps);
}
} }
} }