mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
fix: Update logic to correctly do the multi-layer parallel sum
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -1767,7 +1767,8 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
const int64_t i1 = tgpig.x;
|
||||
const int64_t ir = tgpig.y; // current head
|
||||
@@ -1802,29 +1803,42 @@ kernel void kernel_ssm_scan_f32_group(
|
||||
const float x_dt = x[0] * dt_soft_plus;
|
||||
const float dA = exp(dt_soft_plus * A[0]);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
const int64_t i = tpitg.x + i1*nc;
|
||||
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
|
||||
sumf += state * C[tpitg.x];
|
||||
s[i] = state;
|
||||
|
||||
// Parallel sum: This relies on the fact that this kernel will be
|
||||
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
||||
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
||||
// compute y = sum({state * C[i] for i in range(d_state)}).
|
||||
// To parallelize this effectively, we first use simd_sum over each SIMD
|
||||
// group to compute the sum of each SIMD group, then place the result in
|
||||
// the SIMD group's indexed bucket in the shared memory. We then sum
|
||||
// over the individual group sums to compute the final sum.
|
||||
|
||||
// Computed for each thread
|
||||
float sumf = state * C[tpitg.x];
|
||||
|
||||
// Sum the threads in the simd group => simd sum
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Use the shared buffer to hold the sum of each simd group
|
||||
// Once per simd group, place the group sum into the shared buffer
|
||||
if (tiisg == 0) {
|
||||
shared[sgitg] = sumf;
|
||||
}
|
||||
|
||||
// Wait for all threads in the threadgroup to reach this point. This
|
||||
// ensures that all elements of the shared buffer are populated with the
|
||||
// sum of the individual simd groups.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Sum the simd buckets
|
||||
sumf = shared[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
// Sum the simd buckets => threadgroup sum
|
||||
sumf = 0.0f;
|
||||
for (int64_t i0 = 0; i0 < sgptg; ++i0) {
|
||||
sumf += shared[i0];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
y[0] = sumf;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user