mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
bench : cache the llama_context state at computed depth (#16944)
* bench : cache llama_context state at depth * cont : handle failures to restore the old state * cont : print information when the state is being reused
This commit is contained in:
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
|
||||
}
|
||||
};
|
||||
|
||||
struct ctx_state {
|
||||
int depth = 0; // in tokens
|
||||
|
||||
std::vector<uint8_t> buf; // the llama_context state buffer
|
||||
};
|
||||
|
||||
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
|
||||
llama_model * lmodel = nullptr;
|
||||
const cmd_params_instance * prev_inst = nullptr;
|
||||
|
||||
// store the llama_context state at the previous depth that we performed a test
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
|
||||
ctx_state cstate;
|
||||
|
||||
int params_idx = 0;
|
||||
auto params_count = params_instances.size();
|
||||
for (const auto & inst : params_instances) {
|
||||
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
|
||||
llama_memory_clear(llama_get_memory(ctx), false);
|
||||
|
||||
if (t.n_depth > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
bool is_cached = t.n_depth == cstate.depth;
|
||||
|
||||
if (is_cached) {
|
||||
// if previously we have computed at this depth, just restore the state
|
||||
const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
||||
if (ret == 0) {
|
||||
// if the old state is incompatible with the current context - reprocess from scratch
|
||||
is_cached = false;
|
||||
}
|
||||
}
|
||||
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
||||
exit(1);
|
||||
|
||||
if (!is_cached) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||
if (!res) {
|
||||
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// store the context state for reuse in later runs
|
||||
cstate.depth = t.n_depth;
|
||||
cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
|
||||
llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
||||
} else {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
|
||||
i + 1, params.reps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user