mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : performance optimizations (#1530)
* xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options
This commit is contained in:
		| @@ -37,42 +37,44 @@ endif() | ||||
| # | ||||
|  | ||||
| # general | ||||
| option(LLAMA_STATIC                 "llama: static link libraries"                          OFF) | ||||
| option(LLAMA_NATIVE                 "llama: enable -march=native flag"                      OFF) | ||||
| option(LLAMA_LTO                    "llama: enable link time optimization"                  OFF) | ||||
| option(LLAMA_STATIC                     "llama: static link libraries"                          OFF) | ||||
| option(LLAMA_NATIVE                     "llama: enable -march=native flag"                      OFF) | ||||
| option(LLAMA_LTO                        "llama: enable link time optimization"                  OFF) | ||||
|  | ||||
| # debug | ||||
| option(LLAMA_ALL_WARNINGS           "llama: enable all compiler warnings"                   ON) | ||||
| option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) | ||||
| option(LLAMA_GPROF                  "llama: enable gprof"                                   OFF) | ||||
| option(LLAMA_ALL_WARNINGS               "llama: enable all compiler warnings"                   ON) | ||||
| option(LLAMA_ALL_WARNINGS_3RD_PARTY     "llama: enable all compiler warnings in 3rd party libs" OFF) | ||||
| option(LLAMA_GPROF                      "llama: enable gprof"                                   OFF) | ||||
|  | ||||
| # sanitizers | ||||
| option(LLAMA_SANITIZE_THREAD        "llama: enable thread sanitizer"                        OFF) | ||||
| option(LLAMA_SANITIZE_ADDRESS       "llama: enable address sanitizer"                       OFF) | ||||
| option(LLAMA_SANITIZE_UNDEFINED     "llama: enable undefined sanitizer"                     OFF) | ||||
| option(LLAMA_SANITIZE_THREAD            "llama: enable thread sanitizer"                        OFF) | ||||
| option(LLAMA_SANITIZE_ADDRESS           "llama: enable address sanitizer"                       OFF) | ||||
| option(LLAMA_SANITIZE_UNDEFINED         "llama: enable undefined sanitizer"                     OFF) | ||||
|  | ||||
| # instruction set specific | ||||
| option(LLAMA_AVX                    "llama: enable AVX"                                     ON) | ||||
| option(LLAMA_AVX2                   "llama: enable AVX2"                                    ON) | ||||
| option(LLAMA_AVX512                 "llama: enable AVX512"                                  OFF) | ||||
| option(LLAMA_AVX512_VBMI            "llama: enable AVX512-VBMI"                             OFF) | ||||
| option(LLAMA_AVX512_VNNI            "llama: enable AVX512-VNNI"                             OFF) | ||||
| option(LLAMA_FMA                    "llama: enable FMA"                                     ON) | ||||
| option(LLAMA_AVX                        "llama: enable AVX"                                     ON) | ||||
| option(LLAMA_AVX2                       "llama: enable AVX2"                                    ON) | ||||
| option(LLAMA_AVX512                     "llama: enable AVX512"                                  OFF) | ||||
| option(LLAMA_AVX512_VBMI                "llama: enable AVX512-VBMI"                             OFF) | ||||
| option(LLAMA_AVX512_VNNI                "llama: enable AVX512-VNNI"                             OFF) | ||||
| option(LLAMA_FMA                        "llama: enable FMA"                                     ON) | ||||
| # in MSVC F16C is implied with AVX2/AVX512 | ||||
| if (NOT MSVC) | ||||
|     option(LLAMA_F16C               "llama: enable F16C"                                    ON) | ||||
|     option(LLAMA_F16C                   "llama: enable F16C"                                    ON) | ||||
| endif() | ||||
|  | ||||
| # 3rd party libs | ||||
| option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON) | ||||
| option(LLAMA_BLAS                   "llama: use BLAS"                                       OFF) | ||||
| option(LLAMA_BLAS_VENDOR            "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) | ||||
| option(LLAMA_CUBLAS                 "llama: use cuBLAS"                                     OFF) | ||||
| option(LLAMA_CLBLAST                "llama: use CLBlast"                                    OFF) | ||||
| option(LLAMA_ACCELERATE                 "llama: enable Accelerate framework"                    ON) | ||||
| option(LLAMA_BLAS                       "llama: use BLAS"                                       OFF) | ||||
| option(LLAMA_BLAS_VENDOR                "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) | ||||
| option(LLAMA_CUBLAS                     "llama: use cuBLAS"                                     OFF) | ||||
| set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") | ||||
| set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING  "llama: y block size for dmmv CUDA kernels") | ||||
| option(LLAMA_CLBLAST                    "llama: use CLBlast"                                    OFF) | ||||
|  | ||||
| option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE}) | ||||
| option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE}) | ||||
| option(LLAMA_BUILD_SERVER           "llama: build server example"                           OFF) | ||||
| option(LLAMA_BUILD_TESTS                "llama: build tests"    ${LLAMA_STANDALONE}) | ||||
| option(LLAMA_BUILD_EXAMPLES             "llama: build examples" ${LLAMA_STANDALONE}) | ||||
| option(LLAMA_BUILD_SERVER               "llama: build server example"                           OFF) | ||||
|  | ||||
| # | ||||
| # Build info header | ||||
| @@ -184,6 +186,8 @@ if (LLAMA_CUBLAS) | ||||
|         set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) | ||||
|  | ||||
|         add_compile_definitions(GGML_USE_CUBLAS) | ||||
|         add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) | ||||
|         add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) | ||||
|  | ||||
|         if (LLAMA_STATIC) | ||||
|             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) | ||||
|   | ||||
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS | ||||
| 	OBJS      += ggml-cuda.o | ||||
| 	NVCC      = nvcc | ||||
| 	NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native | ||||
| ifdef LLAMA_CUDA_DMMV_X | ||||
| 	NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) | ||||
| else | ||||
| 	NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 | ||||
| endif # LLAMA_CUDA_DMMV_X | ||||
| ifdef LLAMA_CUDA_DMMV_Y | ||||
| 	NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) | ||||
| else | ||||
| 	NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 | ||||
| endif # LLAMA_CUDA_DMMV_Y | ||||
| ggml-cuda.o: ggml-cuda.cu ggml-cuda.h | ||||
| 	$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ | ||||
| endif | ||||
| endif # LLAMA_CUBLAS | ||||
| ifdef LLAMA_CLBLAST | ||||
| 	CFLAGS  += -DGGML_USE_CLBLAST | ||||
| 	CXXFLAGS  += -DGGML_USE_CLBLAST | ||||
|   | ||||
							
								
								
									
										110
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										110
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -83,9 +83,19 @@ typedef struct { | ||||
| } block_q8_0; | ||||
| static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); | ||||
|  | ||||
| #define WARP_SIZE 32 | ||||
|  | ||||
| #define CUDA_MUL_BLOCK_SIZE 256 | ||||
|  | ||||
| #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 | ||||
| #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec | ||||
|  | ||||
| // dmmv = dequantize_mul_mat_vec | ||||
| #ifndef GGML_CUDA_DMMV_X | ||||
| #define GGML_CUDA_DMMV_X 32 | ||||
| #endif | ||||
| #ifndef GGML_CUDA_DMMV_Y | ||||
| #define GGML_CUDA_DMMV_Y 1 | ||||
| #endif | ||||
|  | ||||
| static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { | ||||
|     const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||||
| @@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) | ||||
|     dequantize_kernel(vx, ib, iqs, v0, v1); | ||||
| } | ||||
|  | ||||
| template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||||
| static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { | ||||
|     const int row = blockIdx.x; | ||||
|     // qk = quantized weights per x block | ||||
|     // qr = number of quantized weights per data value in x block | ||||
|     const int row = blockIdx.x*blockDim.y + threadIdx.y; | ||||
|     const int tid = threadIdx.x; | ||||
|  | ||||
|     const int iter_stride = 2*GGML_CUDA_DMMV_X; | ||||
|     const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter | ||||
|     const int y_offset = qr == 1 ? 1 : qk/2; | ||||
|  | ||||
|     __shared__ float tmp[block_size]; // separate sum for each thread | ||||
|     tmp[tid] = 0; | ||||
|     float tmp = 0; // partial sum for thread in warp | ||||
|  | ||||
|     for (int i = 0; i < ncols/block_size; i += 2) { | ||||
|         const int col = i*block_size + 2*tid; | ||||
|         const int ib = (row*ncols + col)/qk; // block index | ||||
|         const int iqs = (col%qk)/qr; // quant index | ||||
|     for (int i = 0; i < ncols; i += iter_stride) { | ||||
|         const int col = i + vals_per_iter*tid; | ||||
|         const int ib = (row*ncols + col)/qk; // x block index | ||||
|         const int iqs = (col%qk)/qr; // x quant index | ||||
|         const int iybs = col - col%qk; // y block start index | ||||
|  | ||||
|         // dequantize | ||||
|         float v0, v1; | ||||
|         dequantize_kernel(vx, ib, iqs, v0, v1); | ||||
| // processing >2 values per i iter is faster for fast GPUs | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < vals_per_iter; j += 2) { | ||||
|             // process 2 vals per j iter | ||||
|  | ||||
|         // matrix multiplication | ||||
|         tmp[tid] += v0 * y[iybs + iqs + 0]; | ||||
|         tmp[tid] += v1 * y[iybs + iqs + y_offset]; | ||||
|             // dequantize | ||||
|             float v0, v1; | ||||
|             dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); | ||||
|             // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val | ||||
|  | ||||
|             // matrix multiplication | ||||
|             tmp += v0 * y[iybs + iqs + j/qr + 0]; | ||||
|             tmp += v1 * y[iybs + iqs + j/qr + y_offset]; | ||||
|             // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
|     __syncthreads(); | ||||
|     for (int s=block_size/2; s>0; s>>=1) { | ||||
|         if (tid < s) { | ||||
|             tmp[tid] += tmp[tid + s]; | ||||
|         } | ||||
|         __syncthreads(); | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
|         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); | ||||
|     } | ||||
|  | ||||
|     if (tid == 0) { | ||||
|         dst[row] = tmp[0]; | ||||
|         dst[row] = tmp; | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu | ||||
| } | ||||
|  | ||||
| static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { | ||||
| @@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c | ||||
| } | ||||
|  | ||||
| static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); | ||||
|     dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16> | ||||
|         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); | ||||
|     dequantize_mul_mat_vec<1, 1, convert_f16> | ||||
|         <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols); | ||||
| } | ||||
|  | ||||
| static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler