mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
ggml: add ggml_can_fuse_subgraph (#16662)
* ggml: add ggml_can_fuse_subgraph * ggml-cuda: use ggml_can_fuse_subgraph for topk-moe * format * 1. remove inputs from signature as they are transient nodes 2. add check for views: view_src should be part of the subgraph * - combine check into one loop - check all view_src parents - other minor review comments * remove redudant if test * - rename and other minor review comments * add assert about count < 32
This commit is contained in:
@@ -2821,15 +2821,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||||||
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
|
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
|
||||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
|
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
|
||||||
|
|
||||||
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
|
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
|
||||||
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
|
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
|
|
||||||
}
|
|
||||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||||
|
|
||||||
@@ -2838,16 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
|
if (ops.size() == topk_moe_ops.size() &&
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
|
||||||
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
|
|
||||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||||
|
|||||||
@@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
|||||||
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||||
|
const int * node_idxs,
|
||||||
|
int count,
|
||||||
|
const enum ggml_op * ops,
|
||||||
|
const int * outputs,
|
||||||
|
int num_outputs);
|
||||||
|
|
||||||
|
// Returns true if the subgraph formed by {node_idxs} can be fused
|
||||||
|
// checks whethers all nodes which are not part of outputs can be elided
|
||||||
|
// by checking if their num_uses are confined to the subgraph
|
||||||
|
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
|
int node_idx,
|
||||||
|
int count,
|
||||||
|
const enum ggml_op * ops,
|
||||||
|
const int * outputs,
|
||||||
|
int num_outputs) {
|
||||||
|
GGML_ASSERT(count < 32);
|
||||||
|
if (node_idx + count > cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int idxs[32];
|
||||||
|
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
idxs[i] = node_idx + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
|
|||||||
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||||
|
int start_idx,
|
||||||
|
std::initializer_list<enum ggml_op> ops,
|
||||||
|
std::initializer_list<int> outputs = {}) {
|
||||||
|
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||||
|
}
|
||||||
|
|
||||||
// expose GGUF internals for test code
|
// expose GGUF internals for test code
|
||||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||||
|
|||||||
@@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|||||||
GGML_LOG_INFO("========================================\n");
|
GGML_LOG_INFO("========================================\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
|
||||||
|
const int * idxs,
|
||||||
|
int count,
|
||||||
|
const struct ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(cgraph && idxs);
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
const int node_idx = idxs[i];
|
||||||
|
|
||||||
|
if (node_idx >= cgraph->n_nodes) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (cgraph->nodes[node_idx] == tensor) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||||
|
const int * node_idxs,
|
||||||
|
int count,
|
||||||
|
const enum ggml_op * ops,
|
||||||
|
const int * outputs,
|
||||||
|
int num_outputs) {
|
||||||
|
GGML_ASSERT(outputs && num_outputs > 0);
|
||||||
|
|
||||||
|
for (int i = 0; i < count; ++i) {
|
||||||
|
if (node_idxs[i] >= cgraph->n_nodes) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
||||||
|
|
||||||
|
if (node->op != ops[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int subgraph_uses = 0;
|
||||||
|
for (int j = i + 1; j < count; ++j) {
|
||||||
|
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
|
||||||
|
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
|
||||||
|
if (other_node->src[src_idx] == node) {
|
||||||
|
subgraph_uses++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
|
||||||
|
struct ggml_tensor * view_src = node->view_src;
|
||||||
|
while (view_src) {
|
||||||
|
if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
view_src = view_src->view_src;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// check if node is part of the graph
|
// check if node is part of the graph
|
||||||
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
||||||
if (cgraph == NULL) {
|
if (cgraph == NULL) {
|
||||||
|
|||||||
Reference in New Issue
Block a user