mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
vulkan: fuse rms_norm + mul + rope (+ view + set_rows) (#16977)
This change combines the rms_norm+mul and rope+view+set_rows fusions to allow fusing the whole sequence together. This comes up in Qwen3, Bailing, and some other models.
This commit is contained in:
@@ -2294,6 +2294,79 @@ struct test_rope_set_rows : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ROPE (+ GGML_OP_VIEW + GGML_OP_SET_ROWS)
|
||||
struct test_rms_norm_mul_rope : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
const bool multi_add; // test a sequence of adds feeding into rms_norm
|
||||
const bool set_rows;
|
||||
int mode;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "RMS_NORM_MUL_ROPE";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(ne, eps, multi_add, set_rows, mode);
|
||||
}
|
||||
|
||||
test_rms_norm_mul_rope(std::array<int64_t, 4> ne, float eps = 1e-6f, bool multi_add = false,
|
||||
bool set_rows = false, int mode = GGML_ROPE_TYPE_NORMAL)
|
||||
: ne(ne), eps(eps), multi_add(multi_add), set_rows(set_rows), mode(mode) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
ggml_tensor * c = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
|
||||
|
||||
if (multi_add) {
|
||||
a = ggml_add(ctx, ggml_add(ctx, a, b), c);
|
||||
}
|
||||
|
||||
a = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
|
||||
|
||||
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
|
||||
|
||||
ggml_tensor * rope = ggml_rope(ctx, a, pos, ne[0], mode);
|
||||
|
||||
ggml_tensor * out;
|
||||
|
||||
if (set_rows) {
|
||||
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
|
||||
|
||||
ggml_tensor * dst = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, ne[2], 1, 1);
|
||||
ggml_set_name(row_idxs, "row_idxs");
|
||||
|
||||
out = ggml_set_rows(ctx, dst, view, row_idxs);
|
||||
ggml_set_name(out, "out");
|
||||
} else {
|
||||
out = rope;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
init_set_rows_row_ids(t, ne[2]);
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARGMAX
|
||||
struct test_argmax : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -6751,6 +6824,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto multi_add : {false, true}) {
|
||||
for (auto set_rows : {false, true}) {
|
||||
for (auto rope : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX}) {
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 1, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 1, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({768, 3, 5, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 32, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({128, 4, 50, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
test_cases.emplace_back(new test_rms_norm_mul_rope({8192, 2, 2, 1}, 1e-6f, multi_add, set_rows, rope));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||
|
||||
for (int64_t d_conv : {3, 4}) {
|
||||
|
||||
Reference in New Issue
Block a user