From b8a5cfd11abc066f165e924136d0fa1db44d2b95 Mon Sep 17 00:00:00 2001 From: SavicStefan <50296686+SavicStefan@users.noreply.github.com> Date: Sat, 8 Nov 2025 09:28:22 +0100 Subject: [PATCH] vulkan: Increase BK to 32; use BK/4 for non-CM mul_mm.comp (#16636) Signed-off-by: Stefan Savic Co-authored-by: Stefan Savic --- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index d260969f07..5c5251da39 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -100,7 +100,6 @@ layout (push_constant) uniform parameter layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; -layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; @@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat layout (constant_id = 10) const uint WARP = 32; +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#define BK 32 +#define BK_STEP 4 +#else +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +#define BK_STEP 2 +#endif + #ifdef COOPMAT #define SHMEM_STRIDE (BK / 2 + 4) #else @@ -244,8 +251,13 @@ void main() { } #else ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; +#if defined(DATA_A_F32) || defined(DATA_A_F16) + FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC4 cache_b; +#else FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; FLOAT_TYPE_VEC2 cache_b; +#endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); @@ -283,24 +295,41 @@ void main() { } } #else - [[unroll]] for (uint i = 0; i < BK / 2; i++) { + [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i ]; + cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + #endif } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { + #if defined(DATA_A_F32) || defined(DATA_A_F16) + cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i ]; + cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1]; + #else cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i]; + #endif [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; + #if defined(DATA_A_F32) || defined(DATA_A_F16) + sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); + sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), + fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); + #else sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); + #endif } } }