mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
cuda/cpu: Increase support for fp16 unary operations (ggml/1125)
* Support fp16 unary operations in the CUDA backend * cpu: increase fp16 support for unary operators in the CPU backend * cuda: increase fp16 support for unary operators in the CUDA backend * Add test cases for fp16 unary operators * metal: update supports_op for unary operators that don't support fp16, to prevent test-backend-ops from failing * metal: fix PR comments for unary op support after fp16 unary tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "clamp.cuh"
|
||||
|
||||
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
|
||||
template <class T>
|
||||
static __global__ void op_clamp(const T * x, T * dst, const T min, const T max, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
@@ -10,25 +11,31 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||
}
|
||||
|
||||
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
|
||||
template <class T>
|
||||
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
|
||||
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
|
||||
op_clamp<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
|
||||
}
|
||||
|
||||
|
||||
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
const void * src0_d = src0->data;
|
||||
void * dst_d = dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, dst->op_params, sizeof(float));
|
||||
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
|
||||
if (src0->type == GGML_TYPE_F16) {
|
||||
clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
|
||||
} else {
|
||||
clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user