vulkan: fuse rms_norm + mul + rope (+ view + set_rows) (#16977)

This change combines the rms_norm+mul and rope+view+set_rows fusions to
allow fusing the whole sequence together. This comes up in Qwen3, Bailing,
and some other models.
This commit is contained in:
Jeff Bolz
2025-11-08 01:52:15 -06:00
committed by GitHub
parent d6fe40fa00
commit b4e335d8dc
12 changed files with 695 additions and 281 deletions

View File

@@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case {
}
};
// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
struct test_rms_norm_mul_rope : public test_case {
const std::array<int64_t, 4> ne;
const float eps;
const bool multi_add; // test a sequence of adds feeding into rms_norm
const bool set_rows;
int mode;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "RMS_NORM_MUL_ROPE";
}
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
}
test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
: ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
if (multi_add) {
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
}
a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
ggml_tensor * out;
if (set_rows) {
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
ggml_set_name(dst, "dst");
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
ggml_set_name(row_idxs, "row_idxs");
out = ggml_set_rows(ctx, dst, view, row_idxs);
ggml_set_name(out, "out");
} else {
out = rope;
}
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) {
continue;
}
init_set_rows_row_ids(t, ne[2]);
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@@ -6751,6 +6824,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (auto multi_add : {false, true}) {
for (auto set_rows : {false, true}) {
for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
}
}
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
for (int64_t d_conv : {3, 4}) {