mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	graph : add back hybrid memory graph input
But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually).
This commit is contained in:
		@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
 | 
				
			||||||
 | 
					    inp_attn->set_input(ubatch);
 | 
				
			||||||
 | 
					    inp_rs->set_input(ubatch);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
 | 
					void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
 | 
				
			||||||
    GGML_UNUSED(ubatch);
 | 
					    GGML_UNUSED(ubatch);
 | 
				
			||||||
    GGML_ASSERT(one && ggml_nelements(one) == 1);
 | 
					    GGML_ASSERT(one && ggml_nelements(one) == 1);
 | 
				
			||||||
@@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn(
 | 
				
			|||||||
    return cur;
 | 
					    return cur;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const {
 | 
					static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
 | 
				
			||||||
    if (!mctx_cur) {
 | 
					           ggml_context * ctx0,
 | 
				
			||||||
        mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 | 
					     const llama_ubatch & ubatch,
 | 
				
			||||||
    }
 | 
					    const llama_hparams & hparams,
 | 
				
			||||||
 | 
					    const llama_cparams & cparams,
 | 
				
			||||||
 | 
					    const llama_kv_cache_unified_context * mctx_cur) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
 | 
					    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
 | 
				
			|||||||
        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 | 
					        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        const auto n_kv = mctx_cur->get_n_kv();
 | 
					        const auto n_kv = mctx_cur->get_n_kv();
 | 
				
			||||||
 | 
					        const auto n_tokens = ubatch.n_tokens;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
 | 
					        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
 | 
				
			||||||
        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 | 
					        inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 | 
				
			||||||
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
 | 
				
			|||||||
        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 | 
					        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return inp;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
 | 
				
			||||||
 | 
					    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
 | 
					    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
 | 
				
			|||||||
    return cur;
 | 
					    return cur;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const {
 | 
					// TODO: maybe separate the inner implementation into a separate function
 | 
				
			||||||
    if (!mctx_cur) {
 | 
					//       like with the non-sliding window equivalent
 | 
				
			||||||
        mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 | 
					//       once sliding-window hybrid caches are a thing.
 | 
				
			||||||
    }
 | 
					llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
 | 
				
			||||||
 | 
					    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 | 
					    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
 | 
				
			|||||||
    return output_states;
 | 
					    return output_states;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const {
 | 
					static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
 | 
				
			||||||
    if (!mctx_cur) {
 | 
					           ggml_context * ctx0,
 | 
				
			||||||
        mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 | 
					    const llama_memory_recurrent_context * mctx_cur) {
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 | 
					    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
 | 
				
			|||||||
    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
 | 
					    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
 | 
				
			||||||
    ggml_set_input(inp->s_copy);
 | 
					    ggml_set_input(inp->s_copy);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return inp;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
 | 
				
			||||||
 | 
					    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inp = build_rs_inp_impl(ctx0, mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return (llm_graph_input_rs *) res->add_input(std::move(inp));
 | 
					    return (llm_graph_input_rs *) res->add_input(std::move(inp));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
 | 
				
			|||||||
    );
 | 
					    );
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 | 
				
			||||||
 | 
					    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inp_rs   = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
 | 
				
			||||||
 | 
					    auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llm_graph_context::build_pooling(
 | 
					void llm_graph_context::build_pooling(
 | 
				
			||||||
        ggml_cgraph * gf,
 | 
					        ggml_cgraph * gf,
 | 
				
			||||||
        ggml_tensor * cls,
 | 
					        ggml_tensor * cls,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -319,6 +319,28 @@ public:
 | 
				
			|||||||
    const llama_cross * cross = nullptr;
 | 
					    const llama_cross * cross = nullptr;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					    llm_graph_input_mem_hybrid(
 | 
				
			||||||
 | 
					            std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
 | 
				
			||||||
 | 
					            std::unique_ptr<llm_graph_input_rs>              inp_rs,
 | 
				
			||||||
 | 
					            const llama_memory_hybrid_context *              mctx) :
 | 
				
			||||||
 | 
					        inp_attn(std::move(inp_attn)),
 | 
				
			||||||
 | 
					        inp_rs(std::move(inp_rs)),
 | 
				
			||||||
 | 
					        mctx(mctx) { }
 | 
				
			||||||
 | 
					    virtual ~llm_graph_input_mem_hybrid() = default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    void set_input(const llama_ubatch * ubatch) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
 | 
				
			||||||
 | 
					    std::unique_ptr<llm_graph_input_rs>              inp_rs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
 | 
				
			||||||
 | 
					    llm_graph_input_rs              * get_recr() const { return inp_rs.get(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const llama_memory_hybrid_context * mctx;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO: remove this when ggml_scale_add is implemented
 | 
					// TODO: remove this when ggml_scale_add is implemented
 | 
				
			||||||
class llm_graph_input_one : public llm_graph_input_i {
 | 
					class llm_graph_input_one : public llm_graph_input_i {
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
@@ -575,7 +597,7 @@ struct llm_graph_context {
 | 
				
			|||||||
                  float   kq_scale,
 | 
					                  float   kq_scale,
 | 
				
			||||||
                    int   il) const;
 | 
					                    int   il) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const;
 | 
					    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ggml_tensor * build_attn(
 | 
					    ggml_tensor * build_attn(
 | 
				
			||||||
            llm_graph_input_attn_kv_unified * inp,
 | 
					            llm_graph_input_attn_kv_unified * inp,
 | 
				
			||||||
@@ -590,7 +612,7 @@ struct llm_graph_context {
 | 
				
			|||||||
                  float   kq_scale,
 | 
					                  float   kq_scale,
 | 
				
			||||||
                    int   il) const;
 | 
					                    int   il) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const;
 | 
					    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
 | 
					    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
 | 
				
			||||||
    ggml_tensor * build_attn(
 | 
					    ggml_tensor * build_attn(
 | 
				
			||||||
@@ -643,7 +665,7 @@ struct llm_graph_context {
 | 
				
			|||||||
                int32_t   rs_zero,
 | 
					                int32_t   rs_zero,
 | 
				
			||||||
            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 | 
					            const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const;
 | 
					    llm_graph_input_rs * build_rs_inp() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ggml_tensor * build_rs(
 | 
					    ggml_tensor * build_rs(
 | 
				
			||||||
            llm_graph_input_rs * inp,
 | 
					            llm_graph_input_rs * inp,
 | 
				
			||||||
@@ -663,6 +685,11 @@ struct llm_graph_context {
 | 
				
			|||||||
             ggml_tensor * token_shift,
 | 
					             ggml_tensor * token_shift,
 | 
				
			||||||
      const llama_ubatch & ubatch,
 | 
					      const llama_ubatch & ubatch,
 | 
				
			||||||
                     int   il) const;
 | 
					                     int   il) const;
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // hybrid
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    //
 | 
					    //
 | 
				
			||||||
    // pooling
 | 
					    // pooling
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
 | 
				
			|||||||
        // {n_embd, n_tokens}
 | 
					        // {n_embd, n_tokens}
 | 
				
			||||||
        inpL = build_inp_embd(model.tok_embd);
 | 
					        inpL = build_inp_embd(model.tok_embd);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
 | 
					        auto * inp_hybrid = build_inp_mem_hybrid();
 | 
				
			||||||
 | 
					 | 
				
			||||||
        auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ggml_tensor * inp_out_ids = build_inp_out_ids();
 | 
					        ggml_tensor * inp_out_ids = build_inp_out_ids();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
 | 
				
			|||||||
            cb(cur, "attn_norm", il);
 | 
					            cb(cur, "attn_norm", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (n_head_kv == 0) {
 | 
					            if (n_head_kv == 0) {
 | 
				
			||||||
                cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il);
 | 
					                cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                // Attention
 | 
					                // Attention
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
 | 
				
			|||||||
                cb(Vcur, "Vcur", il);
 | 
					                cb(Vcur, "Vcur", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                // No RoPE :)
 | 
					                // No RoPE :)
 | 
				
			||||||
                cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
 | 
					                cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if (il == n_layer - 1 && inp_out_ids) {
 | 
					            if (il == n_layer - 1 && inp_out_ids) {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user