vulkan: optimizations for deepseek prompt processing (#14555)

* vulkan: allow unclamped loads in coopmat2 mul_mat_id shader

* vulkan: increase coopmat2 mul_mat_id tile size

* vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path

* vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits)
This commit is contained in:
Jeff Bolz
2025-07-12 04:51:58 -05:00
committed by GitHub
parent f5e96b368f
commit 98197e5c98
3 changed files with 104 additions and 18 deletions

View File

@@ -18,6 +18,7 @@
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
uint _ne1;
#ifdef COOPMAT
shared uint _ne1_sh;
#endif
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
#ifdef COOPMAT
// Spread the search across all elements in the first subgroup
if (gl_SubgroupID == 0) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_SubgroupInvocationID;
bool in_range = i < num_elements;
uint ii1 = i / p.nei0;
uint ii0 = i % p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
uint idx = subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
}
_ne1 += subgroupBallotBitCount(ballot);
iter &= 15;
}
_ne1_sh = _ne1;
}
barrier();
_ne1 = _ne1_sh;
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;