diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75fd6db14c..015b37be07 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2821,15 +2821,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops = ggml_cuda_topk_moe_ops(false); std::initializer_list topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true); - if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) { - - if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) { - return false; - } - - for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false; - } + if (ops.size() == topk_moe_ops_with_norm.size() && + ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+8]; @@ -2838,16 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) { - - if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) { - return false; - } - - for (size_t i = 0; i < topk_moe_ops.size(); i++) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false; - } - + if (ops.size() == topk_moe_ops.size() && + ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx+4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 18f095b896..e9201cdc68 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); } +GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, + const int * node_idxs, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs); + +// Returns true if the subgraph formed by {node_idxs} can be fused +// checks whethers all nodes which are not part of outputs can be elided +// by checking if their num_uses are confined to the subgraph +static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, + int node_idx, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs) { + GGML_ASSERT(count < 32); + if (node_idx + count > cgraph->n_nodes) { + return false; + } + + int idxs[32]; + + for (int i = 0; i < count; ++i) { + idxs[i] = node_idx + i; + } + + return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs); +} + #ifdef __cplusplus } #endif @@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std:: return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); } +inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, + int start_idx, + std::initializer_list ops, + std::initializer_list outputs = {}) { + return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 86f1c31afd..9be35c1be8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO("========================================\n"); } +static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph, + const int * idxs, + int count, + const struct ggml_tensor * tensor) { + GGML_ASSERT(cgraph && idxs); + for (int i = 0; i < count; ++i) { + const int node_idx = idxs[i]; + + if (node_idx >= cgraph->n_nodes) { + return -1; + } + if (cgraph->nodes[node_idx] == tensor) { + return i; + } + } + return -1; +} + +bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, + const int * node_idxs, + int count, + const enum ggml_op * ops, + const int * outputs, + int num_outputs) { + GGML_ASSERT(outputs && num_outputs > 0); + + for (int i = 0; i < count; ++i) { + if (node_idxs[i] >= cgraph->n_nodes) { + return false; + } + + const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]]; + + if (node->op != ops[i]) { + return false; + } + + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { + continue; + } + + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + return false; + } + + int subgraph_uses = 0; + for (int j = i + 1; j < count; ++j) { + const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]]; + for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) { + if (other_node->src[src_idx] == node) { + subgraph_uses++; + } + } + } + + if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) { + return false; + } + + // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph + struct ggml_tensor * view_src = node->view_src; + while (view_src) { + if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) { + return false; + } + view_src = view_src->view_src; + } + } + + return true; +} + // check if node is part of the graph static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { if (cgraph == NULL) {