vulkan: mul_mat_id coopmat2 optimizations (#15546)

* vulkan: mul_mat_id coopmat2 optimizations

Add a path for when the tile fits in BN/2, similar to what we have for mul_mat.

Only call fetch_scales/store_scales once per QUANT_K block, and once at the
beginning in case start_k is not aligned.

* Also add a path for BN/4 - worth a couple more percent
This commit is contained in:
Jeff Bolz
2025-08-31 02:06:43 -05:00
committed by GitHub
parent 5c16b9c87d
commit c37052ab4d
2 changed files with 93 additions and 6 deletions

View File

@@ -456,18 +456,105 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
store_scales(tid);
#ifdef MUL_MAT_ID
if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
if ((block_k % QUANT_K) == 0) {
store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
return;
}
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
if ((block_k % QUANT_K) == 0) {
store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
return;
}
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
store_scales(tid);
if (block_k + BK < end_k) {
if ((block_k % QUANT_K) == 0) {
store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}