mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
add CUDA
This commit is contained in:
@@ -1,18 +1,18 @@
|
|||||||
#include "scale.cuh"
|
#include "scale.cuh"
|
||||||
|
|
||||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
|
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[i] = scale * x[i];
|
dst[i] = scale * x[i] + bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
@@ -24,8 +24,8 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
float scale;
|
float scale = ((const float *)(dst->op_params))[0];
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
float bias = ((const float *)(dst->op_params))[1];
|
||||||
|
|
||||||
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
|
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user