From 7956bb4d7f430f23f3c4726f7e6404b89f6e20a4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 7 Nov 2025 21:23:11 +0200 Subject: [PATCH] bench : cache the llama_context state at computed depth (#16944) * bench : cache llama_context state at depth * cont : handle failures to restore the old state * cont : print information when the state is being reused --- tools/llama-bench/llama-bench.cpp | 47 ++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 0de07b9811..852a512451 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1919,6 +1919,12 @@ struct sql_printer : public printer { } }; +struct ctx_state { + int depth = 0; // in tokens + + std::vector buf; // the llama_context state buffer +}; + static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); @@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) { llama_model * lmodel = nullptr; const cmd_params_instance * prev_inst = nullptr; + // store the llama_context state at the previous depth that we performed a test + // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721 + ctx_state cstate; + int params_idx = 0; auto params_count = params_instances.size(); for (const auto & inst : params_instances) { @@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) { llama_memory_clear(llama_get_memory(ctx), false); if (t.n_depth > 0) { - if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, - i + 1, params.reps); + bool is_cached = t.n_depth == cstate.depth; + + if (is_cached) { + // if previously we have computed at this depth, just restore the state + const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); + if (ret == 0) { + // if the old state is incompatible with the current context - reprocess from scratch + is_cached = false; + } } - bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); - if (!res) { - fprintf(stderr, "%s: error: failed to run depth\n", __func__); - exit(1); + + if (!is_cached) { + if (params.progress) { + fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, + i + 1, params.reps); + } + bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run depth\n", __func__); + exit(1); + } + + // store the context state for reuse in later runs + cstate.depth = t.n_depth; + cstate.buf.resize(llama_state_seq_get_size(ctx, 0)); + llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); + } else { + if (params.progress) { + fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count, + i + 1, params.reps); + } } }