mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
ggml : use std::sort in ggml_argsort CPU implementation (#17211)
* ggml : use std::sort in ggml_argsort CPU implementation * cont : add missing header
This commit is contained in:
@@ -7,8 +7,9 @@
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32(
|
||||
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
for (int64_t i = ith; i < nr; i += nth) {
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
const float * src_data = (float *)((char *) src0->data + i*nb01);
|
||||
|
||||
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
||||
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
dst_data[j] = j;
|
||||
}
|
||||
|
||||
// C doesn't have a functional sort, so we do a bubble sort instead
|
||||
for (int64_t j = 0; j < ne0; j++) {
|
||||
for (int64_t k = j + 1; k < ne0; k++) {
|
||||
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
||||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
||||
int32_t tmp = dst_data[j];
|
||||
dst_data[j] = dst_data[k];
|
||||
dst_data[k] = tmp;
|
||||
}
|
||||
}
|
||||
std::function<bool(int32_t, int32_t)> cmp;
|
||||
|
||||
// note: this might be causing memory allocations? ideally should be avoided if it's the case
|
||||
switch (order) {
|
||||
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
|
||||
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
|
||||
default: GGML_ABORT("invalid sort order");
|
||||
}
|
||||
|
||||
std::sort(dst_data, dst_data + ne0, cmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user