opencl: add fused rms_norm_mul (#14841)

* opencl: add fused `rms_norm` + `mul`

* opencl: improve workgroup size for `rms_norm_mul`
This commit is contained in:
lhez
2025-07-25 08:12:13 -07:00
committed by GitHub
parent e7fecba934
commit ce111d39d6
2 changed files with 240 additions and 2 deletions

View File

@@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
}
}
}
//------------------------------------------------------------------------------
// rms_norm_mul
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_rms_norm_mul(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float eps,
local float * sum
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
float sumf = 0;
// parallel sum
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
sumf += dot(x[i00], x[i00]);
}
sumf = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
if (get_local_id(0) == 0) {
sum[0] /= ne00;
}
barrier(CLK_LOCAL_MEM_FENCE);
float mean = sum[0];
float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
}
}