mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (#10597)
* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader * Improve performance with better q4_k and q5_k dequant and store unrolling * Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection * Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device * Vulkan: Implement accumulator switch for specific mul mat mat shaders * Vulkan: Unroll more loops for more mul mat mat performance * Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic * Disable coopmat support on AMD proprietary driver * Remove redundant checks * Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support * Fix rebase typo * Fix coopmat2 MUL_MAT_ID pipeline selection
This commit is contained in:
		
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -7,6 +7,12 @@ | |||||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | #ifdef COOPMAT | ||||||
|  | #extension GL_KHR_cooperative_matrix : enable | ||||||
|  | #extension GL_KHR_memory_scope_semantics : enable | ||||||
|  | #extension GL_KHR_shader_subgroup_basic : enable | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||||
| #endif | #endif | ||||||
| @@ -57,6 +63,7 @@ layout (push_constant) uniform parameter | |||||||
| #endif | #endif | ||||||
| } p; | } p; | ||||||
|  |  | ||||||
|  | layout (constant_id = 0) const uint BLOCK_SIZE = 64; | ||||||
| layout (constant_id = 1) const uint BM = 64; | layout (constant_id = 1) const uint BM = 64; | ||||||
| layout (constant_id = 2) const uint BN = 64; | layout (constant_id = 2) const uint BN = 64; | ||||||
| layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant | layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant | ||||||
| @@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32; | |||||||
| layout (constant_id = 6) const uint WMITER = 2; | layout (constant_id = 6) const uint WMITER = 2; | ||||||
| layout (constant_id = 7) const uint TM = 4; | layout (constant_id = 7) const uint TM = 4; | ||||||
| layout (constant_id = 8) const uint TN = 2; | layout (constant_id = 8) const uint TN = 2; | ||||||
| layout (constant_id = 9) const uint WARP = 32; | layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat | ||||||
|  | layout (constant_id = 10) const uint WARP = 32; | ||||||
|  |  | ||||||
| shared FLOAT_TYPE buf_a[BM * (BK+1)]; | #ifdef COOPMAT | ||||||
| shared FLOAT_TYPE buf_b[BN * (BK+1)]; | #define SHMEM_STRIDE (BK + 8) | ||||||
|  | #else | ||||||
|  | #define SHMEM_STRIDE (BK + 1) | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; | ||||||
|  | shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; | ||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
| shared u16vec2 row_ids[3072]; | shared u16vec2 row_ids[3072]; | ||||||
|  | #endif // MUL_MAT_ID | ||||||
|  |  | ||||||
|  | #define NUM_WARPS (BLOCK_SIZE / WARP) | ||||||
|  |  | ||||||
|  | #ifdef COOPMAT | ||||||
|  | shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| void main() { | void main() { | ||||||
| @@ -98,17 +118,32 @@ void main() { | |||||||
|     const uint ik = gl_WorkGroupID.x / blocks_m; |     const uint ik = gl_WorkGroupID.x / blocks_m; | ||||||
|     const uint ic = gl_WorkGroupID.y; |     const uint ic = gl_WorkGroupID.y; | ||||||
|  |  | ||||||
|     const uint warp_i = gl_LocalInvocationID.x / WARP; |  | ||||||
|     const uint warp_r = warp_i % (BM / WM); |  | ||||||
|     const uint warp_c = warp_i / (BM / WM); |  | ||||||
|  |  | ||||||
|     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); |     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); | ||||||
|     const uint WSUBM = WM / WMITER; |     const uint WSUBM = WM / WMITER; | ||||||
|     const uint WSUBN = WN / WNITER; |     const uint WSUBN = WN / WNITER; | ||||||
|  |  | ||||||
|  | #ifdef COOPMAT | ||||||
|  |     const uint warp_i = gl_SubgroupID; | ||||||
|  |  | ||||||
|  |     const uint tiw = gl_SubgroupInvocationID; | ||||||
|  |  | ||||||
|  |     const uint cms_per_row = WM / TM; | ||||||
|  |     const uint cms_per_col = WN / TN; | ||||||
|  |  | ||||||
|  |     const uint storestride = WARP / TM; | ||||||
|  |     const uint store_r = tiw % TM; | ||||||
|  |     const uint store_c = tiw / TM; | ||||||
|  | #else | ||||||
|  |     const uint warp_i = gl_LocalInvocationID.x / WARP; | ||||||
|  |  | ||||||
|     const uint tiw = gl_LocalInvocationID.x % WARP; |     const uint tiw = gl_LocalInvocationID.x % WARP; | ||||||
|  |  | ||||||
|     const uint tiwr = tiw % (WSUBM / TM); |     const uint tiwr = tiw % (WSUBM / TM); | ||||||
|     const uint tiwc = tiw / (WSUBM / TM); |     const uint tiwc = tiw / (WSUBM / TM); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |     const uint warp_r = warp_i % (BM / WM); | ||||||
|  |     const uint warp_c = warp_i / (BM / WM); | ||||||
|  |  | ||||||
|     const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); |     const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); | ||||||
|     const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); |     const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); | ||||||
| @@ -156,21 +191,31 @@ void main() { | |||||||
|     uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; |     uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     float sums[WMITER * TM * WNITER * TN]; | #ifdef COOPMAT | ||||||
|  |     coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a; | ||||||
|  |     coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b; | ||||||
|  |     coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col]; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { | ||||||
|  |         sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f); | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     ACC_TYPE sums[WMITER * TM * WNITER * TN]; | ||||||
|     FLOAT_TYPE cache_a[WMITER * TM]; |     FLOAT_TYPE cache_a[WMITER * TM]; | ||||||
|     FLOAT_TYPE cache_b[WNITER * TN]; |     FLOAT_TYPE cache_b[WNITER * TN]; | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { |     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { | ||||||
|         sums[i] = 0.0f; |         sums[i] = ACC_TYPE(0.0f); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     [[unroll]] for (uint block = start_k; block < end_k; block += BK) { |     for (uint block = start_k; block < end_k; block += BK) { | ||||||
|         [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { |         [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { | ||||||
|  |  | ||||||
| #if defined(DATA_A_F32) || defined(DATA_A_F16) | #if defined(DATA_A_F32) || defined(DATA_A_F16) | ||||||
| #if LOAD_VEC_A == 8 | #if LOAD_VEC_A == 8 | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x); |             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x); | ||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); |             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); | ||||||
|             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); |             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); | ||||||
| @@ -181,21 +226,21 @@ void main() { | |||||||
|             buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); |             buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); | ||||||
| #elif LOAD_VEC_A == 4 | #elif LOAD_VEC_A == 4 | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x); |             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x); | ||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); |             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); | ||||||
|             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); |             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); | ||||||
|             buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); |             buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); | ||||||
| #else | #else | ||||||
|             if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { |             if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { | ||||||
|                 buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); |                 buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); | ||||||
|             } else { |             } else { | ||||||
|                 buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); |                 buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); | ||||||
|             } |             } | ||||||
| #endif | #endif | ||||||
| #elif defined(DATA_A_Q4_0) | #elif defined(DATA_A_Q4_0) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = idx & 0xF; |             const uint iqs = idx & 0xF; | ||||||
| @@ -208,7 +253,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q4_1) | #elif defined(DATA_A_Q4_1) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = idx & 0xF; |             const uint iqs = idx & 0xF; | ||||||
| @@ -222,7 +267,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q5_0) | #elif defined(DATA_A_Q5_0) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = idx & 0xF; |             const uint iqs = idx & 0xF; | ||||||
| @@ -237,7 +282,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q5_1) | #elif defined(DATA_A_Q5_1) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = idx & 0xF; |             const uint iqs = idx & 0xF; | ||||||
| @@ -253,7 +298,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q8_0) | #elif defined(DATA_A_Q8_0) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = (idx & 0xF) * 2; |             const uint iqs = (idx & 0xF) * 2; | ||||||
| @@ -265,7 +310,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q2_K) | #elif defined(DATA_A_Q2_K) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 128;                         // 2 values per idx |             const uint ib = idx / 128;                         // 2 values per idx | ||||||
|             const uint iqs = idx % 128;                        // 0..127 |             const uint iqs = idx % 128;                        // 0..127 | ||||||
| @@ -284,7 +329,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); |             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); | ||||||
| #elif defined(DATA_A_Q3_K) | #elif defined(DATA_A_Q3_K) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 128;                   // 2 values per idx |             const uint ib = idx / 128;                   // 2 values per idx | ||||||
|             const uint iqs = idx % 128;                  // 0..127 |             const uint iqs = idx % 128;                  // 0..127 | ||||||
| @@ -308,7 +353,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); |             buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); | ||||||
| #elif defined(DATA_A_Q4_K) | #elif defined(DATA_A_Q4_K) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 128;                 // 2 values per idx |             const uint ib = idx / 128;                 // 2 values per idx | ||||||
|             const uint iqs = idx % 128;                // 0..127 |             const uint iqs = idx % 128;                // 0..127 | ||||||
| @@ -320,15 +365,20 @@ void main() { | |||||||
|  |  | ||||||
|             const vec2 loadd = vec2(data_a[ib].d); |             const vec2 loadd = vec2(data_a[ib].d); | ||||||
|  |  | ||||||
|             uint8_t sc; |             const uint scidx0 = (is < 4) ? is : (is + 4); | ||||||
|             uint8_t mbyte; |             const uint scidx1 = (is < 4) ? is : (is - 4); | ||||||
|             if (is < 4) { |             const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||||
|                 sc    = uint8_t(data_a[ib].scales[is    ] & 63); |             const uint scidxshift1 = (is < 4) ? 0 : 2; | ||||||
|                 mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); |             const uint mbidx0 = is + 4; | ||||||
|             } else { |             const uint mbidx1 = (is < 4) ? is + 4 : is; | ||||||
|                 sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); |             const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||||
|                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4)); |             const uint mbidxshift0 = (is < 4) ? 0 : 4; | ||||||
|             } |             const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||||
|  |             const uint mbidxshift1 = (is < 4) ? 0 : 2; | ||||||
|  |  | ||||||
|  |             const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||||
|  |             const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||||
|  |  | ||||||
|             const float d = loadd.x * sc; |             const float d = loadd.x * sc; | ||||||
|             const float m = -loadd.y * mbyte; |             const float m = -loadd.y * mbyte; | ||||||
|  |  | ||||||
| @@ -336,7 +386,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); |             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); | ||||||
| #elif defined(DATA_A_Q5_K) | #elif defined(DATA_A_Q5_K) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 128;                 // 2 values per idx |             const uint ib = idx / 128;                 // 2 values per idx | ||||||
|             const uint iqs = idx % 128;                // 0..127 |             const uint iqs = idx % 128;                // 0..127 | ||||||
| @@ -351,15 +401,20 @@ void main() { | |||||||
|  |  | ||||||
|             const vec2 loadd = vec2(data_a[ib].d); |             const vec2 loadd = vec2(data_a[ib].d); | ||||||
|  |  | ||||||
|             uint8_t sc; |             const uint scidx0 = (is < 4) ? is : (is + 4); | ||||||
|             uint8_t mbyte; |             const uint scidx1 = (is < 4) ? is : (is - 4); | ||||||
|             if (is < 4) { |             const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||||
|                 sc    = uint8_t(data_a[ib].scales[is    ] & 63); |             const uint scidxshift1 = (is < 4) ? 0 : 2; | ||||||
|                 mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); |             const uint mbidx0 = is + 4; | ||||||
|             } else { |             const uint mbidx1 = (is < 4) ? is + 4 : is; | ||||||
|                 sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); |             const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; | ||||||
|                 mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4)); |             const uint mbidxshift0 = (is < 4) ? 0 : 4; | ||||||
|             } |             const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; | ||||||
|  |             const uint mbidxshift1 = (is < 4) ? 0 : 2; | ||||||
|  |  | ||||||
|  |             const uint8_t sc    = uint8_t((data_a[ib].scales[scidx0] & 0xF)                         | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); | ||||||
|  |             const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); | ||||||
|  |  | ||||||
|             const float d = loadd.x * sc; |             const float d = loadd.x * sc; | ||||||
|             const float m = -loadd.y * mbyte; |             const float m = -loadd.y * mbyte; | ||||||
|  |  | ||||||
| @@ -367,7 +422,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); |             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); | ||||||
| #elif defined(DATA_A_Q6_K) | #elif defined(DATA_A_Q6_K) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; | ||||||
|  |  | ||||||
|             const uint ib = idx / 128;                  // 2 values per idx |             const uint ib = idx / 128;                  // 2 values per idx | ||||||
|             const uint iqs = idx % 128;                 // 0..127 |             const uint iqs = idx % 128;                 // 0..127 | ||||||
| @@ -386,7 +441,7 @@ void main() { | |||||||
|             buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); |             buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); | ||||||
| #elif defined(DATA_A_IQ4_NL) | #elif defined(DATA_A_IQ4_NL) | ||||||
|             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; |             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; | ||||||
|             const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; |             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; | ||||||
|  |  | ||||||
|             const uint ib = idx / 16; |             const uint ib = idx / 16; | ||||||
|             const uint iqs = idx & 0xF; |             const uint iqs = idx & 0xF; | ||||||
| @@ -407,7 +462,7 @@ void main() { | |||||||
| #else | #else | ||||||
|             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; |             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; | ||||||
| #endif | #endif | ||||||
|             const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; |             const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; | ||||||
|             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); |             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); | ||||||
|             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); |             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); | ||||||
|             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); |             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); | ||||||
| @@ -423,24 +478,24 @@ void main() { | |||||||
| #else | #else | ||||||
|             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; |             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; | ||||||
| #endif | #endif | ||||||
|             const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; |             const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; | ||||||
|             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); |             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); | ||||||
|             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); |             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); | ||||||
|             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); |             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); | ||||||
|             buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); |             buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); | ||||||
| #elif !MUL_MAT_ID | #elif !MUL_MAT_ID | ||||||
|             if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { |             if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { | ||||||
|                 buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); |                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); | ||||||
|             } else { |             } else { | ||||||
|                 buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); |                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); | ||||||
|             } |             } | ||||||
| #else | #else | ||||||
|             const uint row_i = ic * BN + loadc_b + l; |             const uint row_i = ic * BN + loadc_b + l; | ||||||
|             if (row_i < _ne1) { |             if (row_i < _ne1) { | ||||||
|                 const u16vec2 row_idx = row_ids[row_i]; |                 const u16vec2 row_idx = row_ids[row_i]; | ||||||
|                 buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); |                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); | ||||||
|             } else { |             } else { | ||||||
|                 buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); |                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); | ||||||
|             } |             } | ||||||
| #endif | #endif | ||||||
|         } |         } | ||||||
| @@ -450,16 +505,30 @@ void main() { | |||||||
|         pos_a += BK / LOAD_VEC_A; |         pos_a += BK / LOAD_VEC_A; | ||||||
|         pos_b += BK / LOAD_VEC_B; |         pos_b += BK / LOAD_VEC_B; | ||||||
|  |  | ||||||
|         for (uint i = 0; i < BK; i++) { | #ifdef COOPMAT | ||||||
|  |         [[unroll]] for (uint i = 0; i < BK; i += TK) { | ||||||
|  |             [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||||
|  |                 // Load from shared into cache | ||||||
|  |                 coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); | ||||||
|  |  | ||||||
|  |                 [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||||
|  |                     coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); | ||||||
|  |  | ||||||
|  |                     sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | #else | ||||||
|  |         [[unroll]] for (uint i = 0; i < BK; i++) { | ||||||
|             // Load from shared into cache |             // Load from shared into cache | ||||||
|             [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { |             [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||||
|                 [[unroll]] for (uint j = 0; j < TM; j++) { |                 [[unroll]] for (uint j = 0; j < TM; j++) { | ||||||
|                     cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; |                     cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { |             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { | ||||||
|                 [[unroll]] for (uint j = 0; j < TN; j++) { |                 [[unroll]] for (uint j = 0; j < TN; j++) { | ||||||
|                     cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; |                     cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -468,12 +537,13 @@ void main() { | |||||||
|                     [[unroll]] for (uint cc = 0; cc < TN; cc++) { |                     [[unroll]] for (uint cc = 0; cc < TN; cc++) { | ||||||
|                         [[unroll]] for (uint cr = 0; cr < TM; cr++) { |                         [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||||
|                             const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; |                             const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; | ||||||
|                             sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); |                             sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|         barrier(); |         barrier(); | ||||||
|     } |     } | ||||||
| @@ -485,6 +555,54 @@ void main() { | |||||||
|     const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; |     const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | #ifdef COOPMAT | ||||||
|  | #ifdef MUL_MAT_ID | ||||||
|  |     [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||||
|  |         [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||||
|  |             coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||||
|  |  | ||||||
|  |             [[unroll]] for (uint col = 0; col < BN; col += storestride) { | ||||||
|  |                 const uint row_i = dc + cm_col * TN + col + store_c; | ||||||
|  |                 if (row_i >= _ne1) break; | ||||||
|  |  | ||||||
|  |                 const u16vec2 row_idx = row_ids[row_i]; | ||||||
|  |  | ||||||
|  |                 data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||||
|  |         [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||||
|  |             const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; | ||||||
|  |  | ||||||
|  |             if (is_aligned && is_in_bounds) { | ||||||
|  |                 // Full coopMat is within bounds and stride_d is aligned with 16B | ||||||
|  |                 coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]); | ||||||
|  |                 coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); | ||||||
|  |             } else if (is_in_bounds) { | ||||||
|  |                 // Full coopMat is within bounds, but stride_d is not aligned | ||||||
|  |                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||||
|  |  | ||||||
|  |                 [[unroll]] for (uint col = 0; col < TN; col += storestride) { | ||||||
|  |                     data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||||
|  |                 } | ||||||
|  |             } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { | ||||||
|  |                 // Partial coopMat is within bounds | ||||||
|  |                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||||
|  |  | ||||||
|  |                 [[unroll]] for (uint col = 0; col < TN; col += storestride) { | ||||||
|  |                     if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { | ||||||
|  |                         data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #endif // MUL_MAT_ID | ||||||
|  | #else | ||||||
|     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { |     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { | ||||||
|         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { |         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||||
|  |  | ||||||
| @@ -496,7 +614,7 @@ void main() { | |||||||
|                 if (row_i >= _ne1) break; |                 if (row_i >= _ne1) break; | ||||||
|  |  | ||||||
|                 const u16vec2 row_idx = row_ids[row_i]; |                 const u16vec2 row_idx = row_ids[row_i]; | ||||||
| #endif | #endif // MUL_MAT_ID | ||||||
|                 [[unroll]] for (uint cr = 0; cr < TM; cr++) { |                 [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
|                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); |                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); | ||||||
| @@ -504,9 +622,10 @@ void main() { | |||||||
|                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) { |                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) { | ||||||
|                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); |                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); | ||||||
|                     } |                     } | ||||||
| #endif | #endif // MUL_MAT_ID | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | #endif // COOPMAT | ||||||
| } | } | ||||||
|   | |||||||
| @@ -60,6 +60,7 @@ const std::vector<std::string> type_names = { | |||||||
|     "iq4_nl" |     "iq4_nl" | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | namespace { | ||||||
| void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { | void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { | ||||||
| #ifdef _WIN32 | #ifdef _WIN32 | ||||||
|     HANDLE stdout_read, stdout_write; |     HANDLE stdout_read, stdout_write; | ||||||
| @@ -198,8 +199,8 @@ static uint32_t compile_count = 0; | |||||||
| static std::mutex compile_count_mutex; | static std::mutex compile_count_mutex; | ||||||
| static std::condition_variable compile_count_cond; | static std::condition_variable compile_count_cond; | ||||||
|  |  | ||||||
| void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { | ||||||
|     std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); |     std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); | ||||||
|     std::string out_fname = join_paths(output_dir, name + ".spv"); |     std::string out_fname = join_paths(output_dir, name + ".spv"); | ||||||
|     std::string in_path = join_paths(input_dir, in_fname); |     std::string in_path = join_paths(input_dir, in_fname); | ||||||
|  |  | ||||||
| @@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s | |||||||
| } | } | ||||||
|  |  | ||||||
| static std::vector<std::future<void>> compiles; | static std::vector<std::future<void>> compiles; | ||||||
| void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) { | void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { | ||||||
|     { |     { | ||||||
|         // wait until fewer than N compiles are in progress. |         // wait until fewer than N compiles are in progress. | ||||||
|         // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. |         // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. | ||||||
| @@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const | |||||||
|         } |         } | ||||||
|         compile_count++; |         compile_count++; | ||||||
|     } |     } | ||||||
|     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc)); |     compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); | ||||||
| } | } | ||||||
|  |  | ||||||
| void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { | void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { | ||||||
|     std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; |     std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; | ||||||
|     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; |     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; | ||||||
|     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; |     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; | ||||||
| @@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { | |||||||
|  |  | ||||||
|     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; |     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||||
|  |  | ||||||
|  |     if (coopmat) { | ||||||
|  |         base_dict["COOPMAT"] = "1"; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; | ||||||
|  |  | ||||||
|     std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; |     std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; | ||||||
|  |  | ||||||
|     // Shaders with f16 B_TYPE |     // Shaders with f16 B_TYPE | ||||||
|     string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc); |     string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); | ||||||
|     string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); |     string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|  |  | ||||||
|     string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); |     string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|     string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc); |     string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|  |  | ||||||
|     for (const auto& tname : type_names) { |     for (const auto& tname : type_names) { | ||||||
|         std::string data_a_key = "DATA_A_" + to_uppercase(tname); |         std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||||
| @@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) { | |||||||
|         // For aligned matmul loads |         // For aligned matmul loads | ||||||
|         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; |         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; | ||||||
|  |  | ||||||
|         string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); |         string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|         string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); |         string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|  |  | ||||||
|         if (tname != "f16" && tname != "f32") { |         if (tname != "f16" && tname != "f32") { | ||||||
|             string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc); |             string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|             string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc); |             string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -322,25 +329,24 @@ void process_shaders() { | |||||||
|     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; |     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; | ||||||
|  |  | ||||||
|     // matmul |     // matmul | ||||||
|     for (const auto& fp16 : {false, true}) { |     for (const auto& matmul_id : {false, true}) { | ||||||
|         for (const auto& matmul_id : {false, true}) { |         // No coopmats | ||||||
|             for (const auto& coopmat2 : {false, true}) { |         // fp32 | ||||||
|                 for (const auto& f16acc : {false, true}) { |         matmul_shaders(false, matmul_id, false, false, false); | ||||||
| #if !defined(VK_NV_cooperative_matrix2) |  | ||||||
|                     if (coopmat2) { |         // fp16, fp32acc and fp16acc | ||||||
|                         continue; |         matmul_shaders(true, matmul_id, false, false, false); | ||||||
|                     } |         matmul_shaders(true, matmul_id, false, false, true); | ||||||
|  |  | ||||||
|  |         // Coopmat, fp32acc and fp16acc | ||||||
|  |         matmul_shaders(true, matmul_id, true, false, false); | ||||||
|  |         matmul_shaders(true, matmul_id, true, false, true); | ||||||
|  |  | ||||||
|  | #if defined(VK_NV_cooperative_matrix2) | ||||||
|  |         // Coopmat2, fp32acc and fp16acc | ||||||
|  |         matmul_shaders(true, matmul_id, false, true, false); | ||||||
|  |         matmul_shaders(true, matmul_id, false, true, true); | ||||||
| #endif | #endif | ||||||
|                     if (coopmat2 && !fp16) { |  | ||||||
|                         continue; |  | ||||||
|                     } |  | ||||||
|                     if (!coopmat2 && f16acc) { |  | ||||||
|                         continue; |  | ||||||
|                     } |  | ||||||
|                     matmul_shaders(fp16, matmul_id, coopmat2, f16acc); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #if defined(VK_NV_cooperative_matrix2) | #if defined(VK_NV_cooperative_matrix2) | ||||||
| @@ -355,11 +361,11 @@ void process_shaders() { | |||||||
|  |  | ||||||
|             if (tname == "f16") { |             if (tname == "f16") { | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||||
|                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc); |                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); | ||||||
|             } else { |             } else { | ||||||
|                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); |                 std::string data_a_key = "DATA_A_" + to_uppercase(tname); | ||||||
|                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", |                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", | ||||||
|                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc); |                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -524,6 +530,7 @@ void write_output_files() { | |||||||
|     fclose(hdr); |     fclose(hdr); | ||||||
|     fclose(src); |     fclose(src); | ||||||
| } | } | ||||||
|  | } | ||||||
|  |  | ||||||
| int main(int argc, char** argv) { | int main(int argc, char** argv) { | ||||||
|     std::map<std::string, std::string> args; |     std::map<std::string, std::string> args; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m