mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	fix cuda pad/scale/im2col3d
This commit is contained in:
		@@ -142,9 +142,9 @@ static  __global__ void im2col_3d_kernel(
 | 
			
		||||
        const int64_t iih = ioh * s1 + ikh * d1 - p1;
 | 
			
		||||
        const int64_t iid = iod * s2 + ikd * d2 - p2;
 | 
			
		||||
 | 
			
		||||
        const int64_t offset_dst = (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
 | 
			
		||||
        const int64_t offset_dst = (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
 | 
			
		||||
 | 
			
		||||
        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
 | 
			
		||||
        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
 | 
			
		||||
            dst[offset_dst] = 0.0f;
 | 
			
		||||
        } else {
 | 
			
		||||
            const int64_t offset_src = (in*IC + iic)*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
 | 
			
		||||
 
 | 
			
		||||
@@ -22,13 +22,13 @@ static __global__ void pad_f32(const float * src, float * dst,
 | 
			
		||||
         && (i1 >= lp1 && i1 < ne1 - rp1) \
 | 
			
		||||
         && (i2 >= lp2 && i2 < ne2 - rp2) \
 | 
			
		||||
         && (i3 >= lp3 && i3 < ne3 - rp3)) {
 | 
			
		||||
        int i00 = i0 - lp0;
 | 
			
		||||
        int i01 = i1 - lp1;
 | 
			
		||||
        int i02 = i2 - lp2;
 | 
			
		||||
        int i03 = i3 - lp3;
 | 
			
		||||
        int ne02 = ne2 - lp2 - rp2;
 | 
			
		||||
        int ne01 = ne1 - lp1 - rp1;
 | 
			
		||||
        int ne00 = ne0 - lp0 - rp0;
 | 
			
		||||
        int64_t i00 = i0 - lp0;
 | 
			
		||||
        int64_t i01 = i1 - lp1;
 | 
			
		||||
        int64_t i02 = i2 - lp2;
 | 
			
		||||
        int64_t i03 = i3 - lp3;
 | 
			
		||||
        int64_t ne02 = ne2 - lp2 - rp2;
 | 
			
		||||
        int64_t ne01 = ne1 - lp1 - rp1;
 | 
			
		||||
        int64_t ne00 = ne0 - lp0 - rp0;
 | 
			
		||||
 | 
			
		||||
        int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,18 +1,20 @@
 | 
			
		||||
#include "scale.cuh"
 | 
			
		||||
 | 
			
		||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
 | 
			
		||||
    const int i = blockDim.x*blockIdx.x + threadIdx.x;
 | 
			
		||||
#define MIN(a, b) (a) < (b) ? (a) : (b)
 | 
			
		||||
#define MAX_GRIDDIM_X 0x7FFFFFFF
 | 
			
		||||
 | 
			
		||||
    if (i >= k) {
 | 
			
		||||
        return;
 | 
			
		||||
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
 | 
			
		||||
    int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
 | 
			
		||||
    int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
 | 
			
		||||
 | 
			
		||||
    for (int64_t i = tid; i < nelements; i += stride) {
 | 
			
		||||
        dst[i] = scale * x[i] + bias;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    dst[i] = scale * x[i] + bias;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
 | 
			
		||||
    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
 | 
			
		||||
    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
 | 
			
		||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
 | 
			
		||||
    const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
 | 
			
		||||
    scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user