vulkan: fuse mul_mat_id + mul (#17095)

* vulkan: fuse mul_mat_id + mul

This comes up in qwen3 moe.

* split mul_mat_id fusion tests into a separate class
This commit is contained in:
Jeff Bolz
2025-11-09 02:48:42 -06:00
committed by GitHub
parent 0750a59903
commit 80a6cf6347
3 changed files with 180 additions and 51 deletions

View File

@@ -3557,6 +3557,27 @@ struct test_mul_mat : public test_case {
}
};
static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i % n_mats;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
}
// GGML_OP_MUL_MAT_ID
struct test_mul_mat_id : public test_case {
const ggml_type type_a;
@@ -3567,10 +3588,9 @@ struct test_mul_mat_id : public test_case {
const int64_t m;
const int64_t n;
const int64_t k;
const uint32_t o; // number of outputs
std::string vars() override {
return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
}
double max_nmse_err() override {
@@ -3584,9 +3604,69 @@ struct test_mul_mat_id : public test_case {
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
int64_t m = 32, int64_t n = 32, int64_t k = 32)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k), o(o) {
m(m), n(n), k(k) {
GGML_ASSERT(n_used <= n_mats);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_set_name(as, "as");
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
ggml_set_name(ids, "ids");
if (n_used != n_mats) {
ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
ggml_set_name(ids, "view_of_ids");
}
ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
ggml_set_name(b, "b");
ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
init_mul_mat_id_tensors(ctx, n_mats);
}
};
// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
struct test_mul_mat_id_fusion : public test_case {
const ggml_type type_a;
const ggml_type type_b;
const int n_mats;
const int n_used;
const bool b; // broadcast b matrix
const int64_t m;
const int64_t n;
const int64_t k;
const uint32_t o; // number of outputs
const bool mul;
std::string vars() override {
return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
}
double max_nmse_err() override {
return 5e-4;
}
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * k * n * n_used;
}
test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k), o(o), mul(mul) {
GGML_ASSERT(n_used <= n_mats);
}
@@ -3615,35 +3695,25 @@ struct test_mul_mat_id : public test_case {
out = ggml_add(ctx, out, out2);
}
if (mul) {
std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
ne[0] = 1;
ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
out = ggml_mul(ctx, out, m);
}
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
std::random_device rd;
std::default_random_engine rng(rd());
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i % n_mats;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
init_mul_mat_id_tensors(ctx, n_mats);
}
bool run_whole_graph() override { return o > 1; }
bool run_whole_graph() override { return true; }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return ggml_op_name(GGML_OP_MUL_MAT_ID);
return "MUL_MAT_ID_FUSION";
}
};
@@ -4992,24 +5062,7 @@ struct test_mul_mat_vec_fusion : public test_case {
init_tensor_uniform(t);
}
} else {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i % n_mats;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
init_mul_mat_id_tensors(ctx, n_mats);
}
}
@@ -6979,7 +7032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
@@ -7016,6 +7069,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (int bs : {1, 4, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
// test with mul after (ffn_moe_weighted)
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
}
}
}
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (int n : {1, 16}) {
@@ -7472,7 +7534,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
}
}
}
@@ -7480,7 +7542,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
}
}
}
@@ -7490,7 +7552,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
}
}
}