feat: Parallel sum in SSM_CONV

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-07-22 11:37:43 -06:00
parent 80545ef568
commit 16bc059660
2 changed files with 52 additions and 9 deletions

View File

@@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
const int64_t d_state = ne10;
// One shared memory bucket for each simd group in the threadgroup
if (d_state >= 32) {
const int64_t shmem_size = d_state / 32;
// The final simd_sum won't work if the number of simd groups is
// larger than the size of a single simd group. If this case is
// hit at some point, the logic in the second simd_sum could be
// expanded to handle this with one more sequential simd_sum to
// collapse simd group sums another time.
GGML_ASSERT(shmem_size <= 32);
// One thread pre element in d_state
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
}
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} break;
case GGML_OP_SSM_SCAN:
{

View File

@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
device const void * src0,
device const void * src1,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_conv & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
const int64_t i0 = tpitg.x;
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
@@ -1681,14 +1687,32 @@ kernel void kernel_ssm_conv_f32(
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
float sumf = s[i0] * c[i0];
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
// Parallel sum: first sum over threads in simd group, then sum over simd
// group sums
sumf = simd_sum(sumf);
// If multiple simd groups per threadgroup, sum over simd group sums
if (sgptg > 1) {
if (tiisg == 0) {
shared[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = 0.0f;
if (sgitg == 0) {
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
x[0] = sumf;
}
}
} else if (tiisg == 0) {
x[0] = sumf;
}
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(