Pack q2_k blocks into caches of 32

This commit is contained in:
0cc4m
2025-10-25 13:31:07 +00:00
parent 45b9ff5fcf
commit 7984fc57e0
4 changed files with 27 additions and 77 deletions

View File

@@ -2511,6 +2511,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
@@ -3148,7 +3149,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
#endif
) {

View File

@@ -82,8 +82,8 @@ layout (constant_id = 10) const uint WARP = 32;
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR];
shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR];
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
@@ -195,7 +195,7 @@ void main() {
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
}
@@ -213,7 +213,7 @@ void main() {
const uint iqs = loadr_b;
#endif
[[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
}
@@ -223,7 +223,7 @@ void main() {
pos_a_ib += BK_STEP;
pos_b_ib += BK_STEP;
for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) {
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {

View File

@@ -233,63 +233,53 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
// const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// Repack 4x4 quants into one int
// const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
// const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
// const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
// const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]);
}
if (iqs == 1) {
buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]);
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs];
}
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float sum_d = 0;
float sum_m = 0;
int32_t sum_d = 0;
int32_t sum_m = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs];
[[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) {
const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303);
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF));
sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs]));
}
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m);
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
#if defined(DATA_A_QUANT_LEGACY)
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
@@ -311,33 +301,6 @@ void block_b_to_registers(const uint ib) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#elif defined(DATA_A_QUANT_K)
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
buf_b[buf_ib].ds[iqs * 2 ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 ]);
buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]);
[[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) {
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 ] = values.x;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w;
}
}
void block_b_to_registers(const uint ib) {
[[unroll]] for (uint i = 0; i < 4; i++) {
cache_b.ds[i] = buf_b[ib].ds[i];
}
[[unroll]] for (uint iqs = 0; iqs < 32; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#else
#error unimplemented
#endif
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)

View File

@@ -31,31 +31,17 @@ struct block_a_cache {
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 1
#define QUANT_R_MMQ 4
struct block_a_cache
{
uint32_t qs[8];
u8vec4 scales[2];
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
};
#endif
#if defined(DATA_A_QUANT_LEGACY)
#define QUANT_BLOCK_FACTOR 1
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
};
#elif defined(DATA_A_QUANT_K)
#define QUANT_BLOCK_FACTOR 4
struct block_b_cache
{
int32_t qs[32];
FLOAT_TYPE_VEC2 ds[4];
};
#else
#error unimplemented
#endif