mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
feat: Parallel sum in SSM_CONV
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
const int64_t d_state = ne10;
|
||||
|
||||
// One shared memory bucket for each simd group in the threadgroup
|
||||
if (d_state >= 32) {
|
||||
const int64_t shmem_size = d_state / 32;
|
||||
|
||||
// The final simd_sum won't work if the number of simd groups is
|
||||
// larger than the size of a single simd group. If this case is
|
||||
// hit at some point, the logic in the second simd_sum could be
|
||||
// expanded to handle this with one more sequential simd_sum to
|
||||
// collapse simd group sums another time.
|
||||
GGML_ASSERT(shmem_size <= 32);
|
||||
|
||||
// One thread pre element in d_state
|
||||
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
||||
}
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
|
||||
@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
const int64_t i0 = tpitg.x;
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i3 = tgpig.z;
|
||||
@@ -1681,14 +1687,32 @@ kernel void kernel_ssm_conv_f32(
|
||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
float sumf = s[i0] * c[i0];
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
// Parallel sum: first sum over threads in simd group, then sum over simd
|
||||
// group sums
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
// If multiple simd groups per threadgroup, sum over simd group sums
|
||||
if (sgptg > 1) {
|
||||
if (tiisg == 0) {
|
||||
shared[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
sumf = 0.0f;
|
||||
if (sgitg == 0) {
|
||||
if (tiisg < sgptg) {
|
||||
sumf = shared[tiisg];
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
x[0] = sumf;
|
||||
}
|
||||
}
|
||||
} else if (tiisg == 0) {
|
||||
x[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
||||
kernel void kernel_ssm_scan_f32(
|
||||
|
||||
Reference in New Issue
Block a user