CUDA: use registers instead of smem in topk-moe (#16647)

Uses the technique used in the vulkan PR #16641. Neat trick!
This commit is contained in:
Aman Gupta
2025-10-18 17:52:53 +08:00
committed by GitHub
parent 81387858f1
commit 38355c6c8e

View File

@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt_sum = 0.f; float wt_sum = 0.f;
extern __shared__ float data_topk_shared[]; float output_weights[experts_per_thread];
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
for (int k = 0; k < n_expert_used; k++) { for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0]; float max_val = wt[0];
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
} }
} }
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY; wt[max_expert / WARP_SIZE] = -INFINITY;
wt_shared_ptr[k] = max_val; ids[k] = max_expert;
ids[k] = max_expert;
if constexpr (with_norm) { if constexpr (with_norm) {
wt_sum += max_val; wt_sum += max_val;
} }
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; output_weights[i] *= inv_sum;
} }
} }
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { #pragma unroll
weights[i] = wt_shared_ptr[i]; for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
}
} }
} }
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
dim3 block_dims(WARP_SIZE, rows_per_block, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1);
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
switch (n_expert) { switch (n_expert) {
case 1: case 1:
topk_moe_cuda<1, with_norm> topk_moe_cuda<1, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 2: case 2:
topk_moe_cuda<2, with_norm> topk_moe_cuda<2, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 4: case 4:
topk_moe_cuda<4, with_norm> topk_moe_cuda<4, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 8: case 8:
topk_moe_cuda<8, with_norm> topk_moe_cuda<8, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 16: case 16:
topk_moe_cuda<16, with_norm> topk_moe_cuda<16, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 32: case 32:
topk_moe_cuda<32, with_norm> topk_moe_cuda<32, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 64: case 64:
topk_moe_cuda<64, with_norm> topk_moe_cuda<64, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 128: case 128:
topk_moe_cuda<128, with_norm> topk_moe_cuda<128, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 256: case 256:
topk_moe_cuda<256, with_norm> topk_moe_cuda<256, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
case 512: case 512:
topk_moe_cuda<512, with_norm> topk_moe_cuda<512, with_norm>
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break; break;
default: default:
GGML_ASSERT(false && "fatal error"); GGML_ASSERT(false && "fatal error");