make im2col_3d faster

This commit is contained in:
leejet
2025-08-13 01:09:25 +08:00
parent ae47caca70
commit dd745ba31f

View File

@@ -119,35 +119,36 @@ static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t N_OD_IC, int64_t OD_IC,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= OW_KD_KH_KW) {
if (i >= IC_KD_KH_KW) {
return;
}
const int64_t iow = i / KD_KH_KW;
const int64_t ikd = (i - iow * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iow * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t iic = i / KD_KH_KW;
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t ioh = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_IC; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_IC;
const int64_t iod = (iz - in*OD_IC) / IC;
const int64_t iic = iz % IC;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = (in*IC + iic)*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
dst[offset_dst] = src[offset_src];
}
}
@@ -166,13 +167,19 @@ static void im2col_3d_cuda(const float * src, T* dst,
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_IC = N*OD*IC;
const int64_t OD_IC = OD*IC;
const int64_t num_blocks = (OW_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, MIN(N_OD_IC, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(OW_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_KD_KH_KW,
OW_KD_KH_KW, N_OD_IC, OD_IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
const int64_t N_OD_OH = N*OD*OH;
const int64_t OD_OH = OD*OH;
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,