From 16bc059660c1c59e566628201c0ca2c20c9f4bc3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 22 Jul 2025 11:37:43 -0600 Subject: [PATCH] feat: Parallel sum in SSM_CONV Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.m | 21 ++++++++++++++- ggml/src/ggml-metal/ggml-metal.metal | 40 ++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9c3bba5f3e..51ea6d217b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&args length:sizeof(args) atIndex:3]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + const int64_t d_state = ne10; + + // One shared memory bucket for each simd group in the threadgroup + if (d_state >= 32) { + const int64_t shmem_size = d_state / 32; + + // The final simd_sum won't work if the number of simd groups is + // larger than the size of a single simd group. If this case is + // hit at some point, the logic in the second simd_sum could be + // expanded to handle this with one more sequential simd_sum to + // collapse simd group sums another time. + GGML_ASSERT(shmem_size <= 32); + + // One thread pre element in d_state + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } break; case GGML_OP_SSM_SCAN: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4ffa56d45b..c687145790 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32( device const void * src0, device const void * src1, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_conv & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t ir = tgpig.x; const int64_t i2 = tgpig.y; const int64_t i3 = tgpig.z; @@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32( device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); - float sumf = 0.0f; + float sumf = s[i0] * c[i0]; - for (int64_t i0 = 0; i0 < nc; ++i0) { - sumf += s[i0] * c[i0]; + // Parallel sum: first sum over threads in simd group, then sum over simd + // group sums + sumf = simd_sum(sumf); + + // If multiple simd groups per threadgroup, sum over simd group sums + if (sgptg > 1) { + if (tiisg == 0) { + shared[sgitg] = sumf; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + x[0] = sumf; + } + } + } else if (tiisg == 0) { + x[0] = sumf; } - - x[0] = sumf; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part