Load 4 quant blocks into shared memory in one step

This commit is contained in:
0cc4m
2025-10-25 10:46:11 +00:00
parent ac7a7aa2c6
commit 45b9ff5fcf
3 changed files with 134 additions and 80 deletions

View File

@@ -77,9 +77,13 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
#ifndef BK_STEP
#define BK_STEP 4
#endif
// Shared memory cache
shared block_a_cache buf_a[BM];
shared block_b_cache buf_b[BN];
shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR];
shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
@@ -185,70 +189,64 @@ void main() {
sums[i] = ACC_TYPE_VEC2(0.0f);
}
for (uint block = start_k; block < end_k; block += BK) {
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
block_a_to_shmem(buf_ib, ib, iqs);
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
const uint buf_ib = loadc_b + l;
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const u16vec2 row_idx = row_ids[ic * BN + buf_ib];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
pos_a_ib += BK_STEP;
pos_b_ib += BK_STEP;
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
block_a_to_registers(reg_ib, buf_ib);
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint cache_a_idx = wsir * TM + cr * 2;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
block_b_to_registers(ib);
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint cache_a_idx = wsir * TM + cr * 2;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr;
sums[sums_idx].x += mmq_dot_product(cache_a_idx);
sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1);
}
}
}
}

View File

@@ -233,71 +233,113 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// Repack 4x4 quants into one int
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
// const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
// const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
// const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
// const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]);
}
if (iqs == 1) {
buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
}
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t sum_d = 0;
int32_t sum_m = 0;
float sum_d = 0;
float sum_m = 0;
uint8_t scale = cache_a[ib_a].scales[0];
int32_t scale_m = int32_t(scale >> 4);
scale_m |= scale_m << 8;
scale_m |= scale_m << 16;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
[[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) {
const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303);
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint qs_shift = iqs * 2;
const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303);
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
}
}
scale = cache_a[ib_a].scales[1];
scale_m = int32_t(scale >> 4);
scale_m |= scale_m << 8;
scale_m |= scale_m << 16;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const uint qs_shift = (iqs - 4) * 2;
const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303);
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
#if defined(DATA_A_QUANT_LEGACY)
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#elif defined(DATA_A_QUANT_K)
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]);
[[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) {
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w;
}
}
void block_b_to_registers(const uint ib) {
[[unroll]] for (uint i = 0; i < 4; i++) {
cache_b.ds[i] = buf_b[ib].ds[i];
}
[[unroll]] for (uint iqs = 0; iqs < 32; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#else
#error unimplemented
#endif
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);

View File

@@ -31,17 +31,31 @@ struct block_a_cache {
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
#define QUANT_R_MMQ 1
struct block_a_cache
{
uint32_t qs[2];
u8vec2 scales;
uint32_t qs[8];
u8vec4 scales[2];
FLOAT_TYPE_VEC2 dm;
};
#endif
#if defined(DATA_A_QUANT_LEGACY)
#define QUANT_BLOCK_FACTOR 1
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
};
#elif defined(DATA_A_QUANT_K)
#define QUANT_BLOCK_FACTOR 4
struct block_b_cache
{
int32_t qs[32];
FLOAT_TYPE_VEC2 ds[4];
};
#else
#error unimplemented
#endif