diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0397cd9b53..b7c474de86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1840,16 +1840,19 @@ kernel void kernel_ssm_scan_f32_group( // sum of the individual simd groups. threadgroup_barrier(mem_flags::mem_threadgroup); - // Sum the simd buckets => threadgroup sum + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum sumf = 0.0f; - for (int64_t i0 = 0; i0 < sgptg; ++i0) { - sumf += shared[i0]; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - y[0] = sumf; - // recurse s0 = s; }