metal: optimise GGML_OP_SUM (#16559)

* optimise GGML_OP_SUM

* add non-contiguous tests by permuting the input

* change tests to require full contiguity of OP_SUM

* cuda : add check GGML_OP_SUM

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Sam/Samuel
2025-10-15 23:05:56 +09:00
committed by GitHub
parent 17304cbcc1
commit f4ce81c45e
5 changed files with 71 additions and 11 deletions

View File

@@ -3625,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_ACC:
return true;
case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_ARGSORT:
// TODO: Support arbitrary column width
return op->src[0]->ne[0] <= 1024;

View File

@@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:

View File

@@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
int nth = 32; // SIMD width
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, (int) n);
const int nsg = (nth + 31) / 32;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
return 1;
}

View File

@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (tiitg != 0) {
if (args.np == 0) {
return;
}
float acc = 0.0f;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
dst[0] = acc;
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
template <bool norm>