From 77d1b8622ae0eed63337d8cfa19e53321ca9db7a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 10 Oct 2025 10:44:41 +0300 Subject: [PATCH] graph : fix reuse check for recurrent inputs --- src/llama-graph.cpp | 11 ++++++++++- src/llama-graph.h | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4d92c20e2d..a3d8444223 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) { res &= s_copy_main->ne[0] == params.ubatch.n_seqs; res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs; + res &= head == mctx->get_head(); + res &= rs_z == mctx->get_rs_z(); + return res; } @@ -509,6 +512,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + return res; } @@ -1893,6 +1899,9 @@ static std::unique_ptr build_rs_inp_impl( inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + inp->head = mctx_cur->get_head(); + inp->rs_z = mctx_cur->get_rs_z(); + return inp; } @@ -1961,7 +1970,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/src/llama-graph.h b/src/llama-graph.h index 944d129c3e..caba9779b5 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -234,6 +234,10 @@ public: ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; + + // used in view offsets, need to match for valid graph reuse + uint32_t head; + int32_t rs_z; }; class llm_graph_input_cross_embd : public llm_graph_input_i {