mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
ggml: add ops for WAN video model (cuda && cpu) (#15669)
* add conv3d support * add ggml_pad_ext for cpu & cuda backend * cuda/cpu: add im2col_3d support * cuda: make im2col a little faster * fix cuda pad/scale/im2col3d * make im2col_3d faster * gguf: support loading tensors which n_dims > GGML_MAX_DIMS * fix cuda get_rows * avoid ggml_conv_3d conflict * correct GGML_OP_COUNT assertion * avoid build failure * avoid build failure on MacOS * cuda: remove unnecessary MIN define * fix cpu im2col_3d * adjust the code style * cuda: use simpler loop in get_rows * add test_im2col_3d to test-backend-ops * test-backend-ops.cpp: remove trailing whitespace * cpu: im2col_3d support non continuous src Co-authored-by: Jeff Bolz <jbolz@nvidia.com> * fix test_im2col_3d * remove unused variables * cuda: get_rows: dfloat2 -> float2 * add test_pad_ext to test-backend-ops.cpp * add gguf_init_from_file_ext impl * Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS" This reverts commitd8377a0a37. * Revert "add gguf_init_from_file_ext impl" This reverts commitd9f1d13208. * update ggml_backend_vk_device_supports_op * fix ggml_backend_vk_device_supports_op * update other backend supports op for ggml_pad_ext * metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
#include "dequantize.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#define MAX_GRIDDIM_Y 65535
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
@@ -11,32 +13,29 @@ static __global__ void k_get_rows(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(src0_row, ib, iqs, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
@@ -48,22 +47,23 @@ static __global__ void k_get_rows_float(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
}
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
}
|
||||
|
||||
template<typename grad_t, typename dst_t>
|
||||
@@ -98,7 +98,7 @@ static void get_rows_cuda_q(
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
@@ -131,7 +131,7 @@ static void get_rows_cuda_float(
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
|
||||
@@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL_3D:
|
||||
ggml_cuda_op_im2col_3d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D:
|
||||
ggml_cuda_op_conv2d(ctx, dst);
|
||||
break;
|
||||
@@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_3D:
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
|
||||
@@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
template <typename T>
|
||||
static __global__ void im2col_3d_kernel(
|
||||
const float * src, T * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
|
||||
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
|
||||
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
|
||||
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t iic = i / KD_KH_KW;
|
||||
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
|
||||
const int64_t ikw = i % KW;
|
||||
|
||||
const int64_t iow = blockIdx.y;
|
||||
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
|
||||
const int64_t in = iz / OD_OH;
|
||||
const int64_t iod = (iz - in*OD_OH) / OH;
|
||||
const int64_t ioh = iz % OH;
|
||||
|
||||
const int64_t iiw = iow * s0 + ikw * d0 - p0;
|
||||
const int64_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const int64_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
|
||||
dst[offset_dst] = src[offset_src];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
template <typename T>
|
||||
static void im2col_3d_cuda(const float * src, T* dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
const int64_t OH_OW = OH*OW;
|
||||
const int64_t KD_KH_KW = KD*KH*KW;
|
||||
const int64_t ID_IH_IW = ID*IH*IW;
|
||||
const int64_t KH_KW = KH*KW;
|
||||
const int64_t IH_IW = IH*IW;
|
||||
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
|
||||
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
|
||||
const int64_t N_OD_OH = N*OD*OH;
|
||||
const int64_t OD_OH = OD*OH;
|
||||
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
|
||||
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
|
||||
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
|
||||
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
|
||||
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
|
||||
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
|
||||
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
|
||||
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f16(const float * src, half * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f32(const float * src, float * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
|
||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
|
||||
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
|
||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
|
||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
|
||||
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
|
||||
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
|
||||
|
||||
const int64_t N = ne13 / IC;
|
||||
const int64_t ID = ne12;
|
||||
const int64_t IH = ne11;
|
||||
const int64_t IW = ne10;
|
||||
|
||||
const int64_t OC = ne03 / IC;
|
||||
const int64_t KD = ne02;
|
||||
const int64_t KH = ne01;
|
||||
const int64_t KW = ne00;
|
||||
|
||||
const int64_t OD = ne3 / N;
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
} else {
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -1,36 +1,50 @@
|
||||
#include "pad.cuh"
|
||||
|
||||
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
|
||||
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
|
||||
// blockIdx.y: idx of ne1
|
||||
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
static __global__ void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
// blockIdx.z: i3*ne2+i2
|
||||
// blockIdx.y: i1
|
||||
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
|
||||
// gridDim.y: ne1
|
||||
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int i1 = blockIdx.y;
|
||||
int i2 = blockIdx.z % ne2;
|
||||
int i3 = blockIdx.z / ne2;
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[offset_dst] = 0.0f;
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void pad_f32_cuda(const float * x, float * dst,
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
static void pad_f32_cuda(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
|
||||
pad_f32_cuda(src0_d, dst_d,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
}
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
#include "scale.cuh"
|
||||
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
#define MAX_GRIDDIM_X 0x7FFFFFFF
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
|
||||
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
||||
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
|
||||
|
||||
for (int64_t i = tid; i < nelements; i += stride) {
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
dst[i] = scale * x[i] + bias;
|
||||
}
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
Reference in New Issue
Block a user