graph : refactor context to not pass gf explicitly (#14629)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-18 08:29:28 +03:00
committed by GitHub
parent 09651d09ff
commit 8f974bc1e9
5 changed files with 295 additions and 341 deletions

View File

@@ -371,31 +371,11 @@ public:
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
// TODO: this interface seems redundant - remove it
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() const = 0;
virtual ggml_tensor * get_logits() const = 0;
virtual ggml_tensor * get_embd() const = 0;
virtual ggml_tensor * get_embd_pooled() const = 0;
virtual ggml_cgraph * get_gf() = 0;
virtual ggml_context * get_ctx() = 0;
virtual void reset() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
virtual bool can_reuse(const llm_graph_params & params) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
class llm_graph_result;
struct llm_graph_params {
llm_arch arch = LLM_ARCH_UNKNOWN;
@@ -418,8 +398,7 @@ struct llm_graph_params {
llm_graph_cb cb;
// TODO: temporary
llm_graph_result_i * res;
llm_graph_result * res;
// return true if the "other" params would result in a graph with the same topology as with the current params
// having the same topology allows us to reuse the graph in some cases
@@ -464,35 +443,37 @@ struct llm_graph_params {
}
};
class llm_graph_result : public llm_graph_result_i {
class llm_graph_result {
public:
llm_graph_result(int64_t max_nodes);
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() const override { return t_tokens; }
ggml_tensor * get_logits() const override { return t_logits; }
ggml_tensor * get_embd() const override { return t_embd; }
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
ggml_tensor * get_tokens() const { return t_tokens; }
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_cgraph * get_gf() override { return gf; }
ggml_context * get_ctx() override { return ctx_compute.get(); }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
int64_t get_max_nodes() const;
void reset() override;
void reset();
void set_inputs(const llama_ubatch * ubatch) override;
void set_inputs(const llama_ubatch * ubatch);
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters
// would be identical to the existing graph. in that case, we simply have to update the memory
// contexts of the input tensors of the graph and we can reuse it for another computation
// return true if the graph was updated and can be reused
bool can_reuse(const llm_graph_params & params) override;
bool can_reuse(const llm_graph_params & params);
llm_graph_input_i * add_input(llm_graph_input_ptr input);
void set_params(const llm_graph_params & params);
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
@@ -510,6 +491,7 @@ public:
int64_t max_nodes;
private:
// keep a copy of the previous graph parameters
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
// note: these are updated after constructing the new graph
@@ -519,6 +501,8 @@ public:
int debug = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
//
// llm_graph_context
//
@@ -576,6 +560,7 @@ struct llm_graph_context {
llm_graph_result * res;
ggml_context * ctx0 = nullptr;
ggml_cgraph * gf = nullptr;
llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;
@@ -661,7 +646,6 @@ struct llm_graph_context {
//
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -674,7 +658,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -689,7 +672,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -705,7 +687,6 @@ struct llm_graph_context {
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -720,7 +701,6 @@ struct llm_graph_context {
ggml_tensor * build_attn(
llm_graph_input_attn_cross * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +722,6 @@ struct llm_graph_context {
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
@@ -757,7 +736,6 @@ struct llm_graph_context {
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
@@ -765,9 +743,8 @@ struct llm_graph_context {
ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
ggml_tensor * token_shift,
@@ -784,7 +761,6 @@ struct llm_graph_context {
//
void build_pooling(
ggml_cgraph * gf,
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,