From 3866f766fe8e508354013ea9e10a5c87b31e7681 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 21 Jul 2025 09:31:43 -0600 Subject: [PATCH] feat: Use local variable for state recursion Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart --- ggml/src/ggml-metal/ggml-metal.metal | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c50236ba49..0397cd9b53 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1788,8 +1788,11 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = tpitg.x + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); @@ -1809,9 +1812,8 @@ kernel void kernel_ssm_scan_f32_group( const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - const int64_t i = tpitg.x + i1*nc; - const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); - s[i] = state; + const float state = (s0 * dA) + (B[tpitg.x] * x_dt); + s = state; // Parallel sum: This relies on the fact that this kernel will be // dispatched with each threadgroup having (d_state, 1, 1) threads which @@ -1851,6 +1853,9 @@ kernel void kernel_ssm_scan_f32_group( // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } kernel void kernel_rwkv_wkv6_f32(