diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ac2895b516..27005a053d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group( uint3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head @@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group( const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - threadgroup_barrier(mem_flags::mem_threadgroup); - - float sumf = 0.0f; - const int64_t i = tpitg.x + i1*nc; const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); - sumf += state * C[tpitg.x]; s[i] = state; + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[tpitg.x]; + + // Sum the threads in the simd group => simd sum sumf = simd_sum(sumf); - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Use the shared buffer to hold the sum of each simd group + // Once per simd group, place the group sum into the shared buffer if (tiisg == 0) { shared[sgitg] = sumf; } + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. threadgroup_barrier(mem_flags::mem_threadgroup); - // Sum the simd buckets - sumf = shared[tiisg]; - sumf = simd_sum(sumf); + // Sum the simd buckets => threadgroup sum + sumf = 0.0f; + for (int64_t i0 = 0; i0 < sgptg; ++i0) { + sumf += shared[i0]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); y[0] = sumf;