cuda : fix argsort with 64k+ rows (#16849)

This commit is contained in:
Sigbjørn Skjæret
2025-10-30 08:56:28 +01:00
committed by GitHub
parent d7395115ba
commit 229bf68628
2 changed files with 4 additions and 3 deletions

View File

@@ -87,7 +87,7 @@ template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
int row = blockIdx.x;
if (col >= ncols_pad) {
return;
@@ -151,7 +151,7 @@ static void argsort_f32_i32_cuda_bitonic(const float * x,
const int ncols_pad = next_power_of_2(ncols);
const dim3 block_dims(ncols_pad, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// FIXME: this limit could be raised by ~2-4x on Ampere or newer