mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
ggml : add ggml_gelu_erf() CUDA kernel (#13719)
* ggml : add ggml_gelu_erf() CUDA kernel * missing semicolon
This commit is contained in:
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user