#version 450 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require #ifdef FLOAT16 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif #if defined(DATA_A_IQ1_M) #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif #if defined(DATA_A_BF16) && defined(COOPMAT) #extension GL_EXT_bfloat16 : enable #endif #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable #endif #if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif #include "types.comp" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 #endif #ifndef LOAD_VEC_B #define LOAD_VEC_B 1 #endif // Load 2 values at once without affecting index calculations through LOAD_VEC #if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) #define LOAD_VEC_BATCH_A 2 #else #define LOAD_VEC_BATCH_A 1 #endif #if !defined(ALIGNED) #define LOAD_VEC_BATCH_B 2 #else #define LOAD_VEC_BATCH_B 1 #endif #if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE #endif layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; #endif layout (push_constant) uniform parameter { uint M; uint N; uint K; uint stride_a; uint stride_b; uint stride_d; uint batch_stride_a; uint batch_stride_b; uint batch_stride_d; #ifdef MUL_MAT_ID uint nei0; uint nei1; uint nbi1; uint ne11; #else uint k_split; uint ne02; uint ne12; uint broadcast2; uint broadcast3; #endif } p; layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 7) const uint TM = 4; layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; #ifdef COOPMAT #define SHMEM_STRIDE (BK / 2 + 4) #else #define SHMEM_STRIDE (BK / 2 + 1) #endif shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) #ifdef MUL_MAT_ID shared u16vec2 row_ids[BN]; uint _ne1; #ifdef MUL_MAT_ID_USE_SUBGROUPS shared uvec4 ballots_sh[NUM_WARPS]; void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { _ne1 = 0; uint num_elements = p.nei1 * p.nei0; uint nei0shift = findLSB(p.nei0); uint ids[16]; uint iter = 0; for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { // prefetch up to 16 elements if (iter == 0) { [[unroll]] for (uint k = 0; k < 16; ++k) { uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; bool in_range = i < num_elements; uint ii1; if (nei0_is_pow2) { ii1 = i >> nei0shift; } else { ii1 = i / p.nei0; } uint ii0 = i - ii1 * p.nei0; ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; } } uint i = j + gl_LocalInvocationIndex; bool in_range = i < num_elements; uint ii1; if (nei0_is_pow2) { ii1 = i >> nei0shift; } else { ii1 = i / p.nei0; } uint ii0 = i - ii1 * p.nei0; uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); ballots_sh[gl_SubgroupID] = ballot; barrier(); uint subgroup_base = 0; uint total = 0; for (uint k = 0; k < gl_NumSubgroups; ++k) { if (k == gl_SubgroupID) { subgroup_base = total; } total += subgroupBallotBitCount(ballots_sh[k]); } barrier(); uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); } _ne1 += total; iter &= 15; if (_ne1 >= (ic + 1) * BN) { break; } } barrier(); } #endif // MUL_MAT_ID_USE_SUBGROUPS #endif // MUL_MAT_ID #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif #include "mul_mm_funcs.comp" void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else const uint batch_idx = gl_GlobalInvocationID.z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; const uint i03 = i13 / p.broadcast3; const uint i02 = i12 / p.broadcast2; const uint batch_idx_a = i03 * p.ne02 + i02; #endif const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; const uint ic = gl_WorkGroupID.y; const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; #ifdef COOPMAT const uint warp_i = gl_SubgroupID; const uint tiw = gl_SubgroupInvocationID; const uint cms_per_row = WM / TM; const uint cms_per_col = WN / TN; const uint storestride = WARP / TM; const uint store_r = tiw % TM; const uint store_c = tiw / TM; #else const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); #endif const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS if (bitCount(p.nei0) == 1) { load_row_ids(expert_idx, true, ic); } else { load_row_ids(expert_idx, false, ic); } #else _ne1 = 0; for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { if (_ne1 >= ic * BN) { row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); } _ne1++; } } } barrier(); #endif // Workgroup has no work if (ic * BN >= _ne1) return; #endif #ifdef MUL_MAT_ID const uint start_k = 0; const uint end_k = p.K; #else const uint start_k = ik * p.k_split; const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif uint pos_a = ( #ifdef MUL_MAT_ID expert_idx * p.batch_stride_a + #else batch_idx_a * p.batch_stride_a + #endif ir * BM * p.stride_a + start_k) / LOAD_VEC_A; #ifdef MUL_MAT_ID uint pos_b = 0; #else uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; #endif #ifdef COOPMAT coopmat cache_a; coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { sums[i] = coopmat(0.0f); } #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } #endif for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { #if !defined(MUL_MAT_ID) load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); #else load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k); #endif } barrier(); pos_a += BK / LOAD_VEC_A; pos_b += BK / LOAD_VEC_B; #ifdef COOPMAT [[unroll]] for (uint i = 0; i < BK; i += TK) { [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { // Load from shared into cache coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); } } } #else [[unroll]] for (uint i = 0; i < BK / 2; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; } [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); } } } } } #endif barrier(); } #if defined(ACC_TYPE_MAX) #ifdef COOPMAT [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } } #else [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif #endif const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif #ifdef COOPMAT #ifdef MUL_MAT_ID [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); [[unroll]] for (uint col = 0; col < TN; col += storestride) { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; const u16vec2 row_idx = row_ids[row_i - ic * BN]; if (dr + cm_row * TM + store_r < p.M) { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } } } #else const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; if (is_aligned && is_in_bounds) { // Full coopMat is within bounds and stride_d is aligned with 16B coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); } else if (is_in_bounds) { // Full coopMat is within bounds, but stride_d is not aligned coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); [[unroll]] for (uint col = 0; col < TN; col += storestride) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { // Partial coopMat is within bounds coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); [[unroll]] for (uint col = 0; col < TN; col += storestride) { if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } } } } #endif // MUL_MAT_ID #else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; [[unroll]] for (uint cc = 0; cc < TN; cc++) { #ifdef MUL_MAT_ID const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { #ifdef MUL_MAT_ID if (dr_warp + cr < p.M) { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } #endif // MUL_MAT_ID } } } } #endif // COOPMAT }