mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-21 12:16:57 +00:00
CUDA: backwards pass for misc. ops, add tests (#11257)
* CUDA: backwards pass for misc. ops, add tests * remove restrict from pointers
This commit is contained in:
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void silu_back_f32(
|
||||
const float * grad, const float * xf, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float xfi = xf[i];
|
||||
const float s = 1.0f / (1.0f + expf(-xfi));
|
||||
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
|
||||
}
|
||||
|
||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
||||
}
|
||||
|
||||
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
||||
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // input from forward pass
|
||||
const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
||||
Reference in New Issue
Block a user