mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
vulkan: optimize mul_mat_id loading row ids into shared memory (#15427)
- Spread the work across the whole workgroup. Using more threads seems to far outweigh the synchronization overhead. - Specialize the code for when the division is by a power of two.
This commit is contained in:
@@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
s_mmq_wg_denoms_k = { 32, 64, 1 };
|
s_mmq_wg_denoms_k = { 32, 64, 1 };
|
||||||
|
|
||||||
// spec constants and tile sizes for quant matmul_id
|
// spec constants and tile sizes for quant matmul_id
|
||||||
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
|
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
|
||||||
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
|
||||||
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
|
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
|
||||||
l_mmqid_wg_denoms = { 128, 128, 1 };
|
l_mmqid_wg_denoms = { 128, 128, 1 };
|
||||||
m_mmqid_wg_denoms = { 128, 64, 1 };
|
m_mmqid_wg_denoms = { 128, 64, 1 };
|
||||||
s_mmqid_wg_denoms = { 128, 64, 1 };
|
s_mmqid_wg_denoms = { 128, 64, 1 };
|
||||||
|
|||||||
@@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32;
|
|||||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||||
|
|
||||||
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
shared u16vec2 row_ids[4096];
|
shared u16vec2 row_ids[4096];
|
||||||
uint _ne1;
|
uint _ne1;
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
shared uint _ne1_sh;
|
shared uvec4 ballots_sh[NUM_WARPS];
|
||||||
|
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
||||||
|
_ne1 = 0;
|
||||||
|
uint num_elements = p.nei1 * p.nei0;
|
||||||
|
uint nei0shift = findLSB(p.nei0);
|
||||||
|
|
||||||
|
uint ids[16];
|
||||||
|
uint iter = 0;
|
||||||
|
|
||||||
|
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||||
|
// prefetch up to 16 elements
|
||||||
|
if (iter == 0) {
|
||||||
|
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||||
|
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint i = j + gl_LocalInvocationIndex;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
uint id = ids[iter++];
|
||||||
|
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||||
|
|
||||||
|
ballots_sh[gl_SubgroupID] = ballot;
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint subgroup_base = 0;
|
||||||
|
uint total = 0;
|
||||||
|
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||||
|
if (k == gl_SubgroupID) {
|
||||||
|
subgroup_base = total;
|
||||||
|
}
|
||||||
|
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||||
|
if (in_range && id == expert_idx) {
|
||||||
|
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||||
|
}
|
||||||
|
_ne1 += total;
|
||||||
|
iter &= 15;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
#endif // MUL_MAT_ID
|
#endif // MUL_MAT_ID
|
||||||
|
|
||||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||||
#endif
|
#endif
|
||||||
@@ -178,44 +236,11 @@ void main() {
|
|||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
// Spread the search across all elements in the first subgroup
|
if (bitCount(p.nei0) == 1) {
|
||||||
if (gl_SubgroupID == 0) {
|
load_row_ids(expert_idx, true);
|
||||||
_ne1 = 0;
|
} else {
|
||||||
uint num_elements = p.nei1 * p.nei0;
|
load_row_ids(expert_idx, false);
|
||||||
|
|
||||||
uint ids[16];
|
|
||||||
uint iter = 0;
|
|
||||||
|
|
||||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
|
||||||
// prefetch up to 16 elements
|
|
||||||
if (iter == 0) {
|
|
||||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
||||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1 = i / p.nei0;
|
|
||||||
uint ii0 = i % p.nei0;
|
|
||||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uint i = j + gl_SubgroupInvocationID;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1 = i / p.nei0;
|
|
||||||
uint ii0 = i % p.nei0;
|
|
||||||
uint id = ids[iter++];
|
|
||||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
||||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
||||||
if (in_range && id == expert_idx) {
|
|
||||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
|
||||||
}
|
|
||||||
_ne1 += subgroupBallotBitCount(ballot);
|
|
||||||
iter &= 15;
|
|
||||||
}
|
|
||||||
_ne1_sh = _ne1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
_ne1 = _ne1_sh;
|
|
||||||
#else
|
#else
|
||||||
_ne1 = 0;
|
_ne1 = 0;
|
||||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
#include "utils.comp"
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
@@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
|||||||
};
|
};
|
||||||
|
|
||||||
uint _ne1;
|
uint _ne1;
|
||||||
shared uint _ne1_sh;
|
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||||
|
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||||
|
|
||||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||||
{
|
{
|
||||||
@@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
|||||||
return elem;
|
return elem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
||||||
|
_ne1 = 0;
|
||||||
|
uint num_elements = p.nei1 * p.nei0;
|
||||||
|
uint nei0shift = findLSB(p.nei0);
|
||||||
|
|
||||||
|
uint ids[16];
|
||||||
|
uint iter = 0;
|
||||||
|
|
||||||
|
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||||
|
// prefetch up to 16 elements
|
||||||
|
if (iter == 0) {
|
||||||
|
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||||
|
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
uint i = j + gl_LocalInvocationIndex;
|
||||||
|
bool in_range = i < num_elements;
|
||||||
|
uint ii1;
|
||||||
|
if (nei0_is_pow2) {
|
||||||
|
ii1 = i >> nei0shift;
|
||||||
|
} else {
|
||||||
|
ii1 = i / p.nei0;
|
||||||
|
}
|
||||||
|
uint ii0 = i - ii1 * p.nei0;
|
||||||
|
uint id = ids[iter++];
|
||||||
|
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||||
|
|
||||||
|
ballots_sh[gl_SubgroupID] = ballot;
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint subgroup_base = 0;
|
||||||
|
uint total = 0;
|
||||||
|
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||||
|
if (k == gl_SubgroupID) {
|
||||||
|
subgroup_base = total;
|
||||||
|
}
|
||||||
|
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||||
|
if (in_range && id == expert_idx) {
|
||||||
|
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||||
|
}
|
||||||
|
_ne1 += total;
|
||||||
|
iter &= 15;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
@@ -157,45 +217,12 @@ void main() {
|
|||||||
const uint ic = gl_WorkGroupID.y;
|
const uint ic = gl_WorkGroupID.y;
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
// Spread the search across all elements in the first subgroup
|
if (bitCount(p.nei0) == 1) {
|
||||||
if (gl_SubgroupID == 0) {
|
load_row_ids(expert_idx, true);
|
||||||
_ne1 = 0;
|
} else {
|
||||||
uint num_elements = p.nei1 * p.nei0;
|
load_row_ids(expert_idx, false);
|
||||||
|
|
||||||
uint ids[16];
|
|
||||||
uint iter = 0;
|
|
||||||
|
|
||||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
|
||||||
// prefetch up to 16 elements
|
|
||||||
if (iter == 0) {
|
|
||||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
||||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1 = i / p.nei0;
|
|
||||||
uint ii0 = i % p.nei0;
|
|
||||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uint i = j + gl_SubgroupInvocationID;
|
|
||||||
bool in_range = i < num_elements;
|
|
||||||
uint ii1 = i / p.nei0;
|
|
||||||
uint ii0 = i % p.nei0;
|
|
||||||
uint id = ids[iter++];
|
|
||||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
||||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
||||||
if (in_range && id == expert_idx) {
|
|
||||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
|
||||||
}
|
|
||||||
_ne1 += subgroupBallotBitCount(ballot);
|
|
||||||
iter &= 15;
|
|
||||||
}
|
|
||||||
_ne1_sh = _ne1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
_ne1 = _ne1_sh;
|
|
||||||
|
|
||||||
// Workgroup has no work
|
// Workgroup has no work
|
||||||
if (ic * BN >= _ne1) return;
|
if (ic * BN >= _ne1) return;
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user