cuda/cpu: add im2col_3d support

This commit is contained in:
leejet
2025-08-02 00:38:51 +08:00
parent 93c7e775b8
commit f7a12f9e69
8 changed files with 353 additions and 12 deletions

View File

@@ -511,6 +511,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL, GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_BACK,
GGML_OP_IM2COL_3D,
GGML_OP_CONV_2D, GGML_OP_CONV_2D,
GGML_OP_CONV_3D, GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW, GGML_OP_CONV_2D_DW,

View File

@@ -1876,6 +1876,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_im2col_back_f32(params, tensor); ggml_compute_forward_im2col_back_f32(params, tensor);
} break; } break;
case GGML_OP_IM2COL_3D:
{
ggml_compute_forward_im2col_3d(params, tensor);
} break;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
{ {
ggml_compute_forward_conv_2d(params, tensor); ggml_compute_forward_conv_2d(params, tensor);
@@ -2255,6 +2259,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break; } break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK: case GGML_OP_IM2COL_BACK:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D: case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:

View File

@@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
} }
} }
// ggml_compute_forward_im2col_3d_f16
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
ggml_fp16_t * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iid*IH_IW + iih*IW + iiw]);
}
}
}
}
}
}
}
}
}
}
}
// ggml_compute_forward_im2col_3d_f32
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
float * const wdata = (float *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
float * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = src_data[iid*IH_IW + iih*IW + iiw];
}
}
}
}
}
}
}
}
}
}
}
void ggml_compute_forward_im2col_3d(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_3d_f16(params, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_im2col_3d_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
void * a, void * b, float * c) { void * a, void * b, float * c) {
const ggml_type_traits * traits = ggml_get_type_traits(type); const ggml_type_traits * traits = ggml_get_type_traits(type);

View File

@@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst); ggml_cuda_op_im2col(ctx, dst);
break; break;
case GGML_OP_IM2COL_3D:
ggml_cuda_op_im2col_3d(ctx, dst);
break;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst); ggml_cuda_op_conv2d(ctx, dst);
break; break;
@@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
} }
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_TRANSPOSE_2D:

View File

@@ -112,3 +112,125 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} }
} }
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t N_OD_IC, int64_t OD_IC,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= OW_KD_KH_KW) {
return;
}
const int64_t iow = i / KD_KH_KW;
const int64_t ikd = (i - iow * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iow * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t ioh = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_IC; iz+=MAX_GRID_DIM_Z) {
const int64_t in = iz / OD_IC;
const int64_t iod = (iz - in*OD_IC) / IC;
const int64_t iic = iz % IC;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = (in*IC + iic)*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
dst[offset_dst] = src[offset_src];
}
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static void im2col_3d_cuda(const float * src, T* dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_IC = N*OD*IC;
const int64_t OD_IC = OD*IC;
const int64_t num_blocks = (OW_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, MIN(N_OD_IC, MAX_GRID_DIM_Z));
im2col_3d_kernel<<<block_nums, MIN(OW_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_KD_KH_KW,
OW_KD_KH_KW, N_OD_IC, OD_IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
static void im2col_3d_cuda_f32(const float * src, float * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
if(dst->type == GGML_TYPE_F16) {
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
} else {
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
}

View File

@@ -3,3 +3,4 @@
#define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@@ -974,6 +974,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CONV_TRANSPOSE_1D", "CONV_TRANSPOSE_1D",
"IM2COL", "IM2COL",
"IM2COL_BACK", "IM2COL_BACK",
"IM2COL_3D",
"CONV_2D", "CONV_2D",
"CONV_3D", "CONV_3D",
"CONV_2D_DW", "CONV_2D_DW",
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)", "conv_transpose_1d(x)",
"im2col(x)", "im2col(x)",
"im2col_back(x)", "im2col_back(x)",
"im2col_3d(x)"
"conv_2d(x)", "conv_2d(x)",
"conv_3d(x)", "conv_3d(x)",
"conv_2d_dw(x)", "conv_2d_dw(x)",
@@ -4379,7 +4381,6 @@ struct ggml_tensor * ggml_im2col_3d(
int d1, // dilation height int d1, // dilation height
int d2, // dilation depth int d2, // dilation depth
enum ggml_type dst_type) { enum ggml_type dst_type) {
const int64_t N = b->ne[3] / IC; const int64_t N = b->ne[3] / IC;
const int64_t ID = b->ne[2]; const int64_t ID = b->ne[2];
const int64_t IH = b->ne[1]; const int64_t IH = b->ne[1];
@@ -4390,22 +4391,25 @@ struct ggml_tensor * ggml_im2col_3d(
const int64_t KH = a->ne[1]; const int64_t KH = a->ne[1];
const int64_t KW = a->ne[0]; const int64_t KW = a->ne[0];
const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2); const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
struct ggml_tensor* img = ggml_reshape_4d(ctx, b, IW*IH, ID, IC, N); // [N, IC, ID, IH * IW] GGML_ASSERT((OD > 0) && "b too small compared to a");
img = ggml_cont(ctx, ggml_permute(ctx, img, 2, 0, 1, 3)); // [N, IH*IW, IC, ID] GGML_ASSERT((OH > 0) && "b too small compared to a");
img = ggml_reshape_3d(ctx, b, ID, IC, IW*IH*N); // [N*IH*IW, IC, ID] GGML_ASSERT((OW > 0) && "b too small compared to a");
a = ggml_reshape_3d(ctx, a, KD, IC, OC*KW*KH); // [OC*KW*KH, IC, KD]
img = ggml_im2col(ctx, a, img, s2, 1, p2, 0, d2, 1, false, GGML_TYPE_F32); // [N*IH*IW, OD, IC*KD]
img = ggml_reshape_4d(ctx, img, IC*KD, OD, IW*IH, N); // [N, IH*IW, OD, IC*KD] const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
img = ggml_cont(ctx, ggml_permute(ctx, img, 1, 2, 0, 3)); // [N, OD, IC*KD, IH*IW]
img = ggml_reshape_4d(ctx, img, IW, IH, IC*KD, OD*N); // [N*OD, IC*KD, IH, IW]
a = ggml_reshape_4d(ctx, a, KW, KH, IC*KD, OC); // [OC, KD*IC, KH, KW] struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
ggml_set_op_params(result, params, sizeof(params));
img = ggml_im2col(ctx, a, img, s0, s1, p0, p1, d0, d1, true, dst_type); // [N * OD, OH, OW, IC * KD * KH * KW] result->op = GGML_OP_IM2COL_3D;
return img; result->src[0] = a;
result->src[1] = b;
return result;
} }
// a: [OC*IC, KD, KH, KW] // a: [OC*IC, KD, KH, KW]