CUDA: add a fused top-K MoE kernel (#16130)

* CUDA: add a fused top-K MoE kernel

This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory

It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models

* Refactor into ggml_cuda_should_use_topk_moe

* Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before

* Review: format + micro-optimizations

* Fix bug: fix tie breakers

* Add optional norm + clean-up code

* Use smem for final write

* Add bounds check

* Use better memory pattern for writeback
This commit is contained in:
Aman Gupta
2025-09-25 22:35:05 +08:00
committed by GitHub
parent aa3ee0eb0b
commit 077c94d0ca
5 changed files with 381 additions and 0 deletions

View File

@@ -4418,6 +4418,49 @@ struct test_argsort : public test_case {
}
};
struct test_topk_moe: public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
: ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
GGML_ASSERT(n_expert_used <= ne[0]);
}
std::string vars() override {
return VARS_TO_STR3(ne, n_expert_used, with_norm);
}
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "TOPK_MOE";
}
bool run_whole_graph() override { return true; }
ggml_tensor * build_graph(ggml_context * ctx) override {
const int n_expert = ne[0];
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
}
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_SUM
struct test_sum : public test_case {
const ggml_type type;
@@ -6588,6 +6631,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
}
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));