mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-11 10:36:54 +00:00
Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (#10597)
* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader * Improve performance with better q4_k and q5_k dequant and store unrolling * Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection * Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device * Vulkan: Implement accumulator switch for specific mul mat mat shaders * Vulkan: Unroll more loops for more mul mat mat performance * Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic * Disable coopmat support on AMD proprietary driver * Remove redundant checks * Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support * Fix rebase typo * Fix coopmat2 MUL_MAT_ID pipeline selection
This commit is contained in:
@@ -7,6 +7,12 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#endif
|
||||
@@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
|
||||
#endif
|
||||
} p;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
||||
@@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
|
||||
layout (constant_id = 6) const uint WMITER = 2;
|
||||
layout (constant_id = 7) const uint TM = 4;
|
||||
layout (constant_id = 8) const uint TN = 2;
|
||||
layout (constant_id = 9) const uint WARP = 32;
|
||||
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
||||
layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK + 8)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
@@ -98,17 +118,32 @@ void main() {
|
||||
const uint ik = gl_WorkGroupID.x / blocks_m;
|
||||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const uint WSUBM = WM / WMITER;
|
||||
const uint WSUBN = WN / WNITER;
|
||||
|
||||
#ifdef COOPMAT
|
||||
const uint warp_i = gl_SubgroupID;
|
||||
|
||||
const uint tiw = gl_SubgroupInvocationID;
|
||||
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
|
||||
const uint storestride = WARP / TM;
|
||||
const uint store_r = tiw % TM;
|
||||
const uint store_c = tiw / TM;
|
||||
#else
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
|
||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||
|
||||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
#endif
|
||||
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
||||
@@ -156,21 +191,31 @@ void main() {
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
#endif
|
||||
|
||||
float sums[WMITER * TM * WNITER * TN];
|
||||
#ifdef COOPMAT
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[WNITER * TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||
@@ -181,21 +226,21 @@ void main() {
|
||||
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||
#else
|
||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||
} else {
|
||||
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = idx & 0xF;
|
||||
@@ -208,7 +253,7 @@ void main() {
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = idx & 0xF;
|
||||
@@ -222,7 +267,7 @@ void main() {
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = idx & 0xF;
|
||||
@@ -237,7 +282,7 @@ void main() {
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = idx & 0xF;
|
||||
@@ -253,7 +298,7 @@ void main() {
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = (idx & 0xF) * 2;
|
||||
@@ -265,7 +310,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -284,7 +329,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -308,7 +353,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -320,15 +365,20 @@ void main() {
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
uint8_t sc;
|
||||
uint8_t mbyte;
|
||||
if (is < 4) {
|
||||
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
||||
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
||||
} else {
|
||||
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
||||
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
||||
}
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
@@ -336,7 +386,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -351,15 +401,20 @@ void main() {
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
uint8_t sc;
|
||||
uint8_t mbyte;
|
||||
if (is < 4) {
|
||||
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
||||
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
||||
} else {
|
||||
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
||||
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
||||
}
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
@@ -367,7 +422,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -386,7 +441,7 @@ void main() {
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
const uint ib = idx / 16;
|
||||
const uint iqs = idx & 0xF;
|
||||
@@ -407,7 +462,7 @@ void main() {
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#endif
|
||||
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
|
||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||
@@ -423,24 +478,24 @@ void main() {
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#endif
|
||||
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
|
||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
||||
#elif !MUL_MAT_ID
|
||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#else
|
||||
const uint row_i = ic * BN + loadc_b + l;
|
||||
if (row_i < _ne1) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -450,16 +505,30 @@ void main() {
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (uint i = 0; i < BK; i++) {
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < BK; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
|
||||
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,12 +537,13 @@ void main() {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
@@ -485,6 +555,54 @@ void main() {
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
#else
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
@@ -496,7 +614,7 @@ void main() {
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
@@ -504,9 +622,10 @@ void main() {
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // COOPMAT
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user