From ae47caca70592ce0bfce4a0112ea0efe946d568b Mon Sep 17 00:00:00 2001 From: leejet Date: Tue, 12 Aug 2025 23:44:46 +0800 Subject: [PATCH] fix cuda pad/scale/im2col3d --- ggml/src/ggml-cuda/im2col.cu | 4 ++-- ggml/src/ggml-cuda/pad.cu | 14 +++++++------- ggml/src/ggml-cuda/scale.cu | 20 +++++++++++--------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 39168c63fc..c0a3912982 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -142,9 +142,9 @@ static __global__ void im2col_3d_kernel( const int64_t iih = ioh * s1 + ikh * d1 - p1; const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { dst[offset_dst] = 0.0f; } else { const int64_t offset_src = (in*IC + iic)*ID_IH_IW + iid*IH_IW + iih*IW + iiw; diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 6824bf066c..0bb98f0ba5 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -22,13 +22,13 @@ static __global__ void pad_f32(const float * src, float * dst, && (i1 >= lp1 && i1 < ne1 - rp1) \ && (i2 >= lp2 && i2 < ne2 - rp2) \ && (i3 >= lp3 && i3 < ne3 - rp3)) { - int i00 = i0 - lp0; - int i01 = i1 - lp1; - int i02 = i2 - lp2; - int i03 = i3 - lp3; - int ne02 = ne2 - lp2 - rp2; - int ne01 = ne1 - lp1 - rp1; - int ne00 = ne0 - lp0 - rp0; + int64_t i00 = i0 - lp0; + int64_t i01 = i1 - lp1; + int64_t i02 = i2 - lp2; + int64_t i03 = i3 - lp3; + int64_t ne02 = ne2 - lp2 - rp2; + int64_t ne01 = ne1 - lp1 - rp1; + int64_t ne00 = ne0 - lp0 - rp0; int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00; diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 2ee9e58899..bfc03d218d 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,18 +1,20 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MIN(a, b) (a) < (b) ? (a) : (b) +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; } - - dst[i] = scale * x[i] + bias; } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, k); +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {