mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : use 512 threads for soft_max instead of 32
This commit is contained in:
		
							
								
								
									
										51
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ | |||||||
| #define CUDA_SCALE_BLOCK_SIZE 256 | #define CUDA_SCALE_BLOCK_SIZE 256 | ||||||
| #define CUDA_CLAMP_BLOCK_SIZE 256 | #define CUDA_CLAMP_BLOCK_SIZE 256 | ||||||
| #define CUDA_ROPE_BLOCK_SIZE 256 | #define CUDA_ROPE_BLOCK_SIZE 256 | ||||||
|  | #define CUDA_SOFT_MAX_BLOCK_SIZE 512 | ||||||
| #define CUDA_ALIBI_BLOCK_SIZE 32 | #define CUDA_ALIBI_BLOCK_SIZE 32 | ||||||
| #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 | #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 | ||||||
| #define CUDA_QUANTIZE_BLOCK_SIZE 256 | #define CUDA_QUANTIZE_BLOCK_SIZE 256 | ||||||
| @@ -4717,26 +4718,32 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int | |||||||
|     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU |     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU | ||||||
| } | } | ||||||
|  |  | ||||||
| // the CUDA soft max implementation differs from the CPU implementation | // TODO: maybe can be improved with some warp-based primitives | ||||||
| // instead of doubles floats are used |  | ||||||
| static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { | static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { | ||||||
|     const int rowx = blockDim.x*blockIdx.x + threadIdx.x; |     const int tid  = threadIdx.x; | ||||||
|  |     const int rowx = blockIdx.x; | ||||||
|     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension |     const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension | ||||||
|     const int block_size = blockDim.y; |  | ||||||
|     const int tid = threadIdx.y; |  | ||||||
|  |  | ||||||
|     float max_val = -INFINITY; |     const int block_size = blockDim.x; | ||||||
|  |  | ||||||
|  |     __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE]; | ||||||
|  |  | ||||||
|  |     buf[tid] = -INFINITY; | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const int ix = rowx*ncols + col; |         const int ix = rowx*ncols + col; | ||||||
|         const int iy = rowy*ncols + col; |         const int iy = rowy*ncols + col; | ||||||
|         max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f)); |         buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f)); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|     // find the max value in the block |     // find the max value in the block | ||||||
| #pragma unroll |     for (int i = block_size/2; i > 0; i >>= 1) { | ||||||
|     for (int mask = 16; mask > 0; mask >>= 1) { |         if (tid < i) { | ||||||
|         max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); |             buf[tid] = max(buf[tid], buf[tid + i]); | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     float tmp = 0.f; |     float tmp = 0.f; | ||||||
| @@ -4744,18 +4751,26 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds | |||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const int ix = rowx*ncols + col; |         const int ix = rowx*ncols + col; | ||||||
|         const int iy = rowy*ncols + col; |         const int iy = rowy*ncols + col; | ||||||
|         const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val); |         const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]); | ||||||
|         tmp += val; |         tmp += val; | ||||||
|         dst[ix] = val; |         dst[ix] = val; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     buf[tid] = tmp; | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|     // sum up partial sums |     // sum up partial sums | ||||||
| #pragma unroll |     for (int i = block_size/2; i > 0; i >>= 1) { | ||||||
|     for (int mask = 16; mask > 0; mask >>= 1) { |         if (tid < i) { | ||||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); |             buf[tid] += buf[tid + i]; | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const float inv_tmp = 1.f / tmp; |     const float inv_tmp = 1.f / buf[0]; | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const int i = rowx*ncols + col; |         const int i = rowx*ncols + col; | ||||||
| @@ -5796,7 +5811,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols | |||||||
| } | } | ||||||
|  |  | ||||||
| static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { | static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) { | ||||||
|     const dim3 block_dims(1, WARP_SIZE, 1); |     int nth = WARP_SIZE; | ||||||
|  |     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; | ||||||
|  |     const dim3 block_dims(nth,     1, 1); | ||||||
|     const dim3 block_nums(nrows_x, 1, 1); |     const dim3 block_nums(nrows_x, 1, 1); | ||||||
|     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale); |     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale); | ||||||
| } | } | ||||||
| @@ -6853,7 +6870,7 @@ inline void ggml_cuda_op_soft_max( | |||||||
|  |  | ||||||
|     const int64_t ne00 = src0->ne[0]; |     const int64_t ne00 = src0->ne[0]; | ||||||
|     const int64_t nrows_x = ggml_nrows(src0); |     const int64_t nrows_x = ggml_nrows(src0); | ||||||
|     const int64_t nrows_y = src1 ? ggml_nrows(src1) : 0; |     const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; | ||||||
|  |  | ||||||
|     float scale = 1.0f; |     float scale = 1.0f; | ||||||
|     memcpy(&scale, dst->op_params, sizeof(float)); |     memcpy(&scale, dst->op_params, sizeof(float)); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov