Add q4_k mmq

This commit is contained in:
0cc4m
2025-10-27 13:51:03 +00:00
parent 33394ca1ba
commit b684d69338
9 changed files with 79 additions and 14 deletions

View File

@@ -2972,6 +2972,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif
@@ -3093,6 +3094,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif

View File

@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);

View File

@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;

View File

@@ -14,9 +14,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].dm.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];

View File

@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);

View File

@@ -233,7 +233,7 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * 4;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
@@ -279,6 +279,63 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K)
// 4-byte loads for Q4_K blocks (144 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
// Repack 2x4 quants into one int
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;

View File

@@ -26,18 +26,25 @@ struct block_a_cache {
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
#define BK_STEP 4
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
struct block_a_cache
{
struct block_a_cache {
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 dm;
};
#endif
struct block_b_cache

View File

@@ -288,21 +288,21 @@ struct block_q3_K_packed16
struct block_q4_K
{
f16vec2 d;
f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
f16vec2 d;
f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
f16vec2 d;
f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};

View File

@@ -567,7 +567,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k")) {
if (!f16acc && !coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (is_legacy_quant(tname) || tname == "q2_k" || tname == "q4_k")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif