CUDA: topk-moe: add optional parameter for gpt-oss (#16649)

This commit is contained in:
Aman Gupta
2025-10-21 22:40:38 +08:00
committed by GitHub
parent 51d1a8c997
commit 03792ad936
4 changed files with 154 additions and 63 deletions

View File

@@ -4669,14 +4669,21 @@ struct test_topk_moe: public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
: ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
const bool delayed_softmax;
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1,
bool with_norm = false,
bool delayed_softmax = false) :
ne(ne),
n_expert_used(n_expert_used),
with_norm(with_norm),
delayed_softmax(delayed_softmax) {
GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
}
std::string vars() override {
return VARS_TO_STR3(ne, n_expert_used, with_norm);
}
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -4690,11 +4697,17 @@ struct test_topk_moe: public test_case {
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = ggml_soft_max(ctx, logits);
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
if (delayed_softmax) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
}
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
@@ -6975,6 +6988,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));