mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
CUDA: add dynamic shared mem to softmax, refactor general usage (#14497)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "ggml.h"
|
||||
#include "softmax.cuh"
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ float t2f32(T val) {
|
||||
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
template<int... Ns, typename T>
|
||||
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
|
||||
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||
{
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
auto launch_kernel = [=](auto I) -> bool {
|
||||
constexpr int ncols = decltype(I)::value;
|
||||
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||
|
||||
if (p.ncols == ncols) {
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, p);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// unary fold over launch_kernel
|
||||
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//default case
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, params);
|
||||
break;
|
||||
}
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
||||
|
||||
Reference in New Issue
Block a user