mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Improve cuBLAS performance by dequantizing on the GPU (#1065)
This commit is contained in:
		| @@ -110,6 +110,7 @@ if (APPLE AND LLAMA_ACCELERATE) | |||||||
|         message(WARNING "Accelerate framework not found") |         message(WARNING "Accelerate framework not found") | ||||||
|     endif() |     endif() | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
| if (LLAMA_OPENBLAS) | if (LLAMA_OPENBLAS) | ||||||
|     if (LLAMA_STATIC) |     if (LLAMA_STATIC) | ||||||
|         set(BLA_STATIC ON) |         set(BLA_STATIC ON) | ||||||
| @@ -150,6 +151,10 @@ if (LLAMA_CUBLAS) | |||||||
|     if (CUDAToolkit_FOUND) |     if (CUDAToolkit_FOUND) | ||||||
|         message(STATUS "cuBLAS found") |         message(STATUS "cuBLAS found") | ||||||
|  |  | ||||||
|  |         enable_language(CUDA) | ||||||
|  |  | ||||||
|  |         set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) | ||||||
|  |  | ||||||
|         add_compile_definitions(GGML_USE_CUBLAS) |         add_compile_definitions(GGML_USE_CUBLAS) | ||||||
|  |  | ||||||
|         if (LLAMA_STATIC) |         if (LLAMA_STATIC) | ||||||
| @@ -241,21 +246,26 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") | |||||||
|     message(STATUS "x86 detected") |     message(STATUS "x86 detected") | ||||||
|     if (MSVC) |     if (MSVC) | ||||||
|         if (LLAMA_AVX512) |         if (LLAMA_AVX512) | ||||||
|             add_compile_options(/arch:AVX512) |             add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>) | ||||||
|  |             add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>) | ||||||
|             # MSVC has no compile-time flags enabling specific |             # MSVC has no compile-time flags enabling specific | ||||||
|             # AVX512 extensions, neither it defines the |             # AVX512 extensions, neither it defines the | ||||||
|             # macros corresponding to the extensions. |             # macros corresponding to the extensions. | ||||||
|             # Do it manually. |             # Do it manually. | ||||||
|             if (LLAMA_AVX512_VBMI) |             if (LLAMA_AVX512_VBMI) | ||||||
|                 add_compile_definitions(__AVX512VBMI__) |                 add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>) | ||||||
|  |                 add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>) | ||||||
|             endif() |             endif() | ||||||
|             if (LLAMA_AVX512_VNNI) |             if (LLAMA_AVX512_VNNI) | ||||||
|                 add_compile_definitions(__AVX512VNNI__) |                 add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>) | ||||||
|  |                 add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>) | ||||||
|             endif() |             endif() | ||||||
|         elseif (LLAMA_AVX2) |         elseif (LLAMA_AVX2) | ||||||
|             add_compile_options(/arch:AVX2) |             add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>) | ||||||
|  |             add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>) | ||||||
|         elseif (LLAMA_AVX) |         elseif (LLAMA_AVX) | ||||||
|             add_compile_options(/arch:AVX) |             add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>) | ||||||
|  |             add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>) | ||||||
|         endif() |         endif() | ||||||
|     else() |     else() | ||||||
|         if (LLAMA_F16C) |         if (LLAMA_F16C) | ||||||
| @@ -292,7 +302,8 @@ endif() | |||||||
|  |  | ||||||
| add_library(ggml OBJECT | add_library(ggml OBJECT | ||||||
|             ggml.c |             ggml.c | ||||||
|             ggml.h) |             ggml.h | ||||||
|  |             ${GGML_CUDA_SOURCES}) | ||||||
|  |  | ||||||
| target_include_directories(ggml PUBLIC .) | target_include_directories(ggml PUBLIC .) | ||||||
| target_compile_features(ggml PUBLIC c_std_11) # don't bump | target_compile_features(ggml PUBLIC c_std_11) # don't bump | ||||||
| @@ -314,6 +325,14 @@ if (BUILD_SHARED_LIBS) | |||||||
|     target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) |     target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|  | if (GGML_CUDA_SOURCES) | ||||||
|  |     message(STATUS "GGML CUDA sources found, configuring CUDA architecture") | ||||||
|  |     set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) | ||||||
|  |     set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") | ||||||
|  |     set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF) | ||||||
|  | endif() | ||||||
|  |  | ||||||
|  |  | ||||||
| # | # | ||||||
| # programs, examples and tests | # programs, examples and tests | ||||||
| # | # | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								Makefile
									
									
									
									
									
								
							| @@ -1,3 +1,6 @@ | |||||||
|  | # Define the default target now so that it is always the first target | ||||||
|  | default: main quantize quantize-stats perplexity embedding vdot | ||||||
|  |  | ||||||
| ifndef UNAME_S | ifndef UNAME_S | ||||||
| UNAME_S := $(shell uname -s) | UNAME_S := $(shell uname -s) | ||||||
| endif | endif | ||||||
| @@ -100,6 +103,9 @@ endif | |||||||
| ifdef LLAMA_CUBLAS | ifdef LLAMA_CUBLAS | ||||||
| 	CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include | 	CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include | ||||||
| 	LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 | 	LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 | ||||||
|  | 	OBJS	+= ggml-cuda.o | ||||||
|  | ggml-cuda.o: ggml-cuda.cu ggml-cuda.h | ||||||
|  | 	nvcc -arch=native -c -o $@ $< | ||||||
| endif | endif | ||||||
| ifdef LLAMA_GPROF | ifdef LLAMA_GPROF | ||||||
| 	CFLAGS   += -pg | 	CFLAGS   += -pg | ||||||
| @@ -137,8 +143,6 @@ $(info I CC:       $(CCV)) | |||||||
| $(info I CXX:      $(CXXV)) | $(info I CXX:      $(CXXV)) | ||||||
| $(info ) | $(info ) | ||||||
|  |  | ||||||
| default: main quantize quantize-stats perplexity embedding vdot |  | ||||||
|  |  | ||||||
| # | # | ||||||
| # Build library | # Build library | ||||||
| # | # | ||||||
| @@ -155,35 +159,35 @@ common.o: examples/common.cpp examples/common.h | |||||||
| clean: | clean: | ||||||
| 	rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult | 	rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult | ||||||
|  |  | ||||||
| main: examples/main/main.cpp ggml.o llama.o common.o | main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
| 	@echo | 	@echo | ||||||
| 	@echo '====  Run ./main -h for help.  ====' | 	@echo '====  Run ./main -h for help.  ====' | ||||||
| 	@echo | 	@echo | ||||||
|  |  | ||||||
| quantize: examples/quantize/quantize.cpp ggml.o llama.o | quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o | quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o | perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o | embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| vdot: pocs/vdot/vdot.cpp ggml.o | vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) | ||||||
|  |  | ||||||
| libllama.so: llama.o ggml.o | libllama.so: llama.o ggml.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) | ||||||
|  |  | ||||||
| # | # | ||||||
| # Tests | # Tests | ||||||
| # | # | ||||||
|  |  | ||||||
| benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o | benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS) | ||||||
| 	$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS) | 	$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS) | ||||||
| 	./benchmark-q4_0-matmult | 	./benchmark-q4_0-matmult | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										116
									
								
								ggml-cuda.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								ggml-cuda.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,116 @@ | |||||||
|  | #include <stdint.h> | ||||||
|  | #include <cuda_fp16.h> | ||||||
|  | #include "ggml-cuda.h" | ||||||
|  |  | ||||||
|  | typedef uint16_t ggml_fp16_t; | ||||||
|  | static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | ||||||
|  |  | ||||||
|  | #define QK4_0 32 | ||||||
|  | typedef struct { | ||||||
|  |     float   d;              // delta | ||||||
|  |     uint8_t qs[QK4_0 / 2];  // nibbles / quants | ||||||
|  | } block_q4_0; | ||||||
|  | static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); | ||||||
|  |  | ||||||
|  | #define QK4_1 32 | ||||||
|  | typedef struct { | ||||||
|  |     float   d;              // delta | ||||||
|  |     float   m;              // min | ||||||
|  |     uint8_t qs[QK4_1 / 2];  // nibbles / quants | ||||||
|  | } block_q4_1; | ||||||
|  | static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); | ||||||
|  |  | ||||||
|  | #define QK4_2 16 | ||||||
|  | typedef struct { | ||||||
|  |     __half d;               // delta | ||||||
|  |     uint8_t qs[QK4_2 / 2];  // nibbles / quants | ||||||
|  | } block_q4_2; | ||||||
|  | static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); | ||||||
|  |  | ||||||
|  |  | ||||||
|  | static __global__ void dequantize_block_q4_0(const void * vx, float * y) { | ||||||
|  |     const block_q4_0 * x = (const block_q4_0 *) vx; | ||||||
|  |  | ||||||
|  |     const int i = blockIdx.x; | ||||||
|  |  | ||||||
|  |     const float d = x[i].d; | ||||||
|  |  | ||||||
|  |     const uint8_t * pp = x[i].qs; | ||||||
|  |  | ||||||
|  |     for (int l = 0; l < QK4_0; l += 2) { | ||||||
|  |         const uint8_t vi = pp[l/2]; | ||||||
|  |  | ||||||
|  |         const int8_t vi0 = vi & 0xf; | ||||||
|  |         const int8_t vi1 = vi >> 4; | ||||||
|  |  | ||||||
|  |         const float v0 = (vi0 - 8)*d; | ||||||
|  |         const float v1 = (vi1 - 8)*d; | ||||||
|  |  | ||||||
|  |         y[i*QK4_0 + l + 0] = v0; | ||||||
|  |         y[i*QK4_0 + l + 1] = v1; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static __global__ void dequantize_block_q4_1(const void * vx, float * y) { | ||||||
|  |     const block_q4_1 * x = (const block_q4_1 *) vx; | ||||||
|  |  | ||||||
|  |     const int i = blockIdx.x; | ||||||
|  |  | ||||||
|  |     const float d = x[i].d; | ||||||
|  |     const float m = x[i].m; | ||||||
|  |  | ||||||
|  |     const uint8_t * pp = x[i].qs; | ||||||
|  |  | ||||||
|  |     for (int l = 0; l < QK4_1; l += 2) { | ||||||
|  |         const uint8_t vi = pp[l/2]; | ||||||
|  |  | ||||||
|  |         const int8_t vi0 = vi & 0xf; | ||||||
|  |         const int8_t vi1 = vi >> 4; | ||||||
|  |  | ||||||
|  |         const float v0 = vi0*d + m; | ||||||
|  |         const float v1 = vi1*d + m; | ||||||
|  |  | ||||||
|  |         y[i*QK4_1 + l + 0] = v0; | ||||||
|  |         y[i*QK4_1 + l + 1] = v1; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static __global__ void dequantize_block_q4_2(const void * vx, float * y) { | ||||||
|  |     const block_q4_2 * x = (const block_q4_2 *) vx; | ||||||
|  |  | ||||||
|  |     const int i = blockIdx.x; | ||||||
|  |  | ||||||
|  |     const float d = x[i].d; | ||||||
|  |  | ||||||
|  |     const uint8_t * pp = x[i].qs; | ||||||
|  |  | ||||||
|  |     for (int l = 0; l < QK4_2; l += 2) { | ||||||
|  |         const uint8_t vi = pp[l/2]; | ||||||
|  |  | ||||||
|  |         const int8_t vi0 = vi & 0xf; | ||||||
|  |         const int8_t vi1 = vi >> 4; | ||||||
|  |  | ||||||
|  |         const float v0 = (vi0 - 8)*d; | ||||||
|  |         const float v1 = (vi1 - 8)*d; | ||||||
|  |  | ||||||
|  |         y[i*QK4_2 + l + 0] = v0; | ||||||
|  |         y[i*QK4_2 + l + 1] = v1; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | extern "C" { | ||||||
|  |     __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|  |         const int nb = k / QK4_0; | ||||||
|  |         dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|  |         const int nb = k / QK4_1; | ||||||
|  |         dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|  |         const int nb = k / QK4_2; | ||||||
|  |         dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										11
									
								
								ggml-cuda.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								ggml-cuda.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | |||||||
|  | #ifdef  __cplusplus | ||||||
|  | extern "C" { | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
|  | void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
|  | void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
|  |  | ||||||
|  | #ifdef  __cplusplus | ||||||
|  | } | ||||||
|  | #endif | ||||||
							
								
								
									
										80
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										80
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -150,23 +150,25 @@ inline static void* ggml_aligned_malloc(size_t size) { | |||||||
| #elif defined(GGML_USE_CUBLAS) | #elif defined(GGML_USE_CUBLAS) | ||||||
| #include <cublas_v2.h> | #include <cublas_v2.h> | ||||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||||
| #define CUDA_CHECK(err)                                                                            \ | #include "ggml-cuda.h" | ||||||
|     do {                                                                                           \ |  | ||||||
|         cudaError_t err_ = (err);                                                                  \ | #define CUDA_CHECK(err)                                                        \ | ||||||
|         if (err_ != cudaSuccess) {                                                                 \ |     do {                                                                       \ | ||||||
|             printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,                       \ |         cudaError_t err_ = (err);                                              \ | ||||||
|                 cudaGetErrorString(err_));                                                         \ |         if (err_ != cudaSuccess) {                                             \ | ||||||
|             exit(1);                                                                               \ |             printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \ | ||||||
|         }                                                                                          \ |                 cudaGetErrorString(err_));                                     \ | ||||||
|  |             exit(1);                                                           \ | ||||||
|  |         }                                                                      \ | ||||||
|     } while (0) |     } while (0) | ||||||
|  |  | ||||||
| #define CUBLAS_CHECK(err)                                                                          \ | #define CUBLAS_CHECK(err)                                                      \ | ||||||
|     do {                                                                                           \ |     do {                                                                       \ | ||||||
|         cublasStatus_t err_ = (err);                                                               \ |         cublasStatus_t err_ = (err);                                           \ | ||||||
|         if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \ |         if (err_ != CUBLAS_STATUS_SUCCESS) {                                   \ | ||||||
|             printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);                        \ |             printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \ | ||||||
|             exit(1);                                                                               \ |             exit(1);                                                           \ | ||||||
|         }                                                                                          \ |         }                                                                      \ | ||||||
|     } while (0) |     } while (0) | ||||||
|  |  | ||||||
| static cublasHandle_t cublasH = NULL; | static cublasHandle_t cublasH = NULL; | ||||||
| @@ -177,6 +179,7 @@ static void init_cublas(void) { | |||||||
|         CUBLAS_CHECK(cublasCreate(&cublasH)); |         CUBLAS_CHECK(cublasCreate(&cublasH)); | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); |         CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); | ||||||
|  |  | ||||||
|         CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); |         CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); | ||||||
|  |  | ||||||
|         // configure logging to stdout |         // configure logging to stdout | ||||||
| @@ -7311,7 +7314,6 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |  | ||||||
| #else | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
| @@ -7323,6 +7325,7 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         CUDA_CHECK(cudaFree(d_D)); | ||||||
| @@ -7535,7 +7538,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |  | ||||||
| #else | #else | ||||||
|                 const float * x = wdata; |                 const float * x = wdata; | ||||||
|                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); |                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | ||||||
| @@ -7553,6 +7555,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         CUDA_CHECK(cudaFree(d_D)); | ||||||
| @@ -7722,13 +7725,11 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float * const wdata = params->wdata; |  | ||||||
|         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; |  | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         float *d_X = NULL; |         float *d_X = NULL; | ||||||
|         float *d_Y = NULL; |         float *d_Y = NULL; | ||||||
|         float *d_D = NULL; |         float *d_D = NULL; | ||||||
|  |         float *d_Q = NULL; | ||||||
|         const float alpha = 1.0f; |         const float alpha = 1.0f; | ||||||
|         const float beta = 0.0f; |         const float beta = 0.0f; | ||||||
|         const int x_ne = ne01 * ne10; |         const int x_ne = ne01 * ne10; | ||||||
| @@ -7738,10 +7739,41 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); |         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); |         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); |         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); | ||||||
|  |  | ||||||
|  |         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL; | ||||||
|  |         if (type == GGML_TYPE_Q4_0) { | ||||||
|  |             dequantize_row_q_cuda = dequantize_row_q4_0_cuda; | ||||||
|  |         } | ||||||
|  |         else if (type == GGML_TYPE_Q4_1) { | ||||||
|  |             dequantize_row_q_cuda = dequantize_row_q4_1_cuda; | ||||||
|  |         } | ||||||
|  |         else if (type == GGML_TYPE_Q4_2) { | ||||||
|  |             dequantize_row_q_cuda = dequantize_row_q4_2_cuda; | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             GGML_ASSERT(false); | ||||||
|  |         } | ||||||
|  | #else | ||||||
|  |         float * const wdata = params->wdata; | ||||||
|  |         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|             for (int64_t i02 = 0; i02 < ne02; i02++) { |             for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
|  |                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | ||||||
|  |  | ||||||
|  |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |                 // copy and dequantize on device | ||||||
|  |                 CUDA_CHECK( | ||||||
|  |                     cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, | ||||||
|  |                         GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |  | ||||||
|  |                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); | ||||||
|  |                 CUDA_CHECK(cudaGetLastError()); | ||||||
|  | #else | ||||||
|                 { |                 { | ||||||
|                     size_t id = 0; |                     size_t id = 0; | ||||||
|                     for (int64_t i01 = 0; i01 < ne01; ++i01) { |                     for (int64_t i01 = 0; i01 < ne01; ++i01) { | ||||||
| @@ -7749,15 +7781,12 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|                         id += ne00; |                         id += ne00; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 const float * x = wdata; |                 const float * x = wdata; | ||||||
|                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | #endif | ||||||
|  |  | ||||||
|                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); |  | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|                 // copy data to device |                 // copy data to device | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); |  | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |  | ||||||
|                 // compute |                 // compute | ||||||
| @@ -7770,7 +7799,6 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |  | ||||||
| #else | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
| @@ -7783,9 +7811,11 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         CUDA_CHECK(cudaFree(d_D)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_Q)); | ||||||
| #endif | #endif | ||||||
|         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); |         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren