mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state
This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
		@@ -2986,6 +2986,7 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
                    /*.n_group      =*/ n_group,
 | 
					                    /*.n_group      =*/ n_group,
 | 
				
			||||||
                    /*.n_seq_tokens =*/ n_seq_tokens,
 | 
					                    /*.n_seq_tokens =*/ n_seq_tokens,
 | 
				
			||||||
                    /*.n_seqs       =*/ n_seqs,
 | 
					                    /*.n_seqs       =*/ n_seqs,
 | 
				
			||||||
 | 
					                    /*.s_off        =*/ ggml_nelements(src1) * sizeof(float),
 | 
				
			||||||
                    /*.nb01         =*/ nb01,
 | 
					                    /*.nb01         =*/ nb01,
 | 
				
			||||||
                    /*.nb02         =*/ nb02,
 | 
					                    /*.nb02         =*/ nb02,
 | 
				
			||||||
                    /*.nb03         =*/ nb03,
 | 
					                    /*.nb03         =*/ nb03,
 | 
				
			||||||
@@ -3016,7 +3017,8 @@ static bool ggml_metal_encode_node(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                if (ne30 == 1) {
 | 
					                if (ne30 == 1) {
 | 
				
			||||||
                    // Mamba-2
 | 
					                    // Mamba-2
 | 
				
			||||||
                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 | 
					                    [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; // SIMD size
 | 
				
			||||||
 | 
					                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
 | 
				
			||||||
                } else {
 | 
					                } else {
 | 
				
			||||||
                    GGML_ASSERT(d_inner == 1);
 | 
					                    GGML_ASSERT(d_inner == 1);
 | 
				
			||||||
                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 | 
					                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1752,7 +1752,6 @@ kernel void kernel_ssm_scan_f32(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
 | 
					// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
 | 
				
			||||||
// TODO: optimize (e.g. by parallelizing over d_state)
 | 
					 | 
				
			||||||
kernel void kernel_ssm_scan_f32_group(
 | 
					kernel void kernel_ssm_scan_f32_group(
 | 
				
			||||||
        device const void * src0,
 | 
					        device const void * src0,
 | 
				
			||||||
        device const void * src1,
 | 
					        device const void * src1,
 | 
				
			||||||
@@ -1762,10 +1761,14 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
        device const void * src5,
 | 
					        device const void * src5,
 | 
				
			||||||
        device const void * src6,
 | 
					        device const void * src6,
 | 
				
			||||||
        device      float * dst,
 | 
					        device      float * dst,
 | 
				
			||||||
 | 
					        threadgroup float * shared [[threadgroup(0)]],
 | 
				
			||||||
        constant ggml_metal_kargs_ssm_scan & args,
 | 
					        constant ggml_metal_kargs_ssm_scan & args,
 | 
				
			||||||
        uint3  tgpig[[threadgroup_position_in_grid]],
 | 
					        uint3  tgpig[[threadgroup_position_in_grid]],
 | 
				
			||||||
        uint3  tpitg[[thread_position_in_threadgroup]],
 | 
					        uint3  tpitg[[thread_position_in_threadgroup]],
 | 
				
			||||||
 | 
					        ushort sgitg[[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
 | 
					        ushort tiisg[[thread_index_in_simdgroup]],
 | 
				
			||||||
        uint3    ntg[[threads_per_threadgroup]]) {
 | 
					        uint3    ntg[[threads_per_threadgroup]]) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t i1 = tgpig.x;
 | 
					    const int64_t i1 = tgpig.x;
 | 
				
			||||||
    const int64_t ir = tgpig.y; // current head
 | 
					    const int64_t ir = tgpig.y; // current head
 | 
				
			||||||
    const int64_t i3 = tgpig.z; // current seq
 | 
					    const int64_t i3 = tgpig.z; // current seq
 | 
				
			||||||
@@ -1780,7 +1783,7 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
    const int64_t ng  = args.n_group;
 | 
					    const int64_t ng  = args.n_group;
 | 
				
			||||||
    const int64_t n_t = args.n_seq_tokens;
 | 
					    const int64_t n_t = args.n_seq_tokens;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
 | 
					    const int64_t s_off = args.s_off;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    device const int32_t * ids = (device const int32_t *) src6;
 | 
					    device const int32_t * ids = (device const int32_t *) src6;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1798,15 +1801,31 @@ kernel void kernel_ssm_scan_f32_group(
 | 
				
			|||||||
        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
 | 
					        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
 | 
				
			||||||
        const float x_dt = x[0] * dt_soft_plus;
 | 
					        const float x_dt = x[0] * dt_soft_plus;
 | 
				
			||||||
        const float dA = exp(dt_soft_plus * A[0]);
 | 
					        const float dA = exp(dt_soft_plus * A[0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        float sumf = 0.0f;
 | 
					        float sumf = 0.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int64_t i0 = 0; i0 < nc; ++i0) {
 | 
					        const int64_t i = tpitg.x + i1*nc;
 | 
				
			||||||
            const int64_t i = i0 + i1*nc;
 | 
					        const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
 | 
				
			||||||
            const float state = (s0[i] * dA) + (B[i0] * x_dt);
 | 
					        sumf += state * C[tpitg.x];
 | 
				
			||||||
            sumf += state * C[i0];
 | 
					 | 
				
			||||||
        s[i] = state;
 | 
					        s[i] = state;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sumf = simd_sum(sumf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Use the shared buffer to hold the sum of each simd group
 | 
				
			||||||
 | 
					        if (tiisg == 0) {
 | 
				
			||||||
 | 
					            shared[sgitg] = sumf;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Sum the simd buckets
 | 
				
			||||||
 | 
					        sumf = shared[tiisg];
 | 
				
			||||||
 | 
					        sumf = simd_sum(sumf);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        y[0] = sumf;
 | 
					        y[0] = sumf;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // recurse
 | 
					        // recurse
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user