From dd745ba31f21a9bcd810391b9020cb41e44fcb81 Mon Sep 17 00:00:00 2001 From: leejet Date: Wed, 13 Aug 2025 01:09:25 +0800 Subject: [PATCH] make im2col_3d faster --- ggml/src/ggml-cuda/im2col.cu | 47 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index c0a3912982..7737d6a5d5 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -119,35 +119,36 @@ static __global__ void im2col_3d_kernel( const float * src, T * dst, int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, - int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, - int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t N_OD_IC, int64_t OD_IC, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= OW_KD_KH_KW) { + if (i >= IC_KD_KH_KW) { return; } - const int64_t iow = i / KD_KH_KW; - const int64_t ikd = (i - iow * KD_KH_KW) / KH_KW; - const int64_t ikh = (i - iow * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t ioh = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_IC; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_IC; - const int64_t iod = (iz - in*OD_IC) / IC; - const int64_t iic = iz % IC; + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; const int64_t iiw = iow * s0 + ikw * d0 - p0; const int64_t iih = ioh * s1 + ikh * d1 - p1; const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { dst[offset_dst] = 0.0f; } else { - const int64_t offset_src = (in*IC + iic)*ID_IH_IW + iid*IH_IW + iih*IW + iiw; + const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw; dst[offset_dst] = src[offset_src]; } } @@ -166,13 +167,19 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t IH_IW = IH*IW; const int64_t IC_KD_KH_KW = IC*KD*KH*KW; const int64_t OW_KD_KH_KW = OW*KD*KH*KW; - const int64_t N_OD_IC = N*OD*IC; - const int64_t OD_IC = OD*IC; - const int64_t num_blocks = (OW_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, MIN(N_OD_IC, MAX_GRIDDIM_Z)); - im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, - OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_KD_KH_KW, - OW_KD_KH_KW, N_OD_IC, OD_IC, s0, s1, s2, p0, p1, p2, d0, d1, d2); + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, + IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, + OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + s0, s1, s2, p0, p1, p2, d0, d1, d2); } static void im2col_3d_cuda_f16(const float * src, half * dst,