From 717f67cf589b5d99923c8e12c638557cf9eff886 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Oct 2025 19:33:30 +0300 Subject: [PATCH] graph : reuse hybrid graphs --- src/llama-graph.cpp | 41 ++++++++++++++++++++++++++++++++++--- src/llama-graph.h | 10 +++++++-- src/llama-memory-hybrid.cpp | 2 +- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 41fa689437..e4f17d46a3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -458,8 +458,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - inp_attn->set_input(ubatch); - inp_rs->set_input(ubatch); + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + return res; } // @@ -1909,7 +1944,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); - auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67..25e50238f5 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -364,22 +364,28 @@ public: class llm_graph_input_mem_hybrid : public llm_graph_input_i { public: llm_graph_input_mem_hybrid( + const llama_cparams & cparams, std::unique_ptr inp_attn, - std::unique_ptr inp_rs, - const llama_memory_hybrid_context * mctx) : + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : inp_attn(std::move(inp_attn)), inp_rs(std::move(inp_rs)), + cparams(cparams), mctx(mctx) { } virtual ~llm_graph_input_mem_hybrid() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + std::unique_ptr inp_attn; std::unique_ptr inp_rs; llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); } llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + const llama_cparams cparams; + const llama_memory_hybrid_context * mctx; }; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index dfb8439e01..a1b45e4a3c 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), - ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { }