mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
vulkan: Use larger workgroups for mul_mat_vec when M is small (#15355)
* vulkan: Use larger workgroups for mul_mat_vec when M is small Also use subgroup instructions for (part of) the reduction when supported. Without this, the more expensive reductions would eat into the benefits of the larger workgroups. * update heuristic for amd/intel Co-authored-by: 0cc4m <picard12@live.de> --------- Co-authored-by: 0cc4m <picard12@live.de>
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#define EXPERT_COUNT 8
|
||||
@@ -90,7 +94,38 @@ layout (constant_id = 2) const uint NUM_COLS = 1;
|
||||
|
||||
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
// subgroupAdd is probably faster on devices that support it,
|
||||
// particularly when the workgroup has more than one subgroup
|
||||
#if USE_SUBGROUP_ADD
|
||||
// sum up partial sums within a subgroup
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = subgroupAdd(temp[j][n]);
|
||||
}
|
||||
}
|
||||
|
||||
// Go through shared memory to sum partials across subgroups
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = FLOAT_TYPE(0);
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
temp[j][n] += tmpsh[j][n][s];
|
||||
}
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
@@ -115,4 +150,5 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user