From 20f8e43e63033a1bf5ba936b468e70aec36f6e53 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 3 Jul 2025 17:07:46 -0400 Subject: [PATCH] graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). --- src/llama-graph.cpp | 59 ++++++++++++++++++++++++++++++++++++--------- src/llama-graph.h | 33 ++++++++++++++++++++++--- src/llama-model.cpp | 10 +++----- 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 750175856c..7c2e880066 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { + inp_attn->set_input(ubatch); + inp_rs->set_input(ubatch); +} + void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); GGML_ASSERT(one && ggml_nelements(one) == 1); @@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +static std::unique_ptr build_attn_inp_kv_unified_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_context * mctx_cur) { auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); @@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + return inp; +} + +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } @@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +// TODO: maybe separate the inner implementation into a separate function +// like with the non-sliding window equivalent +// once sliding-window hybrid caches are a thing. +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const auto * mctx_cur = static_cast(mctx); auto inp = std::make_unique(hparams, cparams, mctx_cur); @@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } -llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const { - if (!mctx_cur) { - mctx_cur = static_cast(mctx); - } +static std::unique_ptr build_rs_inp_impl( + ggml_context * ctx0, + const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); @@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); + return inp; +} + +llm_graph_input_rs * llm_graph_context::build_rs_inp() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_rs_inp_impl(ctx0, mctx_cur); + return (llm_graph_input_rs *) res->add_input(std::move(inp)); } @@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ); } +llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_pooling( ggml_cgraph * gf, ggml_tensor * cls, diff --git a/src/llama-graph.h b/src/llama-graph.h index b3542e337d..d8dc1e3307 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -319,6 +319,28 @@ public: const llama_cross * cross = nullptr; }; +class llm_graph_input_mem_hybrid : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid( + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid() = default; + + void set_input(const llama_ubatch * ubatch) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_memory_hybrid_context * mctx; +}; + // TODO: remove this when ggml_scale_add is implemented class llm_graph_input_one : public llm_graph_input_i { public: @@ -575,7 +597,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const; + llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const; ggml_tensor * build_attn( llm_graph_input_attn_kv_unified * inp, @@ -590,7 +612,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const; + llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( @@ -643,7 +665,7 @@ struct llm_graph_context { int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; - llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const; + llm_graph_input_rs * build_rs_inp() const; ggml_tensor * build_rs( llm_graph_input_rs * inp, @@ -663,6 +685,11 @@ struct llm_graph_context { ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const; + // + // hybrid + // + + llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; // // pooling diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e965715f2d..3121d9c4b7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - const auto * mctx_hyb = static_cast(mctx); - - auto * inp_rs = build_rs_inp(mctx_hyb->get_recr()); - - auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn()); + auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (n_head_kv == 0) { - cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il); + cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); } else { // Attention @@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) {