graph : add back hybrid memory graph input

But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).
This commit is contained in:
Francis Couture-Harpin
2025-07-03 17:07:46 -04:00
parent 4682e21c46
commit 20f8e43e63
3 changed files with 80 additions and 22 deletions

View File

@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_attn->set_input(ubatch);
inp_rs->set_input(ubatch);
}
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
GGML_ASSERT(one && ggml_nelements(one) == 1);
@@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const {
if (!mctx_cur) {
mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
}
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
@@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
return inp;
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
}
@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const {
if (!mctx_cur) {
mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
}
// TODO: maybe separate the inner implementation into a separate function
// like with the non-sliding window equivalent
// once sliding-window hybrid caches are a thing.
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
return output_states;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const {
if (!mctx_cur) {
mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
}
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
ggml_context * ctx0,
const llama_memory_recurrent_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
return inp;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
return (llm_graph_input_rs *) res->add_input(std::move(inp));
}
@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
);
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
void llm_graph_context::build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,

View File

@@ -319,6 +319,28 @@ public:
const llama_cross * cross = nullptr;
};
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_memory_hybrid_context * mctx;
};
// TODO: remove this when ggml_scale_add is implemented
class llm_graph_input_one : public llm_graph_input_i {
public:
@@ -575,7 +597,7 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
@@ -590,7 +612,7 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const;
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
@@ -643,7 +665,7 @@ struct llm_graph_context {
int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
@@ -663,6 +685,11 @@ struct llm_graph_context {
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const;
//
// hybrid
//
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
//
// pooling

View File

@@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
auto * inp_hybrid = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
cb(cur, "attn_norm", il);
if (n_head_kv == 0) {
cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il);
cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
} else {
// Attention
@@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
cb(Vcur, "Vcur", il);
// No RoPE :)
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {