vulkan: Use larger workgroups for mul_mat_vec when M is small (#15355)

* vulkan: Use larger workgroups for mul_mat_vec when M is small

Also use subgroup instructions for (part of) the reduction when supported.
Without this, the more expensive reductions would eat into the benefits of
the larger workgroups.

* update heuristic for amd/intel

Co-authored-by: 0cc4m <picard12@live.de>

---------

Co-authored-by: 0cc4m <picard12@live.de>
This commit is contained in:
Jeff Bolz
2025-08-17 11:08:57 -05:00
committed by GitHub
parent 19f4decae0
commit 21c17b5bef
3 changed files with 134 additions and 53 deletions

View File

@@ -1,6 +1,10 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#endif
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
@@ -90,7 +94,38 @@ layout (constant_id = 2) const uint NUM_COLS = 1;
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
// subgroupAdd is probably faster on devices that support it,
// particularly when the workgroup has more than one subgroup
#if USE_SUBGROUP_ADD
// sum up partial sums within a subgroup
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
// Go through shared memory to sum partials across subgroups
if (gl_SubgroupInvocationID == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
}
}
}
barrier();
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
#else
// sum up partial sums and write back result
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -115,4 +150,5 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
}
}
}
#endif
}