diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 08095dcf06..e61b097833 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -943,6 +943,34 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ARGSORT); + + char base[256]; + char name[256]; + + ggml_sort_order order = (ggml_sort_order) op->op_params[0]; + + const char * order_str = "undefined"; + switch (order) { + case GGML_SORT_ORDER_ASC: order_str = "asc"; break; + case GGML_SORT_ORDER_DESC: order_str = "desc"; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, const struct ggml_tensor * op, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 5a8bc0c1cc..5539abda33 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 69c8820854..741b1a44db 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -904,8 +904,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ARGSORT: - // TODO: Support arbitrary column width - return op->src[0]->ne[0] <= 1024; case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 6d02befa97..dd889cd90d 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -793,10 +793,28 @@ typedef struct { } ggml_metal_kargs_leaky_relu; typedef struct { - int64_t ncols; - int64_t ncols_pad; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; } ggml_metal_kargs_argsort; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t len; +} ggml_metal_kargs_argsort_merge; + typedef struct { int64_t ne0; float start; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c48f7cd29f..ae098d371f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3530,38 +3530,95 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); - const int64_t nrows = ggml_nrows(op->src[0]); + // bitonic sort requires the number of elements to be power of 2 + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nth *= 2; + } + + const int nptg = (ne00 + nth - 1)/nth; // Metal kernels require the buffer size to be multiple of 16 bytes // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += ggml_nbytes(op); + + if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) { + std::swap(bid_dst, bid_tmp); + } ggml_metal_kargs_argsort args = { - /*.ncols =*/ ne00, - /*.ncols_pad =*/ ne00_padded + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, }; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); + + int len = nth; + + while (len < ne00) { + ggml_metal_op_concurrency_reset(ctx); + + ggml_metal_kargs_argsort_merge args_merge = { + .ne00 = ne00, + .ne01 = ne01, + .ne02 = ne02, + .ne03 = ne03, + .nb00 = nb00, + .nb01 = nb01, + .nb02 = nb02, + .nb03 = nb03, + .len = len, + }; + + // merges per row + const int nm = (ne00 + 2*len - 1) / (2*len); + + const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)); + + ggml_metal_encoder_set_pipeline(enc, pipeline_merge); + ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); + + std::swap(bid_dst, bid_tmp); + + len <<= 1; + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 7afc881fa7..35f07f3e71 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -197,6 +197,10 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; + case GGML_OP_ARGSORT: + { + res *= 2; + } break; default: break; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7f94419c3a..8afc7318f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4541,69 +4541,179 @@ kernel void kernel_timestep_embedding_f32( // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( constant ggml_metal_kargs_argsort & args, - device const float * x, + device const char * src0, device int32_t * dst, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); + threadgroup int32_t * smem_i32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]); template kernel void kernel_argsort_f32_i32( constant ggml_metal_kargs_argsort & args, - device const float * x, + device const char * src0, device int32_t * dst, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { + threadgroup int32_t * smem_i32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; + const int col = tpitg[0]; - if (col >= args.ncols_pad) return; + const int i00 = (tgpig[0]/args.ne01)*ntg.x; + const int i01 = tgpig[0]%args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; - device const float * x_row = x + row * args.ncols; - threadgroup int32_t * dst_row = shared_values; + device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); // initialize indices - dst_row[col] = col; + smem_i32[col] = i00 + col; threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= args.ncols_pad; k *= 2) { + for (int k = 2; k <= ntg.x; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (dst_row[col] >= args.ncols || - (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) + if (smem_i32[col] >= args.ne00 || + (smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[smem_i32[col]] > x_row[smem_i32[ixj]] : + x_row[smem_i32[col]] < x_row[smem_i32[ixj]])) ) { - SWAP(dst_row[col], dst_row[ixj]); + SWAP(smem_i32[col], smem_i32[ixj]); } } else { - if (dst_row[ixj] >= args.ncols || - (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) + if (smem_i32[ixj] >= args.ne00 || + (smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[smem_i32[col]] < x_row[smem_i32[ixj]] : + x_row[smem_i32[col]] > x_row[smem_i32[ixj]])) ) { - SWAP(dst_row[col], dst_row[ixj]); + SWAP(smem_i32[col], smem_i32[ixj]); } } } + threadgroup_barrier(mem_flags::mem_threadgroup); } } // copy the result to dst without the padding - if (col < args.ncols) { - dst[row * args.ncols + col] = dst_row[col]; + if (i00 + col < args.ne00) { + dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; + + dst[col] = smem_i32[col]; } } template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +typedef void (argsort_merge_t)( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_argsort_merge_f32_i32( + constant ggml_metal_kargs_argsort_merge & args, + device const char * src0, + device const int32_t * tmp, + device int32_t * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + int im = tgpig[0] / args.ne01; + int i01 = tgpig[0] % args.ne01; + int i02 = tgpig[1]; + int i03 = tgpig[2]; + + const int start = im * (2*args.len); + + const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start))); + const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len))); + + const int total = len0 + len1; + + device const int32_t * tmp0 = tmp + start + + i01*args.ne00 + + i02*args.ne00*args.ne01 + + i03*args.ne00*args.ne01*args.ne02; + + device const int32_t * tmp1 = tmp0 + args.len; + + dst += start + + i01*args.ne00 + + i02*args.ne00*args.ne01 + + i03*args.ne00*args.ne01*args.ne02; + + device const float * src0_row = (device const float *)(src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + for (int k = tpitg.x; k < (int) total; k += ntg.x) { + // find partition (i,j) such that i+j = k + int low = k > len1 ? k - len1 : 0; + int high = MIN(k, len0); + + while (low < high) { + const int mid = (low + high) >> 1; + + const int32_t idx0 = tmp0[mid]; + const int32_t idx1 = tmp1[k - mid - 1]; + + const float val0 = src0_row[idx0]; + const float val1 = src0_row[idx1]; + + if (order == GGML_SORT_ORDER_ASC) { + if (val0 <= val1) { + low = mid + 1; + } else { + high = mid; + } + } else { + if (val0 >= val1) { + low = mid + 1; + } else { + high = mid; + } + } + } + + const int i = low; + const int j = k - i; + + int32_t out_idx; + + if (i >= len0) { + out_idx = tmp1[j]; + } else if (j >= len1) { + out_idx = tmp0[i]; + } else { + const int32_t idx0 = tmp0[i]; + const int32_t idx1 = tmp1[j]; + + const float val0 = src0_row[idx0]; + const float val1 = src0_row[idx1]; + + out_idx = (order == GGML_SORT_ORDER_ASC) + ? (val0 <= val1 ? idx0 : idx1) + : (val0 >= val1 ? idx0 : idx1); + } + + dst[k] = out_idx; + } +} + +template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; +template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; + kernel void kernel_leaky_relu_f32( constant ggml_metal_kargs_leaky_relu & args, device const float * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b11793963a..a7707eb03f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7492,8 +7492,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) }