mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
CUDA: use registers instead of smem in topk-moe (#16647)
Uses the technique used in the vulkan PR #16641. Neat trick!
This commit is contained in:
@@ -73,8 +73,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||||||
|
|
||||||
float wt_sum = 0.f;
|
float wt_sum = 0.f;
|
||||||
|
|
||||||
extern __shared__ float data_topk_shared[];
|
float output_weights[experts_per_thread];
|
||||||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
|
||||||
|
|
||||||
for (int k = 0; k < n_expert_used; k++) {
|
for (int k = 0; k < n_expert_used; k++) {
|
||||||
float max_val = wt[0];
|
float max_val = wt[0];
|
||||||
@@ -99,11 +98,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||||
|
output_weights[k / WARP_SIZE] = max_val;
|
||||||
|
}
|
||||||
|
|
||||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||||
|
|
||||||
wt_shared_ptr[k] = max_val;
|
ids[k] = max_expert;
|
||||||
ids[k] = max_expert;
|
|
||||||
if constexpr (with_norm) {
|
if constexpr (with_norm) {
|
||||||
wt_sum += max_val;
|
wt_sum += max_val;
|
||||||
}
|
}
|
||||||
@@ -115,12 +117,16 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||||||
const float inv_sum = 1.0f / wt_sum;
|
const float inv_sum = 1.0f / wt_sum;
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
output_weights[i] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
#pragma unroll
|
||||||
weights[i] = wt_shared_ptr[i];
|
for (int i = 0; i < experts_per_thread; i++) {
|
||||||
|
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||||
|
if (idx < n_expert_used) {
|
||||||
|
weights[idx] = output_weights[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,48 +143,46 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
|
||||||
|
|
||||||
switch (n_expert) {
|
switch (n_expert) {
|
||||||
case 1:
|
case 1:
|
||||||
topk_moe_cuda<1, with_norm>
|
topk_moe_cuda<1, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
topk_moe_cuda<2, with_norm>
|
topk_moe_cuda<2, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
topk_moe_cuda<4, with_norm>
|
topk_moe_cuda<4, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
topk_moe_cuda<8, with_norm>
|
topk_moe_cuda<8, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
topk_moe_cuda<16, with_norm>
|
topk_moe_cuda<16, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
topk_moe_cuda<32, with_norm>
|
topk_moe_cuda<32, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
topk_moe_cuda<64, with_norm>
|
topk_moe_cuda<64, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
topk_moe_cuda<128, with_norm>
|
topk_moe_cuda<128, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
topk_moe_cuda<256, with_norm>
|
topk_moe_cuda<256, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
topk_moe_cuda<512, with_norm>
|
topk_moe_cuda<512, with_norm>
|
||||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "fatal error");
|
GGML_ASSERT(false && "fatal error");
|
||||||
|
|||||||
Reference in New Issue
Block a user