opencl : broadcast for soft_max (#14510)

This commit is contained in:
lhez
2025-07-03 11:22:24 -07:00
committed by GitHub
parent 2b72bedec1
commit bee28421be
5 changed files with 132 additions and 59 deletions

View File

@@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4(
global float * src0,
global char * src0,
ulong offset0,
global float * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;