fix: Update logic to correctly do the multi-layer parallel sum

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-07-18 10:49:06 -06:00
parent 8d5a25d356
commit e16e24bebd

View File

@@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group(
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
@@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group(
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
threadgroup_barrier(mem_flags::mem_threadgroup);
float sumf = 0.0f;
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
sumf += state * C[tpitg.x];
s[i] = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[tpitg.x];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Use the shared buffer to hold the sum of each simd group
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup);
// Sum the simd buckets
sumf = shared[tiisg];
sumf = simd_sum(sumf);
// Sum the simd buckets => threadgroup sum
sumf = 0.0f;
for (int64_t i0 = 0; i0 < sgptg; ++i0) {
sumf += shared[i0];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
y[0] = sumf;