mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
vulkan: optimizations for deepseek prompt processing (#14555)
* vulkan: allow unclamped loads in coopmat2 mul_mat_id shader * vulkan: increase coopmat2 mul_mat_id tile size * vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path * vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits)
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
@@ -172,7 +177,47 @@ void main() {
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
@@ -183,6 +228,7 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
|
||||
@@ -162,17 +162,32 @@ void main() {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
}
|
||||
@@ -414,17 +429,31 @@ void main() {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
#endif
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
|
||||
Reference in New Issue
Block a user