mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
ggml webgpu: add support for soft_max, optimize rms_norm (#16357)
* Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@@ -3752,9 +3752,10 @@ struct test_soft_max : public test_case {
|
||||
const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
|
||||
const float scale;
|
||||
const float max_bias;
|
||||
const bool inplace;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
|
||||
return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);
|
||||
}
|
||||
|
||||
// the 1024 test with bias occasionally fails:
|
||||
@@ -3770,8 +3771,9 @@ struct test_soft_max : public test_case {
|
||||
ggml_type m_prec = GGML_TYPE_F32,
|
||||
std::array<int64_t, 2> nr23 = {1, 1},
|
||||
float scale = 1.0f,
|
||||
float max_bias = 0.0f)
|
||||
: type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
|
||||
float max_bias = 0.0f,
|
||||
bool inplace = false)
|
||||
: type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
|
||||
@@ -3790,7 +3792,12 @@ struct test_soft_max : public test_case {
|
||||
ggml_set_name(sinks, "sinks");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
||||
ggml_tensor * out;
|
||||
if (inplace) {
|
||||
out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);
|
||||
} else {
|
||||
out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
||||
}
|
||||
ggml_soft_max_add_sinks(out, sinks);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -6562,6 +6569,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
}
|
||||
// inplace tests
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));
|
||||
}
|
||||
}
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
|
||||
Reference in New Issue
Block a user