mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
ggml : implement GEGLU_ERF and GEGLU_QUICK ops (#14445)
This commit is contained in:
@@ -402,8 +402,8 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_relu;
|
||||
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
||||
cl_kernel kernel_norm;
|
||||
cl_kernel kernel_rms_norm;
|
||||
cl_kernel kernel_group_norm;
|
||||
@@ -753,12 +753,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
backend_ctx->program_glu =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
@@ -2277,6 +2281,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
@@ -6254,6 +6260,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
kernel = backend_ctx->kernel_swiglu_f16;
|
||||
}
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_ERF:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_geglu_erf;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_geglu_erf_f16;
|
||||
}
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU_QUICK:
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_geglu_quick;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_geglu_quick_f16;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported glu op");
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define GELU_COEF_A 0.044715f
|
||||
#define GELU_QUICK_COEF -1.702f
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu
|
||||
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
|
||||
dst_row[i0] = silu*x1;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu_erf
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_geglu_erf(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = gelu_erf*x1;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_erf_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const half x0 = src0_row[i0];
|
||||
const half x1 = src1_row[i0];
|
||||
|
||||
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = gelu_erf*x1;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// geglu_quick
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_geglu_quick(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = gelu_quick*x1;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_geglu_quick_f16(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb01,
|
||||
ulong nb11,
|
||||
int ne0,
|
||||
ulong nb1,
|
||||
int ne00_off,
|
||||
int ne10_off
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
||||
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
||||
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
||||
|
||||
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
||||
const half x0 = src0_row[i0];
|
||||
const half x1 = src1_row[i0];
|
||||
|
||||
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = gelu_quick*x1;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user