ggml : use std::sort in ggml_argsort CPU implementation (#17211)

* ggml : use std::sort in ggml_argsort CPU implementation

* cont : add missing header
This commit is contained in:
Georgi Gerganov
2025-11-12 20:43:38 +02:00
committed by GitHub
parent 8e878f0cb4
commit 374fe09cdd

View File

@@ -7,8 +7,9 @@
#include "unary-ops.h" #include "unary-ops.h"
#include "vec.h" #include "vec.h"
#include <float.h> #include <cfloat>
#include <algorithm> #include <algorithm>
#include <functional>
// ggml_compute_forward_dup // ggml_compute_forward_dup
@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32(
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) { for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const float * src_data = (float *)((char *) src0->data + i*nb01); const float * src_data = (float *)((char *) src0->data + i*nb01);
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
for (int64_t j = 0; j < ne0; j++) { for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j; dst_data[j] = j;
} }
// C doesn't have a functional sort, so we do a bubble sort instead std::function<bool(int32_t, int32_t)> cmp;
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) { // note: this might be causing memory allocations? ideally should be avoided if it's the case
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || switch (order) {
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
int32_t tmp = dst_data[j]; case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
dst_data[j] = dst_data[k]; default: GGML_ABORT("invalid sort order");
dst_data[k] = tmp;
}
}
} }
std::sort(dst_data, dst_data + ne0, cmp);
} }
} }