llama : reuse compute graphs (#14482)

* llama : reuse compute graphs

ggml-ci

* llama-bench : add graph reuse parameter

ggml-ci

* cont : remove the parameter and the sched resets

ggml-ci

* graph : rename update() to can_reuse()

ggml-ci

* params : remove is_same()

ggml-ci

* graph : set res->params in llm_graph_context constructor

ggml-ci

* graph : avoid set_max_nodes in llm_graph_result

ggml-ci

* kv-cache : reuse llama_context's graph result instance

ggml-ci

* context : reset the previous graph result upon memory updates

ggml-ci

* batch : llama_ubatch now carries its data instead of pointing to balloc

ggml-ci

* merge : fix build

ggml-ci

* graph : fix can_reuse() checks when flash-attention is disabled

* graph : move llm_graph_result impl in source file + debug env

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-17 19:08:33 +03:00
committed by GitHub
parent 086cf81e88
commit 01612b7409
12 changed files with 548 additions and 289 deletions

View File

@@ -193,7 +193,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0;
if (!supports_set_rows) {
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
@@ -656,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched);
auto * gf = lctx->graph_init();
auto * res = lctx->get_gf_res_reserve();
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
return updated;
}
res->reset();
auto * gf = build_graph_shift(res, lctx);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
return updated;
@@ -713,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
ggml_backend_sched_reset(sched);
auto * gf = lctx->graph_init();
auto * res = lctx->get_gf_res_reserve();
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
return updated;
}
res->reset();
auto * gf = build_graph_defrag(res, lctx, dinfo);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
return updated;
@@ -1035,6 +1029,10 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
return result;
}
bool llama_kv_cache_unified::get_supports_set_rows() const {
return supports_set_rows;
}
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
@@ -1297,6 +1295,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
// TODO: optimize this section
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_stream; ++s) {
for (uint32_t ii = 0; ii < n_tps; ++ii) {
@@ -1346,7 +1345,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
const auto & cells = v_cells[0];
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
int32_t * data = (int32_t *) dst->data;
@@ -1464,11 +1463,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const {
auto res = std::make_unique<llm_graph_result>();
ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
@@ -1478,6 +1475,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
ggml_set_input(inp->k_shift);
const auto & cparams = lctx->get_cparams();
for (const auto & layer : layers) {
const uint32_t il = layer.il;
@@ -1503,15 +1502,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
res->add_input(std::move(inp));
return res;
return gf;
}
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf,
const defrag_info & dinfo) const {
auto res = std::make_unique<llm_graph_result>();
ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
llm_graph_result * res,
llama_context * lctx,
const defrag_info & dinfo) const {
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
@@ -1519,6 +1518,8 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
const auto & ids = dinfo.ids;
const auto & cparams = lctx->get_cparams();
#if 0
// CPU defrag
//
@@ -1655,7 +1656,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
#endif
return res;
return gf;
}
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
@@ -2331,6 +2332,10 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
return n_kv;
}
bool llama_kv_cache_unified_context::get_supports_set_rows() const {
return kv->get_supports_set_rows();
}
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}