mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-10 10:27:03 +00:00
* add conv3d support * add ggml_pad_ext for cpu & cuda backend * cuda/cpu: add im2col_3d support * cuda: make im2col a little faster * fix cuda pad/scale/im2col3d * make im2col_3d faster * gguf: support loading tensors which n_dims > GGML_MAX_DIMS * fix cuda get_rows * avoid ggml_conv_3d conflict * correct GGML_OP_COUNT assertion * avoid build failure * avoid build failure on MacOS * cuda: remove unnecessary MIN define * fix cpu im2col_3d * adjust the code style * cuda: use simpler loop in get_rows * add test_im2col_3d to test-backend-ops * test-backend-ops.cpp: remove trailing whitespace * cpu: im2col_3d support non continuous src Co-authored-by: Jeff Bolz <jbolz@nvidia.com> * fix test_im2col_3d * remove unused variables * cuda: get_rows: dfloat2 -> float2 * add test_pad_ext to test-backend-ops.cpp * add gguf_init_from_file_ext impl * Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS" This reverts commitd8377a0a37. * Revert "add gguf_init_from_file_ext impl" This reverts commitd9f1d13208. * update ggml_backend_vk_device_supports_op * fix ggml_backend_vk_device_supports_op * update other backend supports op for ggml_pad_ext * metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
35 lines
1.4 KiB
Plaintext
35 lines
1.4 KiB
Plaintext
#include "scale.cuh"
|
|
|
|
#define MAX_GRIDDIM_X 0x7FFFFFFF
|
|
|
|
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
|
|
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
|
|
|
|
for (int64_t i = tid; i < nelements; i += stride) {
|
|
dst[i] = scale * x[i] + bias;
|
|
}
|
|
}
|
|
|
|
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
|
|
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
|
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
|
|
}
|
|
|
|
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const float * src0_d = (const float *)src0->data;
|
|
float * dst_d = (float *)dst->data;
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
|
float scale;
|
|
float bias;
|
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
|
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
|
|
|
|
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
|
|
}
|