mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
opencl: support sink in soft_max (attn sinks) (#15152)
This commit is contained in:
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * src2,
|
||||
ulong offset2,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
src2 = src2 + offset2;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
|
||||
}
|
||||
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
|
||||
}
|
||||
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max);
|
||||
}
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
pdst4[i00] /= sum;
|
||||
|
||||
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * src2,
|
||||
ulong offset2,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
src2 = src2 + offset2;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
|
||||
|
||||
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
||||
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
|
||||
}
|
||||
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max);
|
||||
}
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||
pdst4[i00] /= sum;
|
||||
|
||||
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * src2,
|
||||
ulong offset2,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
src2 = src2 + offset2;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
|
||||
|
||||
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
||||
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max);
|
||||
}
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
pdst[i00] /= sum;
|
||||
|
||||
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * src2,
|
||||
ulong offset2,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
src2 = src2 + offset2;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
|
||||
|
||||
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
||||
global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
|
||||
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = sub_group_reduce_add(lsum);
|
||||
float sum = sub_group_reduce_add(lsum);
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max);
|
||||
}
|
||||
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
pdst[i00] /= sum;
|
||||
|
||||
Reference in New Issue
Block a user