From 64fe17fbb84f493dbc33e4c13042953c4f5bfaeb Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 8 Nov 2025 21:05:19 +0800 Subject: [PATCH] Revert "CUDA: add expert reduce kernel (#16857)" (#17100) --- ggml/src/ggml-cuda/ggml-cuda.cu | 26 ---- ggml/src/ggml-cuda/moe-expert-reduce.cu | 168 ----------------------- ggml/src/ggml-cuda/moe-expert-reduce.cuh | 11 -- tests/test-backend-ops.cpp | 58 -------- 4 files changed, 263 deletions(-) delete mode 100644 ggml/src/ggml-cuda/moe-expert-reduce.cu delete mode 100644 ggml/src/ggml-cuda/moe-expert-reduce.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 2d4314fba4..68dc57843e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -27,7 +27,6 @@ #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmvf.cuh" #include "ggml-cuda/mmvq.cuh" -#include "ggml-cuda/moe-expert-reduce.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" #include "ggml-cuda/opt-step-sgd.cuh" @@ -3197,31 +3196,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } - if (node->op == GGML_OP_MUL) { - int current_node = i + 1; - int num_views = 0; - int num_adds = 0; - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - num_views++; - current_node++; - } - - while (current_node < cgraph->n_nodes && cgraph->nodes[current_node]->op == GGML_OP_ADD && - num_adds < num_views - 1) { - num_adds++; - current_node++; - } - - if (num_adds == num_views - 1 && num_views > 0) { - ggml_tensor * dst_node = cgraph->nodes[current_node - 1]; - if (ggml_cuda_should_use_moe_expert_reduce(cgraph, i, current_node)) { - ggml_cuda_op_moe_expert_reduce(*cuda_ctx, node->src[0], node->src[1], dst_node); - i += num_views + num_adds; - continue; - } - } - } - if (node->op == GGML_OP_ADD) { int n_fuse = 0; ggml_op ops[8]; diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cu b/ggml/src/ggml-cuda/moe-expert-reduce.cu deleted file mode 100644 index a97c5d573b..0000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cu +++ /dev/null @@ -1,168 +0,0 @@ -#include "moe-expert-reduce.cuh" - -// This kernel is a fusion of the expert weight reduce, common in MoE models - -template -__global__ void moe_expert_reduce_cuda(const float * __restrict__ experts, - const float * __restrict__ weights, - float * __restrict__ dst, - const int n_expert_used, - const int n_cols) { - const int row = blockIdx.x; - const int col = blockIdx.y * blockDim.x + threadIdx.x; - if (col >= n_cols) { - return; - } - - experts += row * n_cols * n_expert_used; - weights += row * n_expert_used; - dst += row * n_cols; - - float acc = 0.f; - if constexpr (n_expert_used_template == 0) { - for (int expert = 0; expert < n_expert_used; ++expert) { - ggml_cuda_mad(acc, experts[col], weights[expert]); - experts += n_cols; - } - dst[col] = acc; - } else { -#pragma unroll - for (int i = 0; i < n_expert_used_template; ++i) { - ggml_cuda_mad(acc, experts[col], weights[i]); - experts += n_cols; - } - dst[col] = acc; - } -} - -static void launch_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const float * experts, - const float * weights, - float * dst, - const int n_expert_used, - const int n_cols, - const int n_rows) { - const int block_size = 32; - - const int n_blocks_x = n_rows; - const int n_blocks_y = (n_cols + block_size - 1) / block_size; - - dim3 block_dims(block_size); - dim3 grid_dims(n_blocks_x, n_blocks_y); - - cudaStream_t stream = ctx.stream(); - switch (n_expert_used) { - case 1: - moe_expert_reduce_cuda<1> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 2: - moe_expert_reduce_cuda<2> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 4: - moe_expert_reduce_cuda<4> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 6: - moe_expert_reduce_cuda<6> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 8: - moe_expert_reduce_cuda<8> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 16: - moe_expert_reduce_cuda<16> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 32: - moe_expert_reduce_cuda<32> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 64: - moe_expert_reduce_cuda<64> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - case 128: - moe_expert_reduce_cuda<128> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - default: - moe_expert_reduce_cuda<0> - <<>>(experts, weights, dst, n_expert_used, n_cols); - break; - } -} - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index) { - const ggml_tensor * mul = cgraph->nodes[start_index]; - - if (mul->op != GGML_OP_MUL || !ggml_is_contiguous(mul->src[0]) || !ggml_is_contiguous(mul->src[1])) { - return false; - } - - int current_node = start_index + 1; - size_t current_offset = 0; - - std::vector view_nodes; - //check if all are views of the expert in increasing order - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_VIEW) { - const ggml_tensor * node = cgraph->nodes[current_node]; - if (node->view_src != mul) { - return false; - } - if (node->view_offs < current_offset) { - return false; - } - current_offset = node->view_offs; - current_node++; - view_nodes.push_back(node); - } - - //check if all the adds are in increasing order - const ggml_tensor * prev_add_src = view_nodes.empty() ? nullptr : view_nodes[0]; - int num_adds = 0; - int num_views = view_nodes.size(); - while (current_node < end_index && cgraph->nodes[current_node]->op == GGML_OP_ADD) { - const ggml_tensor * add_node = cgraph->nodes[current_node]; - - bool is_first_op_ok = num_views > num_adds ? add_node->src[0] == prev_add_src : false; - bool is_second_op_ok = num_views > num_adds ? add_node->src[1] == view_nodes[num_adds + 1] : false; - - if (!is_first_op_ok || !is_second_op_ok) { - return false; - } - prev_add_src = add_node; - - num_adds++; - current_node++; - } - - if (num_views != num_adds + 1) { - return false; - } - - return true; -} - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst) { - const int n_rows = experts->ne[2]; - const int n_expert_used = experts->ne[1]; - const int n_cols = experts->ne[0]; - - GGML_ASSERT(experts->type == GGML_TYPE_F32); - GGML_ASSERT(weights->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(experts)); - GGML_ASSERT(ggml_is_contiguous(weights)); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const float * experts_d = (const float *) experts->data; - const float * weights_d = (const float *) weights->data; - float * dst_d = (float *) dst->data; - - launch_moe_expert_reduce(ctx, experts_d, weights_d, dst_d, n_expert_used, n_cols, n_rows); -} diff --git a/ggml/src/ggml-cuda/moe-expert-reduce.cuh b/ggml/src/ggml-cuda/moe-expert-reduce.cuh deleted file mode 100644 index cafc50e104..0000000000 --- a/ggml/src/ggml-cuda/moe-expert-reduce.cuh +++ /dev/null @@ -1,11 +0,0 @@ -#include "common.cuh" -#include "ggml.h" - -#include - -void ggml_cuda_op_moe_expert_reduce(ggml_backend_cuda_context & ctx, - const ggml_tensor * experts, - const ggml_tensor * weights, - ggml_tensor * dst); - -bool ggml_cuda_should_use_moe_expert_reduce(const ggml_cgraph * cgraph, int start_index, int end_index); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 31625bcc7a..2470c148d6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4882,60 +4882,6 @@ struct test_topk_moe: public test_case { } }; -struct test_moe_expert_reduce : public test_case { - const int64_t n_embd; - const int64_t n_tokens; - const int64_t n_expert_used; - - test_moe_expert_reduce(int64_t n_embd = 64, int64_t n_tokens = 5, int64_t n_expert_used = 4) - : n_embd(n_embd), n_tokens(n_tokens), n_expert_used(n_expert_used) { - GGML_ASSERT(n_expert_used > 1); - } - - std::string vars() override { - return VARS_TO_STR3(n_embd, n_tokens, n_expert_used); - } - - std::string op_desc(ggml_tensor * t) override { - GGML_UNUSED(t); - return "MOE_EXPERT_REDUCE"; - } - - bool run_whole_graph() override { return true; } - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * experts = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_expert_used, n_tokens); - ggml_set_name(experts, "experts"); - - ggml_tensor * weights = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, n_expert_used, n_tokens); - ggml_set_name(weights, "weights"); - - ggml_tensor * weighted = ggml_mul(ctx, experts, weights); - ggml_set_name(weighted, "weighted_experts"); - - std::vector expert_views(n_expert_used); - for (int64_t i = 0; i < n_expert_used; ++i) { - expert_views[i] = ggml_view_2d(ctx, weighted, n_embd, n_tokens, weighted->nb[2], i * weighted->nb[1]); - - std::string name = "expert_view_" + std::to_string(i); - ggml_set_name(expert_views[i], name.c_str()); - ggml_build_forward_expand(gf, expert_views[i]); - } - - ggml_tensor * moe_out = expert_views[0]; - for (int64_t i = 1; i < n_expert_used; ++i) { - moe_out = ggml_add(ctx, moe_out, expert_views[i]); - - std::string name = "expert_add_" + std::to_string(i - 1); - ggml_set_name(moe_out, name.c_str()); - } - - ggml_set_name(moe_out, "moe_out"); - - return moe_out; - } -}; - struct test_mul_mat_vec_fusion : public test_case { const ggml_type type; const ggml_glu_op glu_op; @@ -7415,10 +7361,6 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true)); - test_cases.emplace_back(new test_moe_expert_reduce(1024, 5, 4)); - test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 6)); - test_cases.emplace_back(new test_moe_expert_reduce(80, 3, 7)); - #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));