ggml: add ops for WAN video model (cuda && cpu) (#15669)

* add conv3d support

* add ggml_pad_ext for cpu & cuda backend

* cuda/cpu: add im2col_3d support

* cuda: make im2col a little faster

* fix cuda pad/scale/im2col3d

* make im2col_3d faster

* gguf: support loading tensors which n_dims > GGML_MAX_DIMS

* fix cuda get_rows

* avoid ggml_conv_3d conflict

* correct GGML_OP_COUNT assertion

* avoid build failure

* avoid build failure on MacOS

* cuda: remove unnecessary MIN define

* fix cpu im2col_3d

* adjust the code style

* cuda: use simpler loop in get_rows

* add test_im2col_3d to test-backend-ops

* test-backend-ops.cpp: remove trailing whitespace

* cpu: im2col_3d support non continuous src

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>

* fix test_im2col_3d

* remove unused variables

* cuda: get_rows: dfloat2 -> float2

* add test_pad_ext to test-backend-ops.cpp

* add gguf_init_from_file_ext impl

* Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS"

This reverts commit d8377a0a37.

* Revert "add gguf_init_from_file_ext impl"

This reverts commit d9f1d13208.

* update ggml_backend_vk_device_supports_op

* fix ggml_backend_vk_device_supports_op

* update other backend supports op for ggml_pad_ext

* metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op

---------

Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
leejet
2025-09-04 16:38:49 +08:00
committed by GitHub
parent 5421f63ab0
commit 0a1b3982cd
17 changed files with 759 additions and 90 deletions

View File

@@ -2,6 +2,8 @@
#include "dequantize.cuh"
#include "convert.cuh"
#define MAX_GRIDDIM_Y 65535
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
@@ -11,32 +13,29 @@ static __global__ void k_get_rows(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) {
return;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
template<typename src0_t, typename dst_t>
@@ -48,22 +47,23 @@ static __global__ void k_get_rows_float(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) {
return;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
}
template<typename grad_t, typename dst_t>
@@ -98,7 +98,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@@ -131,7 +131,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);

View File

@@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_IM2COL_3D:
ggml_cuda_op_im2col_3d(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
@@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:

View File

@@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= IC_KD_KH_KW) {
return;
}
const int64_t iic = i / KD_KH_KW;
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
dst[offset_dst] = src[offset_src];
}
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static void im2col_3d_cuda(const float * src, T* dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_OH = N*OD*OH;
const int64_t OD_OH = OD*OH;
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
static void im2col_3d_cuda_f32(const float * src, float * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
if(dst->type == GGML_TYPE_F16) {
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
} else {
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
}

View File

@@ -3,3 +3,4 @@
#define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -1,36 +1,50 @@
#include "pad.cuh"
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
static __global__ void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
// blockIdx.z: i3*ne2+i2
// blockIdx.y: i1
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
// gridDim.y: ne1
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
int i1 = blockIdx.y;
int i2 = blockIdx.z % ne2;
int i3 = blockIdx.z / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[offset_dst] = 0.0f;
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
static void pad_f32_cuda(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}

View File

@@ -1,18 +1,19 @@
#include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
#define MAX_GRIDDIM_X 0x7FFFFFFF
if (i >= k) {
return;
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
for (int64_t i = tid; i < nelements; i += stride) {
dst[i] = scale * x[i] + bias;
}
dst[i] = scale * x[i] + bias;
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {