mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
cuda : fix argsort with 64k+ rows (#16849)
This commit is contained in:
@@ -87,7 +87,7 @@ template<ggml_sort_order order>
|
|||||||
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
int col = threadIdx.x;
|
int col = threadIdx.x;
|
||||||
int row = blockIdx.y;
|
int row = blockIdx.x;
|
||||||
|
|
||||||
if (col >= ncols_pad) {
|
if (col >= ncols_pad) {
|
||||||
return;
|
return;
|
||||||
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
|
|||||||
const int ncols_pad = next_power_of_2(ncols);
|
const int ncols_pad = next_power_of_2(ncols);
|
||||||
|
|
||||||
const dim3 block_dims(ncols_pad, 1, 1);
|
const dim3 block_dims(ncols_pad, 1, 1);
|
||||||
const dim3 block_nums(1, nrows, 1);
|
const dim3 block_nums(nrows, 1, 1);
|
||||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
|
||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||||
|
|||||||
@@ -7111,7 +7111,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection)
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||||
}
|
}
|
||||||
|
|
||||||
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
|
||||||
|
|||||||
Reference in New Issue
Block a user