From 385e8270572561b8766e3860f49999c0e3da3d6a Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 12 Oct 2025 15:36:21 +0000 Subject: [PATCH] Refactor mmq caching --- .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 135 ++-------- .../vulkan-shaders/mul_mmq_funcs.glsl | 240 +++++++++++++++--- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 47 ++++ 3 files changed, 278 insertions(+), 144 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index cc3f617f69..ad16a75787 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -81,33 +81,23 @@ layout (constant_id = 10) const uint WARP = 32; #ifdef COOPMAT #define SHMEM_STRIDE (BK / 4 + 4) -#else -#define SHMEM_STRIDE (BK / 4 + 1) #endif -shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; +#define MMQ_SHMEM -#ifdef DATA_A_QUANT_K -#define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1) -shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE]; -#endif +#include "mul_mmq_shmem_types.glsl" -#ifndef COOPMAT -#if QUANT_AUXF == 1 -shared FLOAT_TYPE buf_a_dm[BM]; -#else -shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; -#endif -#endif +// Shared memory cache +shared block_a_cache buf_a[BM]; +shared block_b_cache buf_b[BN]; +// Register cache +block_a_cache cache_a[WMITER * TM]; +block_b_cache cache_b[TN]; -shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; -#ifndef COOPMAT -shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; -#endif - -#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_A (4 * QUANT_R_MMQ) #define LOAD_VEC_B 16 +// TODO: Recheck if this can work with mul_mat_id #ifdef MUL_MAT_ID shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID @@ -230,13 +220,6 @@ void main() { sums[i] = coopmat(0.0f); } #else - int32_t cache_a_qs[WMITER * TM * BK / 4]; - -#ifdef DATA_A_QUANT_K - uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32]; -#endif - - int32_t cache_b_qs[TN * BK / 4]; ACC_TYPE sums[WMITER * TM * WNITER * TN]; @@ -245,40 +228,13 @@ void main() { } #endif -#if QUANT_AUXF == 1 - FLOAT_TYPE cache_a_dm[WMITER * TM]; -#else - FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; -#endif - - FLOAT_TYPE_VEC2 cache_b_ds[TN]; - for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { const uint buf_ib = loadc_a + l; const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; const uint iqs = loadr_a; - if (iqs == 0) { -#if QUANT_AUXF == 1 - buf_a_dm[buf_ib] = get_d(ib); -#else - buf_a_dm[buf_ib] = get_dm(ib); -#endif - } -#if QUANT_R == 1 - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); -#else - const i32vec2 vals = repack(ib, iqs); - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; -#endif - -#ifdef DATA_A_QUANT_K - if (iqs % 4 == 0) { - buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs); - } -#endif + block_a_to_shmem(buf_ib, ib, iqs); } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { #ifdef MUL_MAT_ID @@ -297,13 +253,13 @@ void main() { const uint buf_ib = loadc_b + l; if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); } const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } barrier(); @@ -346,25 +302,19 @@ void main() { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; - } -#ifdef DATA_A_QUANT_K - [[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) { - cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s]; - } -#endif + const uint reg_ib = wsir * TM + cr; + const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + + block_a_to_registers(reg_ib, buf_ib); } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b_ds[cc] = buf_b_ds[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + cache_b[cc].ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs]; } } @@ -374,44 +324,7 @@ void main() { const uint cache_a_idx = wsir * TM + cr; const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; -#if defined(DATA_A_QUANT_LEGACY) - int32_t q_sum = 0; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]); - } - - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); -#elif defined(DATA_A_QUANT_K) - int32_t sum_d = 0; - int32_t sum_m = 0; - - const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32]; - const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1]; - int32_t scale_m = scale0 >> 4; - scale_m |= scale_m << 8; - scale_m |= scale_m << 16; - - [[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) { - sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF); - sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]); - } - - scale_m = scale1 >> 4; - scale_m |= scale_m << 8; - scale_m |= scale_m << 16; - - [[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) { - sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF); - sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]); - } - - sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); -#else -#error unsupported -#endif + sums[sums_idx] += mmq_dot_product(cache_a_idx, cc); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 31503cf71c..b7e1ff81ab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -6,41 +6,92 @@ // Each iqs value maps to a 32-bit integer -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// 2-byte loads for Q4_0 blocks (18 bytes) +// 4-byte loads for Q4_1 blocks (20 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 +#ifdef DATA_A_Q4_0 const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else // DATA_A_Q4_1 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +#endif } +#ifdef DATA_A_Q4_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q4_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -} - +#else // DATA_A_Q4_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif -#if defined(DATA_A_Q5_0) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q4_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +#else // DATA_A_Q4_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) { +} + +ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); + + const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; + const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); + q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); +} +#endif // MMQ_SHMEM + +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +// 2-byte loads for Q5_0 blocks (22 bytes) +// 4-byte loads for Q5_1 blocks (24 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); - const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); +#ifdef DATA_A_Q5_0 + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); +#else // DATA_A_Q5_1 + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); +#endif const int32_t v0 = int32_t(vui & 0x0F0F0F0F) | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) @@ -50,33 +101,70 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } +#ifdef DATA_A_Q5_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q5_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - +#else // DATA_A_Q5_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q5_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1])); + } +#else // DATA_A_Q5_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].qh = data_a_packed32[ib].qh; + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].qh = buf_a[buf_ib].qh; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs)); + const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; + const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a0, qs_b0); + q_sum += dotPacked4x8EXT(qs_a1, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); +} +#endif // MMQ_SHMEM +#endif + #if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) int32_t repack(uint ib, uint iqs) { - // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], data_a_packed16[ib].qs[iqs * 2 + 1])); } @@ -84,11 +172,43 @@ int32_t repack(uint ib, uint iqs) { ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + const int32_t qs_b = cache_b[ib_b].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, qs_b); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); +} +#endif // MMQ_SHMEM #endif // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide -// iqs still refers to a 32-bit integer, meaning 0..r for 32-wide quants +// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants #if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) int32_t repack(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; @@ -109,6 +229,60 @@ uint8_t get_scale(uint ib, uint iqs) { ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); } + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * 4; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + // Repack 4x4 quants into one int + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); + + if (iqs == 0) { + buf_a[buf_ib].scales = u8vec2(data_a[ib_k].scales[iqs_k / 4], data_a[ib_k].scales[iqs_k / 4 + 1]); + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].scales = buf_a[buf_ib].scales; + + [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + const i32vec2 scales = i32vec2(cache_a[ib_a].scales); + i32vec2 scale_m = scales >> 4; + scale_m |= scale_m << 8; + scale_m |= scale_m << 16; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const uint idx_half = iqs / 4; + const uint qs_shift = (iqs % 4) * 2; + + const int32_t qs_a = int32_t((cache_a[ib_a].qs[idx_half] >> qs_shift) & 0x03030303); + + sum_d += dotPacked4x8EXT(qs_a, cache_b[ib_b].qs[iqs]) * (scales[idx_half] & 0xF); + sum_m += dotPacked4x8EXT(scale_m[idx_half], cache_b[ib_b].qs[iqs]); + } + + return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); +} +#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl new file mode 100644 index 0000000000..e445c5646b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -0,0 +1,47 @@ +#if defined(DATA_A_Q4_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q4_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q5_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q8_0) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[32/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q2_K) +#define QUANT_R_MMQ 4 +struct block_a_cache +{ + uint32_t qs[2]; + u8vec2 scales; + FLOAT_TYPE_VEC2 dm; +}; +#endif + +struct block_b_cache +{ + int32_t qs[8]; + FLOAT_TYPE_VEC2 ds; +};