mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
opencl: split ggml-opencl.cl into multiple files and cleanup (#12886)
* opencl: refactor - split the kernel files --------- Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com> * opencl: split more kernels into separate files * opencl: specify subgroup size instead of querying it * opencl: refine Adreno cl compiler version parsing * opencl: skip some kernels not used by Adreno on old compilers * opencl: refine logic for selecting Adreno kernels * opencl: refine Adreno cl compiler version * opencl: cleanup preprocessor for kernels * opencl: consider Adreno CL compiler on Windows * opencl: add final newline for `mul_mv_f16_f16.cl` --------- Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com>
This commit is contained in:
58
ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
Normal file
58
ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// diag_mask_inf kernels
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_diag_mask_inf(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int n_past
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i02 = get_global_id(2);
|
||||
int i01 = get_global_id(1);
|
||||
int i00 = get_global_id(0);
|
||||
|
||||
if (i00 > n_past + i01) {
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
||||
} else {
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_mask_inf_8(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int n_past
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
int i = 2*get_global_id(0);
|
||||
|
||||
dst[i+0] = src0[i+0];
|
||||
dst[i+1] = src0[i+1];
|
||||
int i4 = 4*i;
|
||||
int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||
int i01 = i4/(ne00); i4 -= i01*ne00;
|
||||
int i00 = i4;
|
||||
for (int k = 3; k >= 0; --k) {
|
||||
if (i00 + 4 + k <= n_past + i01) {
|
||||
break;
|
||||
}
|
||||
(&dst[i+1])[k] = -INFINITY;
|
||||
if (i00 + k > n_past + i01) {
|
||||
(&dst[i])[k] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user