diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 975db09c5b..1bcb92f909 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -77,9 +77,13 @@ layout (constant_id = 10) const uint WARP = 32; #include "mul_mmq_shmem_types.glsl" +#ifndef BK_STEP +#define BK_STEP 4 +#endif + // Shared memory cache -shared block_a_cache buf_a[BM]; -shared block_b_cache buf_b[BN]; +shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR]; +shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR]; // Register cache block_a_cache cache_a[WMITER * TM]; block_b_cache cache_b; @@ -185,70 +189,64 @@ void main() { sums[i] = ACC_TYPE_VEC2(0.0f); } - for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { const uint buf_ib = loadc_a + l; const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; const uint iqs = loadr_a; - block_a_to_shmem(buf_ib, ib, iqs); + [[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { + block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); + } } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { + const uint buf_ib = loadc_b + l; + #ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const u16vec2 row_idx = row_ids[ic * BN + buf_ib]; const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; const uint ib = idx / 8; const uint iqs = idx & 0x7; #else - const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; + const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; const uint iqs = loadr_b; #endif - const uint buf_ib = loadc_b + l; - - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + [[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { + block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); } - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b[buf_ib].qs[iqs * 4 ] = values.x; - buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; - buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; - buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } barrier(); - pos_a_ib += 1; - pos_b_ib += 1; + pos_a_ib += BK_STEP; + pos_b_ib += BK_STEP; - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint reg_ib = wsir * TM + cr; - const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint reg_ib = wsir * TM + cr; + const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - block_a_to_registers(reg_ib, buf_ib); - } - } - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b.ds = buf_b[ib].ds; - [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + block_a_to_registers(reg_ib, k_step * BM + buf_ib); } + } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { - const uint cache_a_idx = wsir * TM + cr * 2; - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + block_b_to_registers(ib); - sums[sums_idx].x += mmq_dot_product(cache_a_idx); - sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { + const uint cache_a_idx = wsir * TM + cr * 2; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; + + sums[sums_idx].x += mmq_dot_product(cache_a_idx); + sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index cb539fa52e..b12599111f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -233,71 +233,113 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons #ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs * 4; + const uint iqs_k = (ib % 8) * 8 + iqs; const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); - const uint qs_shift = ((iqs_k % 32) / 8) * 2; + // const uint qs_shift = ((iqs_k % 32) / 8) * 2; // Repack 4x4 quants into one int - const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; - const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; - const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; - const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; + // const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; + // const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; + // const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; + // const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; - buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]); + } + if (iqs == 1) { + buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]); } } void block_a_to_registers(const uint reg_ib, const uint buf_ib) { cache_a[reg_ib].dm = buf_a[buf_ib].dm; - cache_a[reg_ib].scales = buf_a[buf_ib].scales; [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { + cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs]; + } + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; } } ACC_TYPE mmq_dot_product(const uint ib_a) { - int32_t sum_d = 0; - int32_t sum_m = 0; + float sum_d = 0; + float sum_m = 0; - uint8_t scale = cache_a[ib_a].scales[0]; - int32_t scale_m = int32_t(scale >> 4); - scale_m |= scale_m << 8; - scale_m |= scale_m << 16; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs]; + [[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) { + const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)]; + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303); - [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { - const uint qs_shift = iqs * 2; - - const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303); - - sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); - sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF)); + sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs])); + } } - scale = cache_a[ib_a].scales[1]; - scale_m = int32_t(scale >> 4); - scale_m |= scale_m << 8; - scale_m |= scale_m << 16; - - [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { - const uint qs_shift = (iqs - 4) * 2; - - const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303); - - sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); - sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); - } - - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m); } #endif // MMQ_SHMEM #endif +#ifdef MMQ_SHMEM +#if defined(DATA_A_QUANT_LEGACY) +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } +} +#elif defined(DATA_A_QUANT_K) +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_outer = ib / 4; + + buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]); + buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]); + + [[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) { + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x; + buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w; + } +} + +void block_b_to_registers(const uint ib) { + [[unroll]] for (uint i = 0; i < 4; i++) { + cache_b.ds[i] = buf_b[ib].ds[i]; + } + [[unroll]] for (uint iqs = 0; iqs < 32; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } +} +#else +#error unimplemented +#endif +#endif + #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) FLOAT_TYPE get_d(uint ib) { return FLOAT_TYPE(data_a[ib].d); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index e445c5646b..00286fc0c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -31,17 +31,31 @@ struct block_a_cache { FLOAT_TYPE dm; }; #elif defined(DATA_A_Q2_K) -#define QUANT_R_MMQ 4 +#define QUANT_R_MMQ 1 struct block_a_cache { - uint32_t qs[2]; - u8vec2 scales; + uint32_t qs[8]; + u8vec4 scales[2]; FLOAT_TYPE_VEC2 dm; }; #endif +#if defined(DATA_A_QUANT_LEGACY) +#define QUANT_BLOCK_FACTOR 1 + struct block_b_cache { int32_t qs[8]; FLOAT_TYPE_VEC2 ds; }; +#elif defined(DATA_A_QUANT_K) +#define QUANT_BLOCK_FACTOR 4 + +struct block_b_cache +{ + int32_t qs[32]; + FLOAT_TYPE_VEC2 ds[4]; +}; +#else +#error unimplemented +#endif