From 80a6cf63473b95742444a1b27d45164591282a7d Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 9 Nov 2025 02:48:42 -0600 Subject: [PATCH] vulkan: fuse mul_mat_id + mul (#17095) * vulkan: fuse mul_mat_id + mul This comes up in qwen3 moe. * split mul_mat_id fusion tests into a separate class --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 58 ++++++- .../vulkan-shaders/mul_mat_vec_base.glsl | 19 +++ tests/test-backend-ops.cpp | 154 ++++++++++++------ 3 files changed, 180 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6da7bbd2f6..054e8cbdb8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -830,6 +830,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t enable_bias; + uint32_t enable_scale; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -852,6 +853,7 @@ struct vk_mat_vec_id_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t enable_bias; + uint32_t enable_scale; uint32_t nei0; uint32_t ne11; }; @@ -6863,7 +6865,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, + stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, @@ -7684,13 +7686,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte groups_x = CEIL_DIV(groups_x, groups_z); } - uint32_t enable_bias = ctx->num_additional_fused_ops > 0; + uint32_t enable_bias = 0; + uint32_t enable_scale = 0; + if (ctx->num_additional_fused_ops > 0) { + if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) { + enable_scale = 1; + } else { + GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID); + enable_bias = 1; + } + } vk_buffer d_B = d_D; size_t b_buf_offset = 0; uint64_t b_sz = 0; - if (enable_bias) { + if (enable_bias || enable_scale) { const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; bool b_uma = false; @@ -7712,7 +7723,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), - enable_bias, + enable_bias, enable_scale, (uint32_t)nei0, (uint32_t)ne11, }; @@ -12490,6 +12501,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g } } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *mmid = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor *scale = mul->src[1]; + + if (mmid != mul->src[0]) { + return false; + } + // mat-vec only + if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) { + return false; + } + // shaders assume the types match + if (mmid->type != scale->type) { + return false; + } + // shaders assume the bias is contiguous + if (!ggml_is_contiguous(scale)) { + return false; + } + // unaligned bias isn't handled + if (get_misalign_bytes(ctx, scale) != 0) { + return false; + } + // shader only indexes by expert index + if (scale->ne[0] != 1 || + scale->ne[1] != mul->ne[1] || + scale->ne[2] != 1 || + scale->ne[3] != 1) { + return false; + } + } + return true; } @@ -12798,6 +12843,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && @@ -13033,7 +13080,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * is_src_of(graph->nodes[j], graph->nodes[c]) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) { ok = false; break; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index bbb4d1206b..eb8fa6dc09 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -49,6 +49,7 @@ layout (push_constant) uniform parameter uint batch_stride_d; uint enable_bias; + uint enable_scale; #ifdef MUL_MAT_ID uint nei0; @@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); } } @@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); } } @@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2470c148d6..21c7e3a8cf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3557,6 +3557,27 @@ struct test_mul_mat : public test_case { } }; +static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i % n_mats; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } +} + // GGML_OP_MUL_MAT_ID struct test_mul_mat_id : public test_case { const ggml_type type_a; @@ -3567,10 +3588,9 @@ struct test_mul_mat_id : public test_case { const int64_t m; const int64_t n; const int64_t k; - const uint32_t o; // number of outputs std::string vars() override { - return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o); + return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k); } double max_nmse_err() override { @@ -3584,9 +3604,69 @@ struct test_mul_mat_id : public test_case { test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int n_mats = 8, int n_used = 2, bool b = false, - int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1) + int64_t m = 32, int64_t n = 32, int64_t k = 32) : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b), - m(m), n(n), k(k), o(o) { + m(m), n(n), k(k) { + GGML_ASSERT(n_used <= n_mats); + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + // C^T = A * B^T: (k, m) * (k, n) => (m, n) + ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); + ggml_set_name(as, "as"); + + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); + ggml_set_name(ids, "ids"); + if (n_used != n_mats) { + ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0); + ggml_set_name(ids, "view_of_ids"); + } + + ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n); + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + init_mul_mat_id_tensors(ctx, n_mats); + } +}; + +// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL +struct test_mul_mat_id_fusion : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const int n_mats; + const int n_used; + const bool b; // broadcast b matrix + const int64_t m; + const int64_t n; + const int64_t k; + const uint32_t o; // number of outputs + const bool mul; + + std::string vars() override { + return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul); + } + + double max_nmse_err() override { + return 5e-4; + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + return 2 * m * k * n * n_used; + } + + test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + int n_mats = 8, int n_used = 2, bool b = false, + int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false) + : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b), + m(m), n(n), k(k), o(o), mul(mul) { GGML_ASSERT(n_used <= n_mats); } @@ -3615,35 +3695,25 @@ struct test_mul_mat_id : public test_case { out = ggml_add(ctx, out, out2); } + if (mul) { + std::array ne { 1, out->ne[1], out->ne[2], out->ne[3] }; + ne[0] = 1; + ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data()); + out = ggml_mul(ctx, out, m); + } + return out; } void initialize_tensors(ggml_context * ctx) override { - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->type == GGML_TYPE_I32) { - if (ggml_is_view_op(t->op)) { continue; } - std::random_device rd; - std::default_random_engine rng(rd()); - // ids - for (int64_t r = 0; r < ggml_nrows(t); r++) { - std::vector data(t->ne[0]); - for (int i = 0; i < t->ne[0]; i++) { - data[i] = i % n_mats; - } - std::shuffle(data.begin(), data.end(), rng); - ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); - } - } else { - init_tensor_uniform(t); - } - } + init_mul_mat_id_tensors(ctx, n_mats); } - bool run_whole_graph() override { return o > 1; } + bool run_whole_graph() override { return true; } std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); - return ggml_op_name(GGML_OP_MUL_MAT_ID); + return "MUL_MAT_ID_FUSION"; } }; @@ -4992,24 +5062,7 @@ struct test_mul_mat_vec_fusion : public test_case { init_tensor_uniform(t); } } else { - std::random_device rd; - std::default_random_engine rng(rd()); - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->type == GGML_TYPE_I32) { - if (ggml_is_view_op(t->op)) { continue; } - // ids - for (int64_t r = 0; r < ggml_nrows(t); r++) { - std::vector data(t->ne[0]); - for (int i = 0; i < t->ne[0]; i++) { - data[i] = i % n_mats; - } - std::shuffle(data.begin(), data.end(), rng); - ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); - } - } else { - init_tensor_uniform(t); - } - } + init_mul_mat_id_tensors(ctx, n_mats); } } @@ -6979,7 +7032,7 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); - test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); + test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3)); // gpt-oss issue with Vulkan mmq_id test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); @@ -7016,6 +7069,15 @@ static std::vector> make_test_cases_eval() { } } + for (int bs : {1, 4, 512}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + // test with mul after (ffn_moe_weighted) + test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true)); + } + } + } + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (int n : {1, 16}) { @@ -7472,7 +7534,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); + test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); } } } @@ -7480,7 +7542,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); + test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1)); } } } @@ -7490,7 +7552,7 @@ static std::vector> make_test_cases_perf() { for (int bs : {1, 4, 8, 512}) { for (ggml_type type_a : {GGML_TYPE_MXFP4}) { for (ggml_type type_b : {GGML_TYPE_F32}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); + test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1)); } } }