CUDA: add FLOOR, CEIL, ROUND, TRUNC unary ops (#16917)

This commit is contained in:
mnehete32
2025-11-02 08:42:57 +05:30
committed by GitHub
parent a864132ba5
commit 7db35a7958
3 changed files with 56 additions and 0 deletions

View File

@@ -2499,6 +2499,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_XIELU:
ggml_cuda_op_xielu(ctx, dst); ggml_cuda_op_xielu(ctx, dst);
break; break;
case GGML_UNARY_OP_FLOOR:
ggml_cuda_op_floor(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_cuda_op_ceil(ctx, dst);
break;
case GGML_UNARY_OP_ROUND:
ggml_cuda_op_round(ctx, dst);
break;
case GGML_UNARY_OP_TRUNC:
ggml_cuda_op_trunc(ctx, dst);
break;
default: default:
return false; return false;
} }
@@ -3769,6 +3781,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
default: default:
return false; return false;

View File

@@ -85,6 +85,22 @@ static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x); return (x > 0.f) ? x : expm1f(x);
} }
static __device__ __forceinline__ float op_floor(float x) {
return floorf(x);
}
static __device__ __forceinline__ float op_ceil(float x) {
return ceilf(x);
}
static __device__ __forceinline__ float op_round(float x) {
return round(x);
}
static __device__ __forceinline__ float op_trunc(float x) {
return trunc(x);
}
template <float (*op)(float), typename T> template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -201,6 +217,22 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst); ggml_cuda_op_unary<op_elu>(ctx, dst);
} }
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_floor>(ctx, dst);
}
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_ceil>(ctx, dst);
}
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_round>(ctx, dst);
}
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_trunc>(ctx, dst);
}
/* gated ops */ /* gated ops */
template <float (*op)(float), typename T> template <float (*op)(float), typename T>

View File

@@ -63,6 +63,14 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);