mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Improve cuBLAS performance by using a memory pool (#1094)
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
This commit is contained in:
		
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @@ -104,8 +104,10 @@ ifdef LLAMA_CUBLAS | |||||||
| 	CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include | 	CFLAGS    += -DGGML_USE_CUBLAS -I/usr/local/cuda/include | ||||||
| 	LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 | 	LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 | ||||||
| 	OBJS      += ggml-cuda.o | 	OBJS      += ggml-cuda.o | ||||||
|  | 	NVCC      = nvcc | ||||||
|  | 	NVCCFLAGS = --forward-unknown-to-host-linker -arch=native | ||||||
| ggml-cuda.o: ggml-cuda.cu ggml-cuda.h | ggml-cuda.o: ggml-cuda.cu ggml-cuda.h | ||||||
| 	nvcc -arch=native -c -o $@ $< | 	$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@ | ||||||
| endif | endif | ||||||
| ifdef LLAMA_GPROF | ifdef LLAMA_GPROF | ||||||
| 	CFLAGS   += -pg | 	CFLAGS   += -pg | ||||||
|   | |||||||
							
								
								
									
										94
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										94
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -1,5 +1,7 @@ | |||||||
| #include <stdint.h> | #include <stdint.h> | ||||||
|  | #include <stdio.h> | ||||||
| #include <cuda_fp16.h> | #include <cuda_fp16.h> | ||||||
|  | #include <atomic> | ||||||
| #include "ggml-cuda.h" | #include "ggml-cuda.h" | ||||||
|  |  | ||||||
| typedef uint16_t ggml_fp16_t; | typedef uint16_t ggml_fp16_t; | ||||||
| @@ -35,8 +37,6 @@ typedef struct { | |||||||
| } block_q4_3; | } block_q4_3; | ||||||
| static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); | static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| static __global__ void dequantize_block_q4_0(const void * vx, float * y) { | static __global__ void dequantize_block_q4_0(const void * vx, float * y) { | ||||||
|     const block_q4_0 * x = (const block_q4_0 *) vx; |     const block_q4_0 * x = (const block_q4_0 *) vx; | ||||||
|  |  | ||||||
| @@ -131,24 +131,98 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| extern "C" { | void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|     __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { |  | ||||||
|     const int nb = k / QK4_0; |     const int nb = k / QK4_0; | ||||||
|     dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); |     dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|     } | } | ||||||
|  |  | ||||||
|     __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|     const int nb = k / QK4_1; |     const int nb = k / QK4_1; | ||||||
|     dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y); |     dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|     } | } | ||||||
|  |  | ||||||
|     __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|     const int nb = k / QK4_2; |     const int nb = k / QK4_2; | ||||||
|     dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y); |     dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|     } | } | ||||||
|  |  | ||||||
|     __host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { | ||||||
|     const int nb = k / QK4_3; |     const int nb = k / QK4_3; | ||||||
|     dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y); |     dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // buffer pool for cuda | ||||||
|  | #define MAX_CUDA_BUFFERS 16 | ||||||
|  |  | ||||||
|  | struct scoped_spin_lock { | ||||||
|  |     std::atomic_flag& lock; | ||||||
|  |     scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { | ||||||
|  |         while (lock.test_and_set(std::memory_order_acquire)) { | ||||||
|  |             ; // spin | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     ~scoped_spin_lock() { | ||||||
|  |         lock.clear(std::memory_order_release); | ||||||
|  |     } | ||||||
|  |     scoped_spin_lock(const scoped_spin_lock&) = delete; | ||||||
|  |     scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct cuda_buffer { | ||||||
|  |     void * ptr = nullptr; | ||||||
|  |     size_t size = 0; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; | ||||||
|  | static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; | ||||||
|  |  | ||||||
|  | void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | ||||||
|  |     scoped_spin_lock lock(g_cuda_pool_lock); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { | ||||||
|  |         cuda_buffer& b = g_cuda_buffer_pool[i]; | ||||||
|  |         if (b.size >= size && b.ptr != nullptr) { | ||||||
|  |             void * ptr = b.ptr; | ||||||
|  |             *actual_size = b.size; | ||||||
|  |             b.ptr = nullptr; | ||||||
|  |             b.size = 0; | ||||||
|  |             return ptr; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     void * ptr; | ||||||
|  |     CUDA_CHECK(cudaMalloc((void **) &ptr, size)); | ||||||
|  |     *actual_size = size; | ||||||
|  |     return ptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_pool_free(void * ptr, size_t size) { | ||||||
|  |     scoped_spin_lock lock(g_cuda_pool_lock); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { | ||||||
|  |         cuda_buffer& b = g_cuda_buffer_pool[i]; | ||||||
|  |         if (b.ptr == nullptr) { | ||||||
|  |             b.ptr = ptr; | ||||||
|  |             b.size = size; | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); | ||||||
|  |     CUDA_CHECK(cudaFree(ptr)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | cublasHandle_t g_cublasH = NULL; | ||||||
|  | cudaStream_t g_cudaStream = NULL; | ||||||
|  |  | ||||||
|  | void ggml_init_cublas(void) { | ||||||
|  |     if (g_cublasH == NULL) { | ||||||
|  |         // create cublas handle, bind a stream | ||||||
|  |         CUBLAS_CHECK(cublasCreate(&g_cublasH)); | ||||||
|  |  | ||||||
|  |         CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); | ||||||
|  |  | ||||||
|  |         CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); | ||||||
|  |  | ||||||
|  |         // configure logging to stdout | ||||||
|  |         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								ggml-cuda.h
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								ggml-cuda.h
									
									
									
									
									
								
							| @@ -1,7 +1,36 @@ | |||||||
|  | #include <cublas_v2.h> | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  |  | ||||||
| #ifdef  __cplusplus | #ifdef  __cplusplus | ||||||
| extern "C" { | extern "C" { | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | #define CUDA_CHECK(err)                                                                 \ | ||||||
|  |     do {                                                                                \ | ||||||
|  |         cudaError_t err_ = (err);                                                       \ | ||||||
|  |         if (err_ != cudaSuccess) {                                                      \ | ||||||
|  |             fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \ | ||||||
|  |                 cudaGetErrorString(err_));                                              \ | ||||||
|  |             exit(1);                                                                    \ | ||||||
|  |         }                                                                               \ | ||||||
|  |     } while (0) | ||||||
|  |  | ||||||
|  | #define CUBLAS_CHECK(err)                                                               \ | ||||||
|  |     do {                                                                                \ | ||||||
|  |         cublasStatus_t err_ = (err);                                                    \ | ||||||
|  |         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \ | ||||||
|  |             fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \ | ||||||
|  |             exit(1);                                                                    \ | ||||||
|  |         }                                                                               \ | ||||||
|  |     } while (0) | ||||||
|  |  | ||||||
|  | extern cublasHandle_t g_cublasH; | ||||||
|  | extern cudaStream_t   g_cudaStream; | ||||||
|  |  | ||||||
|  | void   ggml_init_cublas(void); | ||||||
|  | void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); | ||||||
|  | void   ggml_cuda_pool_free(void * ptr, size_t size); | ||||||
|  |  | ||||||
| void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); | void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
| void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); | void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
| void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); | void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); | ||||||
|   | |||||||
							
								
								
									
										124
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										124
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) { | |||||||
| #elif defined(GGML_USE_OPENBLAS) | #elif defined(GGML_USE_OPENBLAS) | ||||||
| #include <cblas.h> | #include <cblas.h> | ||||||
| #elif defined(GGML_USE_CUBLAS) | #elif defined(GGML_USE_CUBLAS) | ||||||
| #include <cublas_v2.h> |  | ||||||
| #include <cuda_runtime.h> |  | ||||||
| #include "ggml-cuda.h" | #include "ggml-cuda.h" | ||||||
|  |  | ||||||
| #define CUDA_CHECK(err)                                                        \ |  | ||||||
|     do {                                                                       \ |  | ||||||
|         cudaError_t err_ = (err);                                              \ |  | ||||||
|         if (err_ != cudaSuccess) {                                             \ |  | ||||||
|             printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \ |  | ||||||
|                 cudaGetErrorString(err_));                                     \ |  | ||||||
|             exit(1);                                                           \ |  | ||||||
|         }                                                                      \ |  | ||||||
|     } while (0) |  | ||||||
|  |  | ||||||
| #define CUBLAS_CHECK(err)                                                      \ |  | ||||||
|     do {                                                                       \ |  | ||||||
|         cublasStatus_t err_ = (err);                                           \ |  | ||||||
|         if (err_ != CUBLAS_STATUS_SUCCESS) {                                   \ |  | ||||||
|             printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \ |  | ||||||
|             exit(1);                                                           \ |  | ||||||
|         }                                                                      \ |  | ||||||
|     } while (0) |  | ||||||
|  |  | ||||||
| static cublasHandle_t cublasH = NULL; |  | ||||||
| static cudaStream_t cudaStream = NULL; |  | ||||||
| static void init_cublas(void) { |  | ||||||
|     if (cublasH == NULL) { |  | ||||||
|         // create cublas handle, bind a stream |  | ||||||
|         CUBLAS_CHECK(cublasCreate(&cublasH)); |  | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); |  | ||||||
|  |  | ||||||
|         CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); |  | ||||||
|  |  | ||||||
|         // configure logging to stdout |  | ||||||
|         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #undef MIN | #undef MIN | ||||||
| @@ -3748,7 +3711,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { | |||||||
|  |  | ||||||
|         // initialize cuBLAS |         // initialize cuBLAS | ||||||
|         #if defined(GGML_USE_CUBLAS) |         #if defined(GGML_USE_CUBLAS) | ||||||
|         init_cublas(); |         ggml_init_cublas(); | ||||||
|         #endif |         #endif | ||||||
|  |  | ||||||
|         is_first_call = false; |         is_first_call = false; | ||||||
| @@ -7594,18 +7557,16 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         float *d_X = NULL; |  | ||||||
|         float *d_Y = NULL; |  | ||||||
|         float *d_D = NULL; |  | ||||||
|         const float alpha = 1.0f; |         const float alpha = 1.0f; | ||||||
|         const float beta = 0.0f; |         const float beta = 0.0f; | ||||||
|         const int x_ne = ne01 * ne10; |         const int x_ne = ne01 * ne10; | ||||||
|         const int y_ne = ne11 * ne10; |         const int y_ne = ne11 * ne10; | ||||||
|         const int d_ne = ne11 * ne01; |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); |         size_t x_size, y_size, d_size; | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); |         float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); |         float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); | ||||||
|  |         float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
| @@ -7617,19 +7578,19 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|                 // copy data to device |                 // copy data to device | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|  |  | ||||||
|                 // compute |                 // compute | ||||||
|                 CUBLAS_CHECK( |                 CUBLAS_CHECK( | ||||||
|                     cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, |                     cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|                             ne01, ne11, ne10, |                             ne01, ne11, ne10, | ||||||
|                             &alpha, d_X, ne00, |                             &alpha, d_X, ne00, | ||||||
|                                     d_Y, ne10, |                                     d_Y, ne10, | ||||||
|                             &beta,  d_D, ne01)); |                             &beta,  d_D, ne01)); | ||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); | ||||||
| #else | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
| @@ -7641,10 +7602,10 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |         CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         ggml_cuda_pool_free(d_X, x_size); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         ggml_cuda_pool_free(d_Y, y_size); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         ggml_cuda_pool_free(d_D, d_size); | ||||||
| #endif | #endif | ||||||
|         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); |         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); | ||||||
|  |  | ||||||
| @@ -7794,18 +7755,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         ggml_fp16_t * const wdata = params->wdata; |         ggml_fp16_t * const wdata = params->wdata; | ||||||
|  |  | ||||||
|         float *d_X = NULL; |  | ||||||
|         float *d_Y = NULL; |  | ||||||
|         float *d_D = NULL; |  | ||||||
|         const float alpha = 1.0f; |         const float alpha = 1.0f; | ||||||
|         const float beta = 0.0f; |         const float beta = 0.0f; | ||||||
|         const int x_ne = ne01 * ne10; |         const int x_ne = ne01 * ne10; | ||||||
|         const int y_ne = ne11 * ne10; |         const int y_ne = ne11 * ne10; | ||||||
|         const int d_ne = ne11 * ne01; |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); |         size_t x_size, y_size, d_size; | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); |         float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); |         float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); | ||||||
|  |         float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); | ||||||
| #else | #else | ||||||
|         float * const wdata = params->wdata; |         float * const wdata = params->wdata; | ||||||
| #endif | #endif | ||||||
| @@ -7839,12 +7798,12 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|                 // copy data to device |                 // copy data to device | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|  |  | ||||||
|                 // compute |                 // compute | ||||||
|                 CUBLAS_CHECK( |                 CUBLAS_CHECK( | ||||||
|                     cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, |                     cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|                             ne01, ne11, ne10, |                             ne01, ne11, ne10, | ||||||
|                             &alpha, d_X, CUDA_R_16F, ne00, |                             &alpha, d_X, CUDA_R_16F, ne00, | ||||||
|                                     d_Y, CUDA_R_16F, ne10, |                                     d_Y, CUDA_R_16F, ne10, | ||||||
| @@ -7853,7 +7812,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                             CUBLAS_GEMM_DEFAULT)); |                             CUBLAS_GEMM_DEFAULT)); | ||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); | ||||||
| #else | #else | ||||||
|                 const float * x = wdata; |                 const float * x = wdata; | ||||||
|                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); |                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | ||||||
| @@ -7871,10 +7830,10 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |         CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         ggml_cuda_pool_free(d_X, x_size); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         ggml_cuda_pool_free(d_Y, y_size); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         ggml_cuda_pool_free(d_D, d_size); | ||||||
| #endif | #endif | ||||||
|         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ |         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ | ||||||
|  |  | ||||||
| @@ -8042,20 +8001,17 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         float *d_X = NULL; |  | ||||||
|         float *d_Y = NULL; |  | ||||||
|         float *d_D = NULL; |  | ||||||
|         float *d_Q = NULL; |  | ||||||
|         const float alpha = 1.0f; |         const float alpha = 1.0f; | ||||||
|         const float beta = 0.0f; |         const float beta = 0.0f; | ||||||
|         const int x_ne = ne01 * ne10; |         const int x_ne = ne01 * ne10; | ||||||
|         const int y_ne = ne11 * ne10; |         const int y_ne = ne11 * ne10; | ||||||
|         const int d_ne = ne11 * ne01; |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); |         size_t x_size, y_size, d_size, q_size; | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); |         float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); |         float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); | ||||||
|         CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); |         float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); | ||||||
|  |         float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); | ||||||
|  |  | ||||||
|         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL; |         void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL; | ||||||
|         if (type == GGML_TYPE_Q4_0) { |         if (type == GGML_TYPE_Q4_0) { | ||||||
| @@ -8085,9 +8041,9 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|                 // copy and dequantize on device |                 // copy and dequantize on device | ||||||
|                 CUDA_CHECK( |                 CUDA_CHECK( | ||||||
|                     cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, |                     cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, | ||||||
|                         GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); |                         GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|  |  | ||||||
|                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); |                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); | ||||||
|                 CUDA_CHECK(cudaGetLastError()); |                 CUDA_CHECK(cudaGetLastError()); | ||||||
| #else | #else | ||||||
|                 { |                 { | ||||||
| @@ -8103,18 +8059,18 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|                 // copy data to device |                 // copy data to device | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); | ||||||
|  |  | ||||||
|                 // compute |                 // compute | ||||||
|                 CUBLAS_CHECK( |                 CUBLAS_CHECK( | ||||||
|                     cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, |                     cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|                             ne01, ne11, ne10, |                             ne01, ne11, ne10, | ||||||
|                             &alpha, d_X, ne00, |                             &alpha, d_X, ne00, | ||||||
|                                     d_Y, ne10, |                                     d_Y, ne10, | ||||||
|                             &beta,  d_D, ne01)); |                             &beta,  d_D, ne01)); | ||||||
|  |  | ||||||
|                 // copy data to host |                 // copy data to host | ||||||
|                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); | ||||||
| #else | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
| @@ -8127,11 +8083,11 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|         CUDA_CHECK(cudaStreamSynchronize(cudaStream)); |         CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); | ||||||
|         CUDA_CHECK(cudaFree(d_X)); |         ggml_cuda_pool_free(d_X, x_size); | ||||||
|         CUDA_CHECK(cudaFree(d_Y)); |         ggml_cuda_pool_free(d_Y, y_size); | ||||||
|         CUDA_CHECK(cudaFree(d_D)); |         ggml_cuda_pool_free(d_D, d_size); | ||||||
|         CUDA_CHECK(cudaFree(d_Q)); |         ggml_cuda_pool_free(d_Q, q_size); | ||||||
| #endif | #endif | ||||||
|         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); |         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren