mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
feat: Parallel sum in SSM_CONV
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -2909,7 +2909,26 @@ static bool ggml_metal_encode_node(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
const int64_t d_state = ne10;
|
||||||
|
|
||||||
|
// One shared memory bucket for each simd group in the threadgroup
|
||||||
|
if (d_state >= 32) {
|
||||||
|
const int64_t shmem_size = d_state / 32;
|
||||||
|
|
||||||
|
// The final simd_sum won't work if the number of simd groups is
|
||||||
|
// larger than the size of a single simd group. If this case is
|
||||||
|
// hit at some point, the logic in the second simd_sum could be
|
||||||
|
// expanded to handle this with one more sequential simd_sum to
|
||||||
|
// collapse simd group sums another time.
|
||||||
|
GGML_ASSERT(shmem_size <= 32);
|
||||||
|
|
||||||
|
// One thread pre element in d_state
|
||||||
|
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
|
|||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const void * src1,
|
device const void * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
threadgroup float * shared [[threadgroup(0)]],
|
||||||
constant ggml_metal_kargs_ssm_conv & args,
|
constant ggml_metal_kargs_ssm_conv & args,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||||
|
|
||||||
|
const int64_t i0 = tpitg.x;
|
||||||
const int64_t ir = tgpig.x;
|
const int64_t ir = tgpig.x;
|
||||||
const int64_t i2 = tgpig.y;
|
const int64_t i2 = tgpig.y;
|
||||||
const int64_t i3 = tgpig.z;
|
const int64_t i3 = tgpig.z;
|
||||||
@@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32(
|
|||||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||||
|
|
||||||
float sumf = 0.0f;
|
float sumf = s[i0] * c[i0];
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
// Parallel sum: first sum over threads in simd group, then sum over simd
|
||||||
sumf += s[i0] * c[i0];
|
// group sums
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
// If multiple simd groups per threadgroup, sum over simd group sums
|
||||||
|
if (sgptg > 1) {
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shared[sgitg] = sumf;
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sumf = 0.0f;
|
||||||
|
if (sgitg == 0) {
|
||||||
|
if (tiisg < sgptg) {
|
||||||
|
sumf = shared[tiisg];
|
||||||
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
x[0] = sumf;
|
x[0] = sumf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (tiisg == 0) {
|
||||||
|
x[0] = sumf;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
||||||
|
|||||||
Reference in New Issue
Block a user