mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
CUDA: use CUB for arbitary size argsort (#16754)
This commit is contained in:
@@ -1,5 +1,81 @@
|
|||||||
#include "argsort.cuh"
|
#include "argsort.cuh"
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
# include <cub/cub.cuh>
|
||||||
|
using namespace cub;
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||||
|
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int row = blockIdx.y;
|
||||||
|
|
||||||
|
if (col < ncols && row < nrows) {
|
||||||
|
indices[row * ncols + col] = col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx <= nrows) {
|
||||||
|
offsets[idx] = idx * ncols;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||||
|
const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||||
|
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||||
|
|
||||||
|
int * temp_indices = temp_indices_alloc.get();
|
||||||
|
float * temp_keys = temp_keys_alloc.get();
|
||||||
|
int * d_offsets = offsets_alloc.get();
|
||||||
|
|
||||||
|
static const int block_size = 256;
|
||||||
|
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||||
|
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||||
|
|
||||||
|
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||||
|
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||||
|
|
||||||
|
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||||
|
|
||||||
|
size_t temp_storage_bytes = 0;
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
|
temp_indices, dst, // values (indices)
|
||||||
|
ncols * nrows, nrows, // num items, num segments
|
||||||
|
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||||
|
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||||
|
sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||||
|
void * d_temp_storage = temp_storage_alloc.get();
|
||||||
|
|
||||||
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
|
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||||
|
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||||
|
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||||
|
0, sizeof(float) * 8, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_USE_CUB
|
||||||
|
|
||||||
|
// Bitonic sort implementation
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||||
T tmp = a;
|
T tmp = a;
|
||||||
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
|
|||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||||
|
int * dst,
|
||||||
|
const int ncols,
|
||||||
|
const int nrows,
|
||||||
|
ggml_sort_order order,
|
||||||
|
cudaStream_t stream) {
|
||||||
// bitonic sort requires ncols to be power of 2
|
// bitonic sort requires ncols to be power of 2
|
||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
|||||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||||
|
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
#ifdef GGML_CUDA_USE_CUB
|
||||||
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||||
|
|
||||||
|
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||||
|
ggml_cuda_pool & pool = ctx.pool();
|
||||||
|
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
} else {
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3642,8 +3642,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
return ggml_is_contiguous_rows(op->src[0]);
|
return ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
// TODO: Support arbitrary column width
|
#ifndef GGML_CUDA_USE_CUB
|
||||||
return op->src[0]->ne[0] <= 1024;
|
return op->src[0]->ne[0] <= 1024;
|
||||||
|
#else
|
||||||
|
return true;
|
||||||
|
#endif
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|||||||
Reference in New Issue
Block a user