mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
This commit is contained in:
		
							
								
								
									
										130
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										130
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 | 
			
		||||
#define CUDA_SCALE_BLOCK_SIZE 256
 | 
			
		||||
#define CUDA_CLAMP_BLOCK_SIZE 256
 | 
			
		||||
#define CUDA_ROPE_BLOCK_SIZE 256
 | 
			
		||||
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 | 
			
		||||
#define CUDA_ALIBI_BLOCK_SIZE 32
 | 
			
		||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 | 
			
		||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
 | 
			
		||||
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
 | 
			
		||||
 | 
			
		||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 | 
			
		||||
 | 
			
		||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
 | 
			
		||||
    }
 | 
			
		||||
    return x;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
 | 
			
		||||
        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
 | 
			
		||||
    }
 | 
			
		||||
    return a;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
 | 
			
		||||
    }
 | 
			
		||||
    return x;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
 | 
			
		||||
    const int i = blockDim.x*blockIdx.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
 | 
			
		||||
    dst[i] = x[i] * x[i];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
 | 
			
		||||
        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
 | 
			
		||||
    }
 | 
			
		||||
    return a;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int block_size>
 | 
			
		||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
 | 
			
		||||
    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 | 
			
		||||
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
 | 
			
		||||
    }
 | 
			
		||||
    return x;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int block_size>
 | 
			
		||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
 | 
			
		||||
    const int row = blockIdx.x*blockDim.y + threadIdx.y;
 | 
			
		||||
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
 | 
			
		||||
    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// the CUDA soft max implementation differs from the CPU implementation
 | 
			
		||||
// instead of doubles floats are used
 | 
			
		||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
 | 
			
		||||
    const int row = blockDim.x*blockIdx.x + threadIdx.x;
 | 
			
		||||
    const int block_size = blockDim.y;
 | 
			
		||||
    const int tid = threadIdx.y;
 | 
			
		||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
 | 
			
		||||
    const int tid  = threadIdx.x;
 | 
			
		||||
    const int rowx = blockIdx.x;
 | 
			
		||||
    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
 | 
			
		||||
 | 
			
		||||
    const int block_size = blockDim.x;
 | 
			
		||||
 | 
			
		||||
    const int warp_id = threadIdx.x / WARP_SIZE;
 | 
			
		||||
    const int lane_id = threadIdx.x % WARP_SIZE;
 | 
			
		||||
 | 
			
		||||
    __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
 | 
			
		||||
 | 
			
		||||
    float max_val = -INFINITY;
 | 
			
		||||
 | 
			
		||||
    for (int col = tid; col < ncols; col += block_size) {
 | 
			
		||||
        const int i = row*ncols + col;
 | 
			
		||||
        max_val = max(max_val, x[i]);
 | 
			
		||||
        const int ix = rowx*ncols + col;
 | 
			
		||||
        const int iy = rowy*ncols + col;
 | 
			
		||||
        max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // find the max value in the block
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
 | 
			
		||||
    max_val = warp_reduce_max(max_val);
 | 
			
		||||
    if (block_size > WARP_SIZE) {
 | 
			
		||||
        if (warp_id == 0) {
 | 
			
		||||
            buf[lane_id] = -INFINITY;
 | 
			
		||||
        }
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        if (lane_id == 0) {
 | 
			
		||||
            buf[warp_id] = max_val;
 | 
			
		||||
        }
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        max_val = buf[lane_id];
 | 
			
		||||
        max_val = warp_reduce_max(max_val);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float tmp = 0.f;
 | 
			
		||||
 | 
			
		||||
    for (int col = tid; col < ncols; col += block_size) {
 | 
			
		||||
        const int i = row*ncols + col;
 | 
			
		||||
        const float val = expf(x[i] - max_val);
 | 
			
		||||
        const int ix = rowx*ncols + col;
 | 
			
		||||
        const int iy = rowy*ncols + col;
 | 
			
		||||
        const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
 | 
			
		||||
        tmp += val;
 | 
			
		||||
        dst[i] = val;
 | 
			
		||||
        dst[ix] = val;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // sum up partial sums
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
			
		||||
        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
 | 
			
		||||
    // find the sum of exps in the block
 | 
			
		||||
    tmp = warp_reduce_sum(tmp);
 | 
			
		||||
    if (block_size > WARP_SIZE) {
 | 
			
		||||
        if (warp_id == 0) {
 | 
			
		||||
            buf[lane_id] = 0.f;
 | 
			
		||||
        }
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        if (lane_id == 0) {
 | 
			
		||||
            buf[warp_id] = tmp;
 | 
			
		||||
        }
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
 | 
			
		||||
        tmp = buf[lane_id];
 | 
			
		||||
        tmp = warp_reduce_sum(tmp);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const float inv_tmp = 1.f / tmp;
 | 
			
		||||
 | 
			
		||||
    for (int col = tid; col < ncols; col += block_size) {
 | 
			
		||||
        const int i = row*ncols + col;
 | 
			
		||||
        const int i = rowx*ncols + col;
 | 
			
		||||
        dst[i] *= inv_tmp;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
 | 
			
		||||
    diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
 | 
			
		||||
    const dim3 block_dims(1, WARP_SIZE, 1);
 | 
			
		||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
 | 
			
		||||
    int nth = WARP_SIZE;
 | 
			
		||||
    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
 | 
			
		||||
    const dim3 block_dims(nth,     1, 1);
 | 
			
		||||
    const dim3 block_nums(nrows_x, 1, 1);
 | 
			
		||||
    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 | 
			
		||||
    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void im2col_f32_f16_cuda(const float * x, half * dst,
 | 
			
		||||
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
 | 
			
		||||
    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 | 
			
		||||
 | 
			
		||||
    const int64_t ne00 = src0->ne[0];
 | 
			
		||||
    const int64_t nrows = ggml_nrows(src0);
 | 
			
		||||
    const int64_t nrows_x = ggml_nrows(src0);
 | 
			
		||||
    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
 | 
			
		||||
 | 
			
		||||
    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
 | 
			
		||||
    float scale = 1.0f;
 | 
			
		||||
    memcpy(&scale, dst->op_params, sizeof(float));
 | 
			
		||||
 | 
			
		||||
    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 | 
			
		||||
 | 
			
		||||
    (void) src1;
 | 
			
		||||
    (void) dst;
 | 
			
		||||
    (void) src1_dd;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void ggml_cuda_op_scale(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user