ggml: add ggml_can_fuse_subgraph (#16662)

* ggml: add ggml_can_fuse_subgraph

* ggml-cuda: use ggml_can_fuse_subgraph for topk-moe

* format

* 1. remove inputs from signature as they are transient nodes
2. add check for views: view_src should be part of the subgraph

* - combine check into one loop
- check all view_src parents
- other minor review comments

* remove redudant if test

* - rename and other minor review comments

* add assert about count < 32
This commit is contained in:
Aman Gupta
2025-10-21 16:43:14 +08:00
committed by GitHub
parent 6ea37f5739
commit 4926419c4d
3 changed files with 113 additions and 19 deletions

View File

@@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
}
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
const int * node_idxs,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs);
// Returns true if the subgraph formed by {node_idxs} can be fused
// checks whethers all nodes which are not part of outputs can be elided
// by checking if their num_uses are confined to the subgraph
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int node_idx,
int count,
const enum ggml_op * ops,
const int * outputs,
int num_outputs) {
GGML_ASSERT(count < 32);
if (node_idx + count > cgraph->n_nodes) {
return false;
}
int idxs[32];
for (int i = 0; i < count; ++i) {
idxs[i] = node_idx + i;
}
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
}
#ifdef __cplusplus
}
#endif
@@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
}
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<int> outputs = {}) {
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);