vulkan: In coopmat2 mmq, load q4_k/q5_k scales through shared memory (#12833)

q4_k and q5_k had a lot of redundant global loads where the same 16B of
scale information is repeatedly loaded and decoded during each loop iteration.
This change restructures the loops to more explicitly iterate over whole
blocks in the outer loop (with unrolled inner loop) and to copy/decode the
scale data into shared memory once at the start of each outer loop. The copy
is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%.
I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k
and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped
variants isn't used as often as it originally was (e.g. due to the padded_N
change), so I trimmed it down to offset some of the new complexity of the
semi-manual loop unrolling.
This commit is contained in:
Jeff Bolz
2025-04-09 00:25:08 -05:00
committed by GitHub
parent 7ecd780b1a
commit 0090950f67
3 changed files with 235 additions and 42 deletions

View File

@@ -19,6 +19,9 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define IS_MUL_MM2 1
layout (constant_id = 0) const uint BLOCK_SIZE = 256;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@@ -70,6 +73,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#define DECODEFUNCA
#endif
#if !defined(fetch_scales)
#define fetch_scales(a, b, c, d, e, f)
#endif
#if !defined(store_scales)
#define store_scales(a)
#endif
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
@@ -116,6 +126,8 @@ void main() {
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint tid = gl_LocalInvocationIndex;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
@@ -218,14 +230,21 @@ void main() {
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)
const uint START_ALIGN_K = 256;
// For Qi_K (block size 256), unroll whole 256 element tiles.
// For legacy quants (block size 32), unroll 8x.
const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
const uint unroll_count = UNROLL_K / BK;
// Detect a fast path where all loads are entirely in bounds and no clamping is required
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1
(stride_a % 8) == 0 &&
#endif
(stride_b % 8) == 0 && (start_k % 8) == 0) {
(stride_b % 8) == 0) {
// Hint to the compiler that values are aligned (want 16B alignment)
start_k &= ~7;
start_k &= ~(START_ALIGN_K-1);
stride_b &= ~7;
#if QUANT_K == 1
stride_a &= ~7;
@@ -234,11 +253,39 @@ void main() {
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
uint k_iters = (end_k - start_k + BK - 1) / BK;
uint k_iters = (end_k - start_k) / UNROLL_K;
uint block_k = start_k;
// fetch scale values for a tile of quants. These will be copied into shared memory.
// The fetches and stores are pipelined to hide the latency.
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
for (uint i = 0; i < k_iters; ++i) {
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
@@ -246,6 +293,7 @@ void main() {
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
@@ -253,8 +301,30 @@ void main() {
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
for (uint i = 0; i < k_iters; ++i) {
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
@@ -262,6 +332,7 @@ void main() {
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
@@ -269,8 +340,31 @@ void main() {
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
for (uint i = 0; i < k_iters; ++i) {
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
@@ -278,6 +372,7 @@ void main() {
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
@@ -298,47 +393,29 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
[[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
store_scales(tid);
if (block_k + BK < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
// Clamping is expensive, so detect different code paths for each combination
// of A and B needing clamping.
bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
bool unclampedB = true;
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
if (unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
// Convert from ACC_TYPE to D_TYPE