mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : replace remaining shfl_xor with calls to warp_reduce functions (#5744)
This commit is contained in:
		
							
								
								
									
										73
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										73
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { | ||||
|     return a; | ||||
| } | ||||
|  | ||||
| //static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { | ||||
| //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL | ||||
| //#pragma unroll | ||||
| //    for (int mask = 16; mask > 0; mask >>= 1) { | ||||
| //        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); | ||||
| //    } | ||||
| //    return a; | ||||
| //#else | ||||
| //    (void) a; | ||||
| //    NO_DEVICE_CODE; | ||||
| //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL | ||||
| //} | ||||
| #ifdef GGML_CUDA_F16 | ||||
| static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { | ||||
| #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL | ||||
| #pragma unroll | ||||
|    for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); | ||||
|    } | ||||
|    return a; | ||||
| #else | ||||
|    (void) a; | ||||
|    NO_DEVICE_CODE; | ||||
| #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL | ||||
| } | ||||
| #endif // GGML_CUDA_F16 | ||||
|  | ||||
| static __device__ __forceinline__ float warp_reduce_max(float x) { | ||||
| #pragma unroll | ||||
| @@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, | ||||
| #endif | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[row] = tmp; | ||||
| @@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, | ||||
| #endif | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[row] = tmp; | ||||
| @@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, | ||||
| #endif | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (tid == 0) { | ||||
|         dst[row] = tmp; | ||||
| @@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, | ||||
| #endif | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[row] = tmp; | ||||
| @@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, | ||||
| #endif | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (tid == 0) { | ||||
|         dst[row] = tmp; | ||||
| @@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest | ||||
|     float amax = fabsf(xi); | ||||
|     float sum = xi; | ||||
|  | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); | ||||
|         sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); | ||||
|     } | ||||
|     amax = warp_reduce_max(amax); | ||||
|     sum = warp_reduce_sum(sum); | ||||
|  | ||||
|     const float d = amax / 127; | ||||
|     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); | ||||
| @@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (tid == 0) { | ||||
| #ifdef GGML_CUDA_F16 | ||||
| @@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32( | ||||
|     const int idst = channel*nrows_dst + row_dst; | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[idst] = tmp; | ||||
| @@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|     tmp = warp_reduce_sum(tmp); | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[idst] = tmp; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Engininja2
					Engininja2