mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state
This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
|
|||||||
/*.n_group =*/ n_group,
|
/*.n_group =*/ n_group,
|
||||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||||
/*.n_seqs =*/ n_seqs,
|
/*.n_seqs =*/ n_seqs,
|
||||||
|
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb01 =*/ nb01,
|
||||||
/*.nb02 =*/ nb02,
|
/*.nb02 =*/ nb02,
|
||||||
/*.nb03 =*/ nb03,
|
/*.nb03 =*/ nb03,
|
||||||
@@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(
|
|||||||
|
|
||||||
if (ne30 == 1) {
|
if (ne30 == 1) {
|
||||||
// Mamba-2
|
// Mamba-2
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(d_inner == 1);
|
GGML_ASSERT(d_inner == 1);
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
|||||||
@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||||
// TODO: optimize (e.g. by parallelizing over d_state)
|
|
||||||
kernel void kernel_ssm_scan_f32_group(
|
kernel void kernel_ssm_scan_f32_group(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const void * src1,
|
device const void * src1,
|
||||||
@@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
device const void * src5,
|
device const void * src5,
|
||||||
device const void * src6,
|
device const void * src6,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
threadgroup float * shared [[threadgroup(0)]],
|
||||||
constant ggml_metal_kargs_ssm_scan & args,
|
constant ggml_metal_kargs_ssm_scan & args,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
const int64_t i1 = tgpig.x;
|
const int64_t i1 = tgpig.x;
|
||||||
const int64_t ir = tgpig.y; // current head
|
const int64_t ir = tgpig.y; // current head
|
||||||
const int64_t i3 = tgpig.z; // current seq
|
const int64_t i3 = tgpig.z; // current seq
|
||||||
@@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
const int64_t ng = args.n_group;
|
const int64_t ng = args.n_group;
|
||||||
const int64_t n_t = args.n_seq_tokens;
|
const int64_t n_t = args.n_seq_tokens;
|
||||||
|
|
||||||
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
const int64_t s_off = args.s_off;
|
||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src6;
|
device const int32_t * ids = (device const int32_t *) src6;
|
||||||
|
|
||||||
@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
||||||
const float x_dt = x[0] * dt_soft_plus;
|
const float x_dt = x[0] * dt_soft_plus;
|
||||||
const float dA = exp(dt_soft_plus * A[0]);
|
const float dA = exp(dt_soft_plus * A[0]);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
const int64_t i = tpitg.x + i1*nc;
|
||||||
const int64_t i = i0 + i1*nc;
|
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
|
||||||
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
sumf += state * C[tpitg.x];
|
||||||
sumf += state * C[i0];
|
s[i] = state;
|
||||||
s[i] = state;
|
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Use the shared buffer to hold the sum of each simd group
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shared[sgitg] = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Sum the simd buckets
|
||||||
|
sumf = shared[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
y[0] = sumf;
|
y[0] = sumf;
|
||||||
|
|
||||||
// recurse
|
// recurse
|
||||||
|
|||||||
Reference in New Issue
Block a user