diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8a9f5980ea..d0976519f2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1082,6 +1082,7 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t nrows; int32_t order; }; @@ -8708,6 +8709,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; case GGML_OP_ARGSORT: elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); break; case GGML_OP_IM2COL: { @@ -9954,9 +9956,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c int32_t * op_params = (int32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { ncols, + nrows, op_params[0], }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c81b84452e..c4e68bc023 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint nrows; uint order; } p; @@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) { dst_row[idx1] = tmp; } -void argsort(bool needs_bounds_check) { +void argsort(bool needs_bounds_check, const uint row) { // bitonic sort const int col = int(gl_LocalInvocationID.x); - const uint row = gl_WorkGroupID.y; const uint row_offset = row * p.ncols; @@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) { void main() { if (p.ncols == BLOCK_SIZE) { - argsort(false); + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } } else { - argsort(true); + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } } }