metal : fuse NORM + MUL + ADD, support non-multiples of 4 (#16220)

* metal : fuse NORM + MUL + ADD

* metal : support norms of non-multiple of 4

* cont : fix comment [no ci]
This commit is contained in:
Georgi Gerganov
2025-09-25 11:30:16 +03:00
committed by GitHub
parent 4ea00794b8
commit dfcd53f7ec
9 changed files with 206 additions and 232 deletions

View File

@@ -6117,7 +6117,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));