cuda: make im2col a little faster

This commit is contained in:
leejet
2025-08-02 01:06:22 +08:00
parent f7a12f9e69
commit 85c8e1e519

View File

@@ -133,7 +133,7 @@ static __global__ void im2col_3d_kernel(
const int64_t ikw = i % KW;
const int64_t ioh = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_IC; iz+=MAX_GRID_DIM_Z) {
for (int64_t iz = blockIdx.z; iz < N_OD_IC; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_IC;
const int64_t iod = (iz - in*OD_IC) / IC;
const int64_t iic = iz % IC;
@@ -169,7 +169,7 @@ static void im2col_3d_cuda(const float * src, T* dst,
const int64_t N_OD_IC = N*OD*IC;
const int64_t OD_IC = OD*IC;
const int64_t num_blocks = (OW_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, MIN(N_OD_IC, MAX_GRID_DIM_Z));
dim3 block_nums(num_blocks, OH, MIN(N_OD_IC, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(OW_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_KD_KH_KW,
OW_KD_KH_KW, N_OD_IC, OD_IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);