mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: optimize mul_mat_id loading row ids into shared memory (#15427)
- Spread the work across the whole workgroup. Using more threads seems to far outweigh the synchronization overhead. - Specialize the code for when the division is by a power of two.
This commit is contained in:
		| @@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) { | |||||||
|         s_mmq_wg_denoms_k = { 32,  64,  1 }; |         s_mmq_wg_denoms_k = { 32,  64,  1 }; | ||||||
|  |  | ||||||
|         // spec constants and tile sizes for quant matmul_id |         // spec constants and tile sizes for quant matmul_id | ||||||
|         l_warptile_mmqid = { 256, 128, 128, 16, 0 }; |         l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size }; | ||||||
|         m_warptile_mmqid = { 256, 128, 64, 16, 0 }; |         m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; | ||||||
|         s_warptile_mmqid = { 256, 128, 64, 16, 0 }; |         s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; | ||||||
|         l_mmqid_wg_denoms = { 128, 128, 1 }; |         l_mmqid_wg_denoms = { 128, 128, 1 }; | ||||||
|         m_mmqid_wg_denoms = { 128, 64, 1 }; |         m_mmqid_wg_denoms = { 128, 64, 1 }; | ||||||
|         s_mmqid_wg_denoms = { 128, 64, 1 }; |         s_mmqid_wg_denoms = { 128, 64, 1 }; | ||||||
|   | |||||||
| @@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32; | |||||||
| shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; | shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; | ||||||
| shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; | shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; | ||||||
|  |  | ||||||
|  | #define NUM_WARPS (BLOCK_SIZE / WARP) | ||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
| shared u16vec2 row_ids[4096]; | shared u16vec2 row_ids[4096]; | ||||||
| uint _ne1; | uint _ne1; | ||||||
| #ifdef COOPMAT | #ifdef COOPMAT | ||||||
| shared uint _ne1_sh; | shared uvec4 ballots_sh[NUM_WARPS]; | ||||||
|  | void load_row_ids(uint expert_idx, bool nei0_is_pow2) { | ||||||
|  |     _ne1 = 0; | ||||||
|  |     uint num_elements = p.nei1 * p.nei0; | ||||||
|  |     uint nei0shift = findLSB(p.nei0); | ||||||
|  |  | ||||||
|  |     uint ids[16]; | ||||||
|  |     uint iter = 0; | ||||||
|  |  | ||||||
|  |     for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { | ||||||
|  |         // prefetch up to 16 elements | ||||||
|  |         if (iter == 0) { | ||||||
|  |             [[unroll]] for (uint k = 0; k < 16; ++k) { | ||||||
|  |                 uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; | ||||||
|  |                 bool in_range = i < num_elements; | ||||||
|  |                 uint ii1; | ||||||
|  |                 if (nei0_is_pow2) { | ||||||
|  |                     ii1 = i >> nei0shift; | ||||||
|  |                 } else { | ||||||
|  |                     ii1 = i / p.nei0; | ||||||
|  |                 } | ||||||
|  |                 uint ii0 = i - ii1 * p.nei0; | ||||||
|  |                 ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         uint i = j + gl_LocalInvocationIndex; | ||||||
|  |         bool in_range = i < num_elements; | ||||||
|  |         uint ii1; | ||||||
|  |         if (nei0_is_pow2) { | ||||||
|  |             ii1 = i >> nei0shift; | ||||||
|  |         } else { | ||||||
|  |             ii1 = i / p.nei0; | ||||||
|  |         } | ||||||
|  |         uint ii0 = i - ii1 * p.nei0; | ||||||
|  |         uint id = ids[iter++]; | ||||||
|  |         uvec4 ballot = subgroupBallot(in_range && id == expert_idx); | ||||||
|  |  | ||||||
|  |         ballots_sh[gl_SubgroupID] = ballot; | ||||||
|  |         barrier(); | ||||||
|  |  | ||||||
|  |         uint subgroup_base = 0; | ||||||
|  |         uint total = 0; | ||||||
|  |         for (uint k = 0; k < gl_NumSubgroups; ++k) { | ||||||
|  |             if (k == gl_SubgroupID) { | ||||||
|  |                 subgroup_base = total; | ||||||
|  |             } | ||||||
|  |             total += subgroupBallotBitCount(ballots_sh[k]); | ||||||
|  |         } | ||||||
|  |         barrier(); | ||||||
|  |  | ||||||
|  |         uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); | ||||||
|  |         if (in_range && id == expert_idx) { | ||||||
|  |             row_ids[_ne1 + idx] = u16vec2(ii0, ii1); | ||||||
|  |         } | ||||||
|  |         _ne1 += total; | ||||||
|  |         iter &= 15; | ||||||
|  |     } | ||||||
|  |     barrier(); | ||||||
|  | } | ||||||
| #endif | #endif | ||||||
| #endif // MUL_MAT_ID | #endif // MUL_MAT_ID | ||||||
|  |  | ||||||
| #define NUM_WARPS (BLOCK_SIZE / WARP) |  | ||||||
|  |  | ||||||
| #ifdef COOPMAT | #ifdef COOPMAT | ||||||
| shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; | shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; | ||||||
| #endif | #endif | ||||||
| @@ -178,44 +236,11 @@ void main() { | |||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
| #ifdef COOPMAT | #ifdef COOPMAT | ||||||
|     // Spread the search across all elements in the first subgroup |     if (bitCount(p.nei0) == 1) { | ||||||
|     if (gl_SubgroupID == 0) { |         load_row_ids(expert_idx, true); | ||||||
|         _ne1 = 0; |     } else { | ||||||
|         uint num_elements = p.nei1 * p.nei0; |         load_row_ids(expert_idx, false); | ||||||
|  |  | ||||||
|         uint ids[16]; |  | ||||||
|         uint iter = 0; |  | ||||||
|  |  | ||||||
|         for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { |  | ||||||
|             // prefetch up to 16 elements |  | ||||||
|             if (iter == 0) { |  | ||||||
|                 [[unroll]] for (uint k = 0; k < 16; ++k) { |  | ||||||
|                     uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; |  | ||||||
|                     bool in_range = i < num_elements; |  | ||||||
|                     uint ii1 = i / p.nei0; |  | ||||||
|                     uint ii0 = i % p.nei0; |  | ||||||
|                     ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             uint i = j + gl_SubgroupInvocationID; |  | ||||||
|             bool in_range = i < num_elements; |  | ||||||
|             uint ii1 = i / p.nei0; |  | ||||||
|             uint ii0 = i % p.nei0; |  | ||||||
|             uint id = ids[iter++]; |  | ||||||
|             uvec4 ballot = subgroupBallot(in_range && id == expert_idx); |  | ||||||
|             uint idx = subgroupBallotExclusiveBitCount(ballot); |  | ||||||
|             if (in_range && id == expert_idx) { |  | ||||||
|                 row_ids[_ne1 + idx] = u16vec2(ii0, ii1); |  | ||||||
|             } |  | ||||||
|             _ne1 += subgroupBallotBitCount(ballot); |  | ||||||
|             iter &= 15; |  | ||||||
|         } |  | ||||||
|         _ne1_sh = _ne1; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     barrier(); |  | ||||||
|  |  | ||||||
|     _ne1 = _ne1_sh; |  | ||||||
| #else | #else | ||||||
|     _ne1 = 0; |     _ne1 = 0; | ||||||
|     for (uint ii1 = 0; ii1 < p.nei1; ii1++) { |     for (uint ii1 = 0; ii1 < p.nei1; ii1++) { | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ | |||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #include "types.comp" | #include "types.comp" | ||||||
|  | #include "utils.comp" | ||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
| @@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { | |||||||
| }; | }; | ||||||
|  |  | ||||||
| uint _ne1; | uint _ne1; | ||||||
| shared uint _ne1_sh; | layout (constant_id = 5) const uint subgroup_size = 32; | ||||||
|  | shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; | ||||||
|  |  | ||||||
| B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||||
| { | { | ||||||
| @@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem | |||||||
|     return elem; |     return elem; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void load_row_ids(uint expert_idx, bool nei0_is_pow2) { | ||||||
|  |     _ne1 = 0; | ||||||
|  |     uint num_elements = p.nei1 * p.nei0; | ||||||
|  |     uint nei0shift = findLSB(p.nei0); | ||||||
|  |  | ||||||
|  |     uint ids[16]; | ||||||
|  |     uint iter = 0; | ||||||
|  |  | ||||||
|  |     for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { | ||||||
|  |         // prefetch up to 16 elements | ||||||
|  |         if (iter == 0) { | ||||||
|  |             [[unroll]] for (uint k = 0; k < 16; ++k) { | ||||||
|  |                 uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; | ||||||
|  |                 bool in_range = i < num_elements; | ||||||
|  |                 uint ii1; | ||||||
|  |                 if (nei0_is_pow2) { | ||||||
|  |                     ii1 = i >> nei0shift; | ||||||
|  |                 } else { | ||||||
|  |                     ii1 = i / p.nei0; | ||||||
|  |                 } | ||||||
|  |                 uint ii0 = i - ii1 * p.nei0; | ||||||
|  |                 ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         uint i = j + gl_LocalInvocationIndex; | ||||||
|  |         bool in_range = i < num_elements; | ||||||
|  |         uint ii1; | ||||||
|  |         if (nei0_is_pow2) { | ||||||
|  |             ii1 = i >> nei0shift; | ||||||
|  |         } else { | ||||||
|  |             ii1 = i / p.nei0; | ||||||
|  |         } | ||||||
|  |         uint ii0 = i - ii1 * p.nei0; | ||||||
|  |         uint id = ids[iter++]; | ||||||
|  |         uvec4 ballot = subgroupBallot(in_range && id == expert_idx); | ||||||
|  |  | ||||||
|  |         ballots_sh[gl_SubgroupID] = ballot; | ||||||
|  |         barrier(); | ||||||
|  |  | ||||||
|  |         uint subgroup_base = 0; | ||||||
|  |         uint total = 0; | ||||||
|  |         for (uint k = 0; k < gl_NumSubgroups; ++k) { | ||||||
|  |             if (k == gl_SubgroupID) { | ||||||
|  |                 subgroup_base = total; | ||||||
|  |             } | ||||||
|  |             total += subgroupBallotBitCount(ballots_sh[k]); | ||||||
|  |         } | ||||||
|  |         barrier(); | ||||||
|  |  | ||||||
|  |         uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); | ||||||
|  |         if (in_range && id == expert_idx) { | ||||||
|  |             row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); | ||||||
|  |         } | ||||||
|  |         _ne1 += total; | ||||||
|  |         iter &= 15; | ||||||
|  |     } | ||||||
|  |     barrier(); | ||||||
|  | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| void main() { | void main() { | ||||||
| @@ -157,45 +217,12 @@ void main() { | |||||||
|     const uint ic = gl_WorkGroupID.y; |     const uint ic = gl_WorkGroupID.y; | ||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
|     // Spread the search across all elements in the first subgroup |     if (bitCount(p.nei0) == 1) { | ||||||
|     if (gl_SubgroupID == 0) { |         load_row_ids(expert_idx, true); | ||||||
|         _ne1 = 0; |     } else { | ||||||
|         uint num_elements = p.nei1 * p.nei0; |         load_row_ids(expert_idx, false); | ||||||
|  |  | ||||||
|         uint ids[16]; |  | ||||||
|         uint iter = 0; |  | ||||||
|  |  | ||||||
|         for (uint j = 0; j < num_elements; j += gl_SubgroupSize) { |  | ||||||
|             // prefetch up to 16 elements |  | ||||||
|             if (iter == 0) { |  | ||||||
|                 [[unroll]] for (uint k = 0; k < 16; ++k) { |  | ||||||
|                     uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize; |  | ||||||
|                     bool in_range = i < num_elements; |  | ||||||
|                     uint ii1 = i / p.nei0; |  | ||||||
|                     uint ii0 = i % p.nei0; |  | ||||||
|                     ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             uint i = j + gl_SubgroupInvocationID; |  | ||||||
|             bool in_range = i < num_elements; |  | ||||||
|             uint ii1 = i / p.nei0; |  | ||||||
|             uint ii0 = i % p.nei0; |  | ||||||
|             uint id = ids[iter++]; |  | ||||||
|             uvec4 ballot = subgroupBallot(in_range && id == expert_idx); |  | ||||||
|             uint idx = subgroupBallotExclusiveBitCount(ballot); |  | ||||||
|             if (in_range && id == expert_idx) { |  | ||||||
|                 row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); |  | ||||||
|             } |  | ||||||
|             _ne1 += subgroupBallotBitCount(ballot); |  | ||||||
|             iter &= 15; |  | ||||||
|         } |  | ||||||
|         _ne1_sh = _ne1; |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     barrier(); |  | ||||||
|  |  | ||||||
|     _ne1 = _ne1_sh; |  | ||||||
|  |  | ||||||
|     // Workgroup has no work |     // Workgroup has no work | ||||||
|     if (ic * BN >= _ne1) return; |     if (ic * BN >= _ne1) return; | ||||||
| #endif | #endif | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz