ggml webgpu: add support for soft_max, optimize rms_norm (#16357)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* Update tests/test-backend-ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Reese Levine
2025-10-02 11:00:31 -07:00
committed by GitHub
parent 34fcc5a4ac
commit ef07a40906
6 changed files with 566 additions and 48 deletions

View File

@@ -3752,9 +3752,10 @@ struct test_soft_max : public test_case {
const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
const bool inplace;
std::string vars() override {
return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);
}
// the 1024 test with bias occasionally fails:
@@ -3770,8 +3771,9 @@ struct test_soft_max : public test_case {
ggml_type m_prec = GGML_TYPE_F32,
std::array<int64_t, 2> nr23 = {1, 1},
float scale = 1.0f,
float max_bias = 0.0f)
: type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
float max_bias = 0.0f,
bool inplace = false)
: type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -3790,7 +3792,12 @@ struct test_soft_max : public test_case {
ggml_set_name(sinks, "sinks");
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
ggml_tensor * out;
if (inplace) {
out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);
} else {
out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
}
ggml_soft_max_add_sinks(out, sinks);
ggml_set_name(out, "out");
@@ -6562,6 +6569,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
// inplace tests
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));