#include "unary.cuh" template static __global__ void op_abs(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = fabsf(x[i]); } template static __global__ void op_sgn(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = (T)(x[i] > (T)0.f ? 1.f : ((x[i] < (T)0.f ? -1.f : 0.f))); } template static __global__ void op_neg(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = -x[i]; } template static __global__ void op_step(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] > (T)0.0f; } template static __global__ void op_gelu(const T * x, T * dst, const int k) { const T GELU_COEF_A = 0.044715f; const T SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } T xi = x[i]; dst[i] = (T)0.5f*xi*((T)1.0f + (T)tanhf(SQRT_2_OVER_PI*xi*((T)1.0f + GELU_COEF_A*xi*xi))); } template static __global__ void op_gelu_quick(const T * x, T * dst, int k) { const T GELU_QUICK_COEF = -1.702f; const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] * ((T)1.0f / ((T)1.0f + (T)expf(GELU_QUICK_COEF * x[i]))); } template static __global__ void op_silu(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] / ((T)1.0f + (T)expf(-x[i])); } template static __global__ void op_silu_back( const T * grad, const T * xf, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } const T xfi = xf[i]; const T s = (T)1.0f / ((T)1.0f + (T)expf(-xfi)); dst[i] = grad[i] * s * ((T)1.0f + xfi * ((T)1.0f - s)); } template static __global__ void op_tanh(const T * x, T * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = tanhf(x[i]); } template static __global__ void op_relu(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = fmaxf(x[i], 0); } template static __global__ void op_sigmoid(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = (T)1.0f / ((T)1.0f + (T)expf(-x[i])); } template static __global__ void op_hardsigmoid(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f)); } template static __global__ void op_hardswish(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] * (T)fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f)); } template static __global__ void op_exp(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = expf(x[i]); } template static __global__ void op_leaky_relu(const T * x, T * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = (T)fmaxf(x[i], 0) + (T)fminf(x[i], 0.0f) * (T)negative_slope; } template static __global__ void op_sqr(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] * x[i]; } template static __global__ void op_sqrt(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = sqrtf(x[i]); } template static __global__ void op_sin(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = sinf(x[i]); } template static __global__ void op_cos(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = cosf(x[i]); } template static __global__ void op_log(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = logf(x[i]); } template static void abs_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; op_abs<<>>(x, dst, k); } template static void sgn_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; op_sgn<<>>(x, dst, k); } template static void neg_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; op_neg<<>>(x, dst, k); } template static void step_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE; op_step<<>>(x, dst, k); } template static void gelu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; op_gelu<<>>(x, dst, k); } template static void gelu_quick_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; op_gelu_quick<<>>(x, dst, k); } template static void silu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; op_silu<<>>(x, dst, k); } template static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; op_silu_back<<>>(grad, x, dst, k); } template static void tanh_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; op_tanh<<>>(x, dst, k); } template static void relu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; op_relu<<>>(x, dst, k); } template static void sigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE; op_sigmoid<<>>(x, dst, k); } template static void hardsigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE; op_hardsigmoid<<>>(x, dst, k); } template static void hardswish_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE; op_hardswish<<>>(x, dst, k); } template static void exp_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE; op_exp<<>>(x, dst, k); } template static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; op_leaky_relu<<>>(x, dst, k, negative_slope); } template static void sqr_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; op_sqr<<>>(x, dst, k); } template static void sqrt_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE; op_sqrt<<>>(x, dst, k); } template static void sin_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE; op_sin<<>>(x, dst, k); } template static void cos_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE; op_cos<<>>(x, dst, k); } template static void log_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE; op_log<<>>(x, dst, k); } void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { abs_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { abs_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { sgn_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { sgn_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { neg_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { neg_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { step_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { step_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { gelu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { gelu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { silu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { silu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; // input from forward pass const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream); } else { silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { gelu_quick_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { gelu_quick_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { tanh_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { tanh_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { sigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { sigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { hardsigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { hardsigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { hardswish_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { hardswish_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { exp_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { exp_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); float negative_slope; memcpy(&negative_slope, dst->op_params, sizeof(float)); if (src0->type == GGML_TYPE_F16) { leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream); } else { leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } } void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { sqr_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { sqr_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { sqrt_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { sqrt_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { sin_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { sin_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { cos_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { cos_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } } void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const void * src0_d = src0->data; void * dst_d = dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); if (src0->type == GGML_TYPE_F16) { log_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); } else { log_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); } }