perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state

This is a first attempt at optimizing the metal kernel. The changes here
are:

- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
  computation

When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-07-15 17:04:31 -06:00
parent ba74a24730
commit 8d5a25d356
2 changed files with 32 additions and 11 deletions

View File

@@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
@@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(
if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

View File

@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
@@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
@@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
const int64_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
threadgroup_barrier(mem_flags::mem_threadgroup);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
sumf += state * C[tpitg.x];
s[i] = state;
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Use the shared buffer to hold the sum of each simd group
if (tiisg == 0) {
shared[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Sum the simd buckets
sumf = shared[tiisg];
sumf = simd_sum(sumf);
y[0] = sumf;
// recurse