mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-15 11:17:31 +00:00
* Factor out `reduce_rows_f32` from common.cuh This increases iteration cycle speed by not having to recompile every kernel all the time * Hide memory-latency by loop unrolling in reduce_rows_f32 * Further optimizations to `reduce_rows_f32` 1. Increase threadblock size to better hide latency of memory requests. As a consequence of bigger threadblocks, do 2-step summation, using shared memory to communicate results between invocations 2. Use sum_temp array to reduce waits on sum 3. Adjust num_unroll to reflext bigger threadblock 4. Improve default block_dims, increase support for more block_dims * Add perf tests for `reduce_rows_f32` kernel * Add heuristic to toggle 128/512 threads based on sm count Break even point was the minimum of the following multiples. | GPU Model | Nrow SM Count Multiple | | ----------- | ----------- | | RTX 4000 SFF ADA | 2.0x | | RTX 6000 ADA | 2.5x | | RTX PRO 6000 Blackwell Max-Q | 3.04x | | RTX PRO 4500 Blackwell | 3.15x | * Ensure perf gains also for small ncols and large nrows Alternative to this, one could have also made the number of unrollings template-able, but that would require compiling the kernel multiple times, increasing binary size unnecessarily * Modify perf and unit-tests * Apply auto-formatting by clang * Fix CI build failure See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486 Building with VS generator worked though. * Remove sm_count property from `ggml_backend_cuda_context` Requested by @JohannesGaessler, and should fix remaining CI issues as a side-effect * Add CUB-based implementation for GGML_OP_MEAN Currently this branch is only executed for nrows==1 * Add heuristics to execute CUB branch only when it brings perf Heuristics were determined on the following HW: * RTX 4000 SFF ADA * RTX 6000 ADA * RTX PRO 6000 Blackwell Max-Q * RTX PRO 4500 Blackwell * Add unit-test for CUB-based mean Tests should run with CUDA Graphs enabled per default on NVGPUs * Rename `USE_CUB` to `GGML_CUDA_USE_CUB` Suggested by @JohannesGaessler * Unindent Preprocessor directives See https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506
42 lines
1.3 KiB
Plaintext
42 lines
1.3 KiB
Plaintext
#include "sum.cuh"
|
|
#include "sumrows.cuh"
|
|
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
#include <cub/cub.cuh>
|
|
using namespace cub;
|
|
#endif // GGML_CUDA_USE_CUB
|
|
|
|
#include <cstdint>
|
|
|
|
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
size_t tmp_size = 0;
|
|
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
|
|
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
|
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
|
|
#else
|
|
// Use (inefficient) sum_rows implementation as a fallback.
|
|
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
|
|
sum_rows_f32_cuda(x, dst, ne, 1, stream);
|
|
GGML_UNUSED(pool);
|
|
#endif // GGML_CUDA_USE_CUB
|
|
}
|
|
|
|
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
|
|
|
|
const float * src0_d = (const float *) src0->data;
|
|
float * dst_d = (float *) dst->data;
|
|
|
|
const int64_t ne = ggml_nelements(src0);
|
|
|
|
ggml_cuda_pool & pool = ctx.pool();
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
|
|
}
|