mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: topk-moe: add optional parameter for gpt-oss (#16649)
This commit is contained in:
@@ -4669,14 +4669,21 @@ struct test_topk_moe: public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int n_expert_used;
|
||||
const bool with_norm;
|
||||
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
|
||||
: ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
|
||||
const bool delayed_softmax;
|
||||
|
||||
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
|
||||
int n_expert_used = 1,
|
||||
bool with_norm = false,
|
||||
bool delayed_softmax = false) :
|
||||
ne(ne),
|
||||
n_expert_used(n_expert_used),
|
||||
with_norm(with_norm),
|
||||
delayed_softmax(delayed_softmax) {
|
||||
GGML_ASSERT(n_expert_used <= ne[0]);
|
||||
GGML_ASSERT(!(with_norm && delayed_softmax));
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(ne, n_expert_used, with_norm);
|
||||
}
|
||||
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
@@ -4690,11 +4697,17 @@ struct test_topk_moe: public test_case {
|
||||
const int n_tokens = ne[1];
|
||||
|
||||
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * probs = ggml_soft_max(ctx, logits);
|
||||
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
|
||||
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
|
||||
if (delayed_softmax) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
|
||||
@@ -6975,6 +6988,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
||||
Reference in New Issue
Block a user