mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
metal : faster argsort (#17315)
* metal : faster argsort * cont : keep data in registers
This commit is contained in:
@@ -3726,8 +3726,6 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
std::swap(bid_dst, bid_tmp);
|
||||
|
||||
@@ -4739,12 +4739,13 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int im = tgpig[0] / args.ne01;
|
||||
int i01 = tgpig[0] % args.ne01;
|
||||
int i02 = tgpig[1];
|
||||
int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2*args.len);
|
||||
const int im = tgpig[0] / args.ne01;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
||||
@@ -4768,54 +4769,101 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
for (int k = tpitg.x; k < (int) total; k += ntg.x) {
|
||||
// find partition (i,j) such that i+j = k
|
||||
int low = k > len1 ? k - len1 : 0;
|
||||
int high = MIN(k, len0);
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k - mid - 1];
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = min(k0 + chunk, total);
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
if (val0 <= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
} else {
|
||||
if (val0 >= val1) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
int low = k0 > len1 ? k0 - len1 : 0;
|
||||
int high = MIN(k0, len0);
|
||||
|
||||
// binary-search partition (i, j) such that i + j = k
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
bool take_left;
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
const int i = low;
|
||||
const int j = k - i;
|
||||
if (take_left) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
int i = low;
|
||||
int j = k0 - i;
|
||||
|
||||
// keep the merge fronts into registers
|
||||
int32_t idx0 = 0;
|
||||
float val0 = 0.0f;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
|
||||
int32_t idx1 = 0;
|
||||
float val1 = 0.0f;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
|
||||
for (int k = k0; k < k1; ++k) {
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
out_idx = tmp1[j];
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp1[j++];
|
||||
}
|
||||
break;
|
||||
} else if (j >= len1) {
|
||||
out_idx = tmp0[i];
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp0[i++];
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
const int32_t idx0 = tmp0[i];
|
||||
const int32_t idx1 = tmp1[j];
|
||||
bool take_left;
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
out_idx = (order == GGML_SORT_ORDER_ASC)
|
||||
? (val0 <= val1 ? idx0 : idx1)
|
||||
: (val0 >= val1 ? idx0 : idx1);
|
||||
if (take_left) {
|
||||
out_idx = idx0;
|
||||
++i;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
} else {
|
||||
out_idx = idx1;
|
||||
++j;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
|
||||
Reference in New Issue
Block a user