mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: add __restrict__ to mul mat vec kernels (#2140)
This commit is contained in:
		
							
								
								
									
										53
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -59,8 +59,8 @@ typedef float2 dfloat2; | |||||||
| #endif //GGML_CUDA_DMMV_F16 | #endif //GGML_CUDA_DMMV_F16 | ||||||
|  |  | ||||||
| typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); | typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); | ||||||
| typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); | typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); | ||||||
| typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); | typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); | ||||||
| typedef void (*cpy_kernel_t)(const char * cx, char * cdst); | typedef void (*cpy_kernel_t)(const char * cx, char * cdst); | ||||||
| typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); | typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); | ||||||
| typedef void (*ggml_cuda_op_t)( | typedef void (*ggml_cuda_op_t)( | ||||||
| @@ -131,7 +131,7 @@ typedef struct { | |||||||
| } block_q8_1; | } block_q8_1; | ||||||
| static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); | static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); | ||||||
|  |  | ||||||
| typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs); | typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs); | ||||||
|  |  | ||||||
| //================================= k-quants | //================================= k-quants | ||||||
|  |  | ||||||
| @@ -407,7 +407,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in | |||||||
|  |  | ||||||
| //================================== k-quants | //================================== k-quants | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { | static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { | ||||||
|  |  | ||||||
|     const int i   = blockIdx.x; |     const int i   = blockIdx.x; | ||||||
|     const block_q2_K * x = (const block_q2_K *) vx; |     const block_q2_K * x = (const block_q2_K *) vx; | ||||||
| @@ -440,7 +440,7 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { | |||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { | static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { | ||||||
|  |  | ||||||
|     const int i = blockIdx.x; |     const int i = blockIdx.x; | ||||||
|     const block_q3_K * x = (const block_q3_K *) vx; |     const block_q3_K * x = (const block_q3_K *) vx; | ||||||
| @@ -504,7 +504,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t | |||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { | static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { | ||||||
|     const block_q4_K * x = (const block_q4_K *) vx; |     const block_q4_K * x = (const block_q4_K *) vx; | ||||||
|  |  | ||||||
|     const int i = blockIdx.x; |     const int i = blockIdx.x; | ||||||
| @@ -544,7 +544,7 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { | static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { | ||||||
|     const block_q5_K * x = (const block_q5_K *) vx; |     const block_q5_K * x = (const block_q5_K *) vx; | ||||||
|  |  | ||||||
|     const int i = blockIdx.x; |     const int i = blockIdx.x; | ||||||
| @@ -590,7 +590,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { | static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { | ||||||
|     const block_q6_K * x = (const block_q6_K *) vx; |     const block_q6_K * x = (const block_q6_K *) vx; | ||||||
|  |  | ||||||
|     const int i = blockIdx.x; |     const int i = blockIdx.x; | ||||||
| @@ -634,7 +634,7 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { | static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { | ||||||
|  |  | ||||||
|     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); |     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); | ||||||
|  |  | ||||||
| @@ -742,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { | static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { | ||||||
|  |  | ||||||
|     const int row = blockIdx.y*blockDim.y + threadIdx.y; |     const int row = blockIdx.y*blockDim.y + threadIdx.y; | ||||||
|     if (row > nrows) return; |     if (row > nrows) return; | ||||||
| @@ -846,7 +846,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { | static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { | ||||||
|  |  | ||||||
|     const int row = blockIdx.y*blockDim.y + threadIdx.y; |     const int row = blockIdx.y*blockDim.y + threadIdx.y; | ||||||
|     if (row > nrows) return; |     if (row > nrows) return; | ||||||
| @@ -949,7 +949,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { | static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { | ||||||
|  |  | ||||||
|     const int row = blockIdx.x; |     const int row = blockIdx.x; | ||||||
|     const int num_blocks_per_row = ncols / QK_K; |     const int num_blocks_per_row = ncols / QK_K; | ||||||
| @@ -1053,7 +1053,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { | static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { | ||||||
|  |  | ||||||
|     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); |     static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); | ||||||
|  |  | ||||||
| @@ -1171,7 +1171,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, | |||||||
|     v.y = x[ib + iqs + 1]; |     v.y = x[ib + iqs + 1]; | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void quantize_q8_1(const float * x, void * vy, const int k) { | static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) { | ||||||
|     const int i = blockDim.x*blockIdx.x + threadIdx.x; |     const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|  |  | ||||||
|     if (i >= k) { |     if (i >= k) { | ||||||
| @@ -1207,7 +1207,7 @@ static __global__ void quantize_q8_1(const float * x, void * vy, const int k) { | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||||||
| static __global__ void dequantize_block(const void * vx, float * y, const int k) { | static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) { | ||||||
|     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; |     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; | ||||||
|  |  | ||||||
|     if (i >= k) { |     if (i >= k) { | ||||||
| @@ -1227,7 +1227,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) | |||||||
|     y[iybs + iqs + y_offset] = v.y; |     y[iybs + iqs + y_offset] = v.y; | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { | static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { | ||||||
| #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | ||||||
|     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; |     const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; | ||||||
|  |  | ||||||
| @@ -1252,7 +1252,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, cons | |||||||
| #endif // __CUDA_ARCH__ >= 600 | #endif // __CUDA_ARCH__ >= 600 | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { | static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { | ||||||
| #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | ||||||
|     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; |     const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; | ||||||
|  |  | ||||||
| @@ -1277,7 +1277,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, cons | |||||||
| #endif // __CUDA_ARCH__ >= 600 | #endif // __CUDA_ARCH__ >= 600 | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { | static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { | ||||||
| #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | ||||||
|     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; |     const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; | ||||||
|  |  | ||||||
| @@ -1312,7 +1312,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, cons | |||||||
| #endif // __CUDA_ARCH__ >= 600 | #endif // __CUDA_ARCH__ >= 600 | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { | static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { | ||||||
| #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | ||||||
|     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; |     const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; | ||||||
|  |  | ||||||
| @@ -1346,7 +1346,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, cons | |||||||
| #endif // __CUDA_ARCH__ >= 600 | #endif // __CUDA_ARCH__ >= 600 | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { | static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { | ||||||
| #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | #if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics | ||||||
|     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; |     const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; | ||||||
|  |  | ||||||
| @@ -1366,7 +1366,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, cons | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda> | template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda> | ||||||
| static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) { | static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { | ||||||
|     const int row = blockIdx.y*blockDim.y + threadIdx.y; |     const int row = blockIdx.y*blockDim.y + threadIdx.y; | ||||||
|  |  | ||||||
|     if (row >= nrows) { |     if (row >= nrows) { | ||||||
| @@ -1404,7 +1404,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||||||
| static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { | static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { | ||||||
|     // qk = quantized weights per x block |     // qk = quantized weights per x block | ||||||
|     // qr = number of quantized weights per data value in x block |     // qr = number of quantized weights per data value in x block | ||||||
|     const int row = blockIdx.y*blockDim.y + threadIdx.y; |     const int row = blockIdx.y*blockDim.y + threadIdx.y; | ||||||
| @@ -1471,7 +1471,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { | static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { | ||||||
|     const half * x = (const half *) vx; |     const half * x = (const half *) vx; | ||||||
|  |  | ||||||
|     const int row_x = blockDim.y*blockIdx.y + threadIdx.y; |     const int row_x = blockDim.y*blockIdx.y + threadIdx.y; | ||||||
| @@ -1518,7 +1518,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl | |||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | ||||||
|     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, |     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, | ||||||
|     const int row_stride_x, const int channel_stride_x) { |     const int row_stride_x, const int channel_stride_x) { | ||||||
|  |  | ||||||
|     const half * x = (const half *) vx; |     const half * x = (const half *) vx; | ||||||
| @@ -2355,10 +2355,7 @@ inline void ggml_cuda_op_mul_mat_vec( | |||||||
|         src0->type == GGML_TYPE_Q5_1 || |         src0->type == GGML_TYPE_Q5_1 || | ||||||
|         src0->type == GGML_TYPE_Q8_0; |         src0->type == GGML_TYPE_Q8_0; | ||||||
|  |  | ||||||
|     // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6. |     const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented; | ||||||
|     // However, they have bad performance with Pascal cards. |  | ||||||
|     // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q. |  | ||||||
|     const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented; |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     if (use_mul_mat_vec_q) { |     if (use_mul_mat_vec_q) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler