mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
bench : cache the llama_context state at computed depth (#16944)
* bench : cache llama_context state at depth * cont : handle failures to restore the old state * cont : print information when the state is being reused
This commit is contained in:
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ctx_state {
|
||||||
|
int depth = 0; // in tokens
|
||||||
|
|
||||||
|
std::vector<uint8_t> buf; // the llama_context state buffer
|
||||||
|
};
|
||||||
|
|
||||||
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
|
|||||||
llama_model * lmodel = nullptr;
|
llama_model * lmodel = nullptr;
|
||||||
const cmd_params_instance * prev_inst = nullptr;
|
const cmd_params_instance * prev_inst = nullptr;
|
||||||
|
|
||||||
|
// store the llama_context state at the previous depth that we performed a test
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
|
||||||
|
ctx_state cstate;
|
||||||
|
|
||||||
int params_idx = 0;
|
int params_idx = 0;
|
||||||
auto params_count = params_instances.size();
|
auto params_count = params_instances.size();
|
||||||
for (const auto & inst : params_instances) {
|
for (const auto & inst : params_instances) {
|
||||||
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
|
|||||||
llama_memory_clear(llama_get_memory(ctx), false);
|
llama_memory_clear(llama_get_memory(ctx), false);
|
||||||
|
|
||||||
if (t.n_depth > 0) {
|
if (t.n_depth > 0) {
|
||||||
if (params.progress) {
|
bool is_cached = t.n_depth == cstate.depth;
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
|
||||||
i + 1, params.reps);
|
if (is_cached) {
|
||||||
|
// if previously we have computed at this depth, just restore the state
|
||||||
|
const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
||||||
|
if (ret == 0) {
|
||||||
|
// if the old state is incompatible with the current context - reprocess from scratch
|
||||||
|
is_cached = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
|
||||||
if (!res) {
|
if (!is_cached) {
|
||||||
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
if (params.progress) {
|
||||||
exit(1);
|
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
|
||||||
|
i + 1, params.reps);
|
||||||
|
}
|
||||||
|
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
|
||||||
|
if (!res) {
|
||||||
|
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// store the context state for reuse in later runs
|
||||||
|
cstate.depth = t.n_depth;
|
||||||
|
cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
|
||||||
|
llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
|
||||||
|
} else {
|
||||||
|
if (params.progress) {
|
||||||
|
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
|
||||||
|
i + 1, params.reps);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user