vulkan: add mmq q2_k integer dot support

This commit is contained in:
0cc4m
2025-09-28 16:34:27 +00:00
parent 226f295f4d
commit cc71ccca82
10 changed files with 128 additions and 26 deletions

View File

@@ -2966,6 +2966,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif
@@ -3085,6 +3087,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif

View File

@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
const vec2 dm = vec2(data_a[a_offset + ib].dm);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);

View File

@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}

View File

@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));

View File

@@ -41,9 +41,9 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
vec2 dm = vec2(data_a[ib0 + i].dm);
const FLOAT_TYPE dall = FLOAT_TYPE(dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(dm.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);

View File

@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(data_a[ib].d);
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)

View File

@@ -24,7 +24,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -84,6 +87,11 @@ layout (constant_id = 10) const uint WARP = 32;
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#ifdef DATA_A_QUANT_K
#define SHMEM_SCALES_STRIDE (SCALES_PER_32 + 1)
shared uint8_t buf_a_scales[BM * SHMEM_SCALES_STRIDE];
#endif
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
@@ -224,6 +232,10 @@ void main() {
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
#ifdef DATA_A_QUANT_K
uint8_t cache_a_scales[WMITER * TM * SCALES_PER_32];
#endif
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
@@ -243,9 +255,9 @@ void main() {
for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
if (iqs == 0) {
#if QUANT_AUXF == 1
@@ -261,6 +273,12 @@ void main() {
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
#ifdef DATA_A_QUANT_K
if (iqs % 4 == 0) {
buf_a_scales[buf_ib * SHMEM_SCALES_STRIDE + iqs / 4] = get_scale(ib, iqs);
}
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
@@ -333,6 +351,11 @@ void main() {
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
#ifdef DATA_A_QUANT_K
[[unroll]] for (uint s = 0; s < SCALES_PER_32; s++) {
cache_a_scales[(wsir * TM + cr) * SCALES_PER_32 + s] = buf_a_scales[ib * SHMEM_SCALES_STRIDE + s];
}
#endif
}
}
@@ -350,6 +373,8 @@ void main() {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
#if defined(DATA_A_QUANT_LEGACY)
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
@@ -357,6 +382,36 @@ void main() {
}
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
#elif defined(DATA_A_QUANT_K)
int32_t sum_d = 0;
int32_t sum_m = 0;
const int32_t scale0 = cache_a_scales[cache_a_idx * SCALES_PER_32];
const int32_t scale1 = cache_a_scales[cache_a_idx * SCALES_PER_32 + 1];
int32_t scale_m = scale0 >> 4;
scale_m |= scale_m << 8;
scale_m |= scale_m << 16;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 8; idx_k++) {
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale0 & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
}
scale_m = scale1 >> 4;
scale_m |= scale_m << 8;
scale_m |= scale_m << 16;
[[unroll]] for (uint idx_k = BK / 8; idx_k < BK / 4; idx_k++) {
sum_d += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]) * (scale1 & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[cc * (BK / 4) + idx_k]);
}
sums[sums_idx] += mul_q8_1(sum_d, sum_m, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
#else
#error unsupported
#endif
}
}
}

View File

@@ -9,8 +9,8 @@
#if defined(DATA_A_Q4_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
@@ -37,8 +37,8 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int
#if defined(DATA_A_Q5_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
@@ -77,8 +77,8 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int
#if defined(DATA_A_Q8_0)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
@@ -86,6 +86,31 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int
}
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..r for 32-wide quants
#if defined(DATA_A_Q2_K)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
@@ -103,3 +128,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif

View File

@@ -66,6 +66,7 @@ struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
f16vec2 d;
f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#define SCALES_PER_32 2
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
@@ -274,6 +281,7 @@ struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
@@ -310,6 +318,7 @@ struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
@@ -340,6 +349,7 @@ struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
@@ -365,6 +375,7 @@ struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#define DATA_A_QUANT_K
#endif
// IQuants

View File

@@ -566,7 +566,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif