From ac7a7aa2c62bdcf9a387c9f230b9947439eb4521 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 19 Oct 2025 09:00:17 +0000 Subject: [PATCH] Reduce mmq register use --- .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 169 +++--------------- .../vulkan-shaders/mul_mmq_funcs.glsl | 57 +++--- 2 files changed, 61 insertions(+), 165 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index ad16a75787..975db09c5b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -10,12 +10,6 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif -#ifdef COOPMAT -#extension GL_KHR_cooperative_matrix : enable -#extension GL_KHR_memory_scope_semantics : enable -#extension GL_KHR_shader_subgroup_basic : enable -#endif - #ifdef MUL_MAT_ID #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif @@ -79,10 +73,6 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#ifdef COOPMAT -#define SHMEM_STRIDE (BK / 4 + 4) -#endif - #define MMQ_SHMEM #include "mul_mmq_shmem_types.glsl" @@ -92,7 +82,7 @@ shared block_a_cache buf_a[BM]; shared block_b_cache buf_b[BN]; // Register cache block_a_cache cache_a[WMITER * TM]; -block_b_cache cache_b[TN]; +block_b_cache cache_b; #define LOAD_VEC_A (4 * QUANT_R_MMQ) #define LOAD_VEC_B 16 @@ -104,10 +94,6 @@ shared u16vec2 row_ids[4096]; #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef COOPMAT -shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; -#endif - #include "mul_mmq_funcs.glsl" void main() { @@ -137,26 +123,12 @@ void main() { const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; - -#ifdef COOPMAT - const uint warp_i = gl_SubgroupID; - - const uint tiw = gl_SubgroupInvocationID; - - const uint cms_per_row = WM / TM; - const uint cms_per_col = WN / TN; - - const uint storestride = WARP / TM; - const uint store_r = tiw % TM; - const uint store_c = tiw / TM; -#else const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); -#endif const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); @@ -207,26 +179,11 @@ void main() { uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif -#ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; - coopmat cm_result; + ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2]; - coopmat factors[cms_per_row * cms_per_col]; - - coopmat sums[cms_per_row * cms_per_col]; - - [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { - sums[i] = coopmat(0.0f); + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { + sums[i] = ACC_TYPE_VEC2(0.0f); } -#else - - ACC_TYPE sums[WMITER * TM * WNITER * TN]; - - [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = ACC_TYPE(0.0f); - } -#endif for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { @@ -267,38 +224,6 @@ void main() { pos_a_ib += 1; pos_b_ib += 1; -#ifdef COOPMAT - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - const uint ib_a = warp_r * WM + cm_row * TM; - // Load from shared into cache - coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { - cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; - } - - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const uint ib_b = warp_c * WN + cm_col * TN; - coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { - cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; - } - - cm_result = coopmat(0); - cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); - } - - coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); - } - } -#else // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { @@ -312,24 +237,22 @@ void main() { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b[cc].ds = buf_b[ib].ds; + cache_b.ds = buf_b[ib].ds; [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs]; + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; } - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint cache_a_idx = wsir * TM + cr; - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint cache_a_idx = wsir * TM + cr * 2; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; - sums[sums_idx] += mmq_dot_product(cache_a_idx, cc); + sums[sums_idx].x += mmq_dot_product(cache_a_idx); + sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); } } } } -#endif barrier(); } @@ -341,54 +264,6 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif -#ifdef COOPMAT -#ifdef MUL_MAT_ID - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < BN; col += storestride) { - const uint row_i = dc + cm_col * TN + col + store_c; - if (row_i >= _ne1) break; - - const u16vec2 row_idx = row_ids[row_i]; - - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } -#else - const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float - - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; - - if (is_aligned && is_in_bounds) { - // Full coopMat is within bounds and stride_d is aligned with 16B - coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); - coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); - } else if (is_in_bounds) { - // Full coopMat is within bounds, but stride_d is not aligned - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { - // Partial coopMat is within bounds - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } - } - } -#endif // MUL_MAT_ID -#else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -399,19 +274,27 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID - [[unroll]] for (uint cr = 0; cr < TM; cr++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); + } #else - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); + } + if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); } #endif // MUL_MAT_ID } } } } -#endif // COOPMAT } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index b7e1ff81ab..cb539fa52e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -62,21 +62,21 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) { } -ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { +ACC_TYPE mmq_dot_product(const uint ib_a) { int32_t q_sum = 0; [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { const uint32_t vui = cache_a[ib_a].qs[iqs]; const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); - const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; - const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); } #endif // MMQ_SHMEM @@ -140,7 +140,7 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { } } -ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { +ACC_TYPE mmq_dot_product(const uint ib_a) { int32_t q_sum = 0; [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { const uint32_t vui = cache_a[ib_a].qs[iqs]; @@ -150,14 +150,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; - const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; q_sum += dotPacked4x8EXT(qs_a0, qs_b0); q_sum += dotPacked4x8EXT(qs_a1, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); } #endif // MMQ_SHMEM #endif @@ -191,16 +191,16 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { } } -ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { +ACC_TYPE mmq_dot_product(const uint ib_a) { int32_t q_sum = 0; [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { const int32_t qs_a = cache_a[ib_a].qs[iqs]; - const int32_t qs_b = cache_b[ib_b].qs[iqs]; + const int32_t qs_b = cache_b.qs[iqs]; q_sum += dotPacked4x8EXT(qs_a, qs_b); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); } #endif // MMQ_SHMEM #endif @@ -247,7 +247,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].scales = u8vec2(data_a[ib_k].scales[iqs_k / 4], data_a[ib_k].scales[iqs_k / 4 + 1]); + buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); } } @@ -261,26 +261,39 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { } } -ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { +ACC_TYPE mmq_dot_product(const uint ib_a) { int32_t sum_d = 0; int32_t sum_m = 0; - const i32vec2 scales = i32vec2(cache_a[ib_a].scales); - i32vec2 scale_m = scales >> 4; + uint8_t scale = cache_a[ib_a].scales[0]; + int32_t scale_m = int32_t(scale >> 4); scale_m |= scale_m << 8; scale_m |= scale_m << 16; - [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { - const uint idx_half = iqs / 4; - const uint qs_shift = (iqs % 4) * 2; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint qs_shift = iqs * 2; - const int32_t qs_a = int32_t((cache_a[ib_a].qs[idx_half] >> qs_shift) & 0x03030303); + const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303); - sum_d += dotPacked4x8EXT(qs_a, cache_b[ib_b].qs[iqs]) * (scales[idx_half] & 0xF); - sum_m += dotPacked4x8EXT(scale_m[idx_half], cache_b[ib_b].qs[iqs]); + sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); + scale = cache_a[ib_a].scales[1]; + scale_m = int32_t(scale >> 4); + scale_m |= scale_m << 8; + scale_m |= scale_m << 16; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const uint qs_shift = (iqs - 4) * 2; + + const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303); + + sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + } + + return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); } #endif // MMQ_SHMEM #endif