mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Fix CUDA softmax by subtracting max value before exp (#2665)
This commit is contained in:
		
							
								
								
									
										37
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int | |||||||
|  |  | ||||||
| // the CUDA soft max implementation differs from the CPU implementation | // the CUDA soft max implementation differs from the CPU implementation | ||||||
| // instead of doubles floats are used | // instead of doubles floats are used | ||||||
| // values are also not normalized to the maximum value by subtracting it in the exponential function |  | ||||||
| // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine |  | ||||||
| static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { | static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { | ||||||
|     const int row = blockDim.x*blockIdx.x + threadIdx.x; |     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|     const int block_size = blockDim.y; |     const int block_size = blockDim.y; | ||||||
|     const int tid = threadIdx.y; |     const int tid = threadIdx.y; | ||||||
|  |  | ||||||
|     float tmp = 0.0; |     float max_val = -INFINITY; | ||||||
|  |  | ||||||
|     for (int block_start = 0; block_start < ncols; block_start += block_size) { |  | ||||||
|         const int col = block_start + tid; |  | ||||||
|  |  | ||||||
|         if (col >= ncols) { |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|  |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const int i = row*ncols + col; |         const int i = row*ncols + col; | ||||||
|         const float val = expf(x[i]); |         max_val = max(max_val, x[i]); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // find the max value in the block | ||||||
|  | #pragma unroll | ||||||
|  |     for (int mask = 16; mask > 0; mask >>= 1) { | ||||||
|  |         max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     float tmp = 0.f; | ||||||
|  |  | ||||||
|  |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|  |         const int i = row*ncols + col; | ||||||
|  |         const float val = expf(x[i] - max_val); | ||||||
|         tmp += val; |         tmp += val; | ||||||
|         dst[i] = val; |         dst[i] = val; | ||||||
|     } |     } | ||||||
| @@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol | |||||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); |         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int block_start = 0; block_start < ncols; block_start += block_size) { |     const float inv_tmp = 1.f / tmp; | ||||||
|         const int col = block_start + tid; |  | ||||||
|  |  | ||||||
|         if (col >= ncols) { |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|  |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const int i = row*ncols + col; |         const int i = row*ncols + col; | ||||||
|         dst[i] /= tmp; |         dst[i] *= inv_tmp; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jiahao Li
					Jiahao Li