Files
llama.cpp/ggml/src/ggml-cuda/sumrows.cu
Oliver Simons 6028bf7435 CUDA: Optimize reduce_rows_f32 kernel, leading up to 25x perf improvement on kernel-level and 10% perf increase for Gemma3n (#15132)
* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506
2025-08-13 10:04:46 +02:00

44 lines
1.8 KiB
Plaintext

#include "reduce_rows.cuh"
#include "sumrows.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const dim3 block_nums(nrows, 1, 1);
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_nums(nrows, 1, 1);
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
if ((nrows / nsm) < 2) {
// Increase num threads to 512 for small nrows to better hide the latency
const dim3 block_dims(512, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
} else {
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
}