mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
ggml : add ops SOFTPLUS, EXPM1, TRI, SOLVE_TRI, CUMSUM (#17063)
* Add ops needed for new hybrid models: SOFTPLUS, EXPM1, TRI, SOLVE_TRI, CUMSUM * Update ggml/include/ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Code review * Whitespace * Update tests/test-backend-ops.cpp Co-authored-by: Diego Devesa <slarengh@gmail.com> * This is actually sigmoid, duh. * Add CONST, remove TRI_KEEP, other changes from review * Update tests/test-backend-ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml/src/ggml.c Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml/src/ggml.c Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml/src/ggml-cuda/unary.cu Co-authored-by: Aman Gupta <amangupta052@gmail.com> * Remove extra script * Update ggml/src/ggml.c Co-authored-by: Diego Devesa <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: Diego Devesa <slarengh@gmail.com> * moving changes from laptop [no ci] * pre-rebase * Update tests/test-backend-ops.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update tests/test-backend-ops.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Refactor tests * ggml : cleanup * cont : fix ggml_fill srcs * tests : add note * ggml : add ggml_fill_inplace * ggml : add asserts * ggml : fix ggml_fill constant cast * cont : ggml_tri minor * Use TENSOR_LOCALS * Fix regression from #14596, regenerate * Don't make commits at night... --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Diego Devesa <slarengh@gmail.com> Co-authored-by: Aman Gupta <amangupta052@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
committed by
GitHub
parent
a19bd6f7ce
commit
389ac78b26
@@ -2527,6 +2527,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_cuda_op_trunc(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
ggml_cuda_op_expm1(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_softplus(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3829,6 +3835,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
|
||||
@@ -81,6 +81,14 @@ static __device__ __forceinline__ float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_expm1(float x) {
|
||||
return expm1f(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_softplus(float x) {
|
||||
return (x > 20.0f) ? x : logf(1.0f + expf(x));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
@@ -233,6 +241,14 @@ void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_trunc>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_expm1>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_softplus>(ctx, dst);
|
||||
}
|
||||
/* gated ops */
|
||||
|
||||
template <float (*op)(float), typename T>
|
||||
|
||||
@@ -61,6 +61,10 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
Reference in New Issue
Block a user