vulkan: apply MUL_MAT_ID subgroup optimization to non-coopmat devices (#15524)

* vulkan: use subgroup function for mul_mat_id shader even without coopmat

* vulkan: fix compile warnings

* vulkan: properly check for subgroup size control and require full subgroups for subgroup mul_mat_id

* vulkan: disable subgroup mul_mat_id on devices with subgroups < 16
This commit is contained in:
Ruben Ortlam
2025-08-24 19:36:36 +02:00
committed by GitHub
parent b730706a49
commit 043fb27d38
3 changed files with 282 additions and 195 deletions

View File

@@ -17,6 +17,9 @@
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#endif
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
@@ -108,8 +111,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
@@ -168,7 +173,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
}
barrier();
}
#endif
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
@@ -235,7 +240,7 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
#ifdef COOPMAT
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true);
} else {