mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Add NVIDIA cuBLAS support (#1044)
This commit is contained in:
		| @@ -66,6 +66,7 @@ endif() | |||||||
| # 3rd party libs | # 3rd party libs | ||||||
| option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON) | option(LLAMA_ACCELERATE             "llama: enable Accelerate framework"                    ON) | ||||||
| option(LLAMA_OPENBLAS               "llama: use OpenBLAS"                                   OFF) | option(LLAMA_OPENBLAS               "llama: use OpenBLAS"                                   OFF) | ||||||
|  | option(LLAMA_CUBLAS                 "llama: use cuBLAS"                                     OFF) | ||||||
|  |  | ||||||
| option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE}) | option(LLAMA_BUILD_TESTS            "llama: build tests"    ${LLAMA_STANDALONE}) | ||||||
| option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE}) | option(LLAMA_BUILD_EXAMPLES         "llama: build examples" ${LLAMA_STANDALONE}) | ||||||
| @@ -142,6 +143,26 @@ if (LLAMA_OPENBLAS) | |||||||
|     endif() |     endif() | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|  | if (LLAMA_CUBLAS) | ||||||
|  |     cmake_minimum_required(VERSION 3.17) | ||||||
|  |  | ||||||
|  |     find_package(CUDAToolkit) | ||||||
|  |     if (CUDAToolkit_FOUND) | ||||||
|  |         message(STATUS "cuBLAS found") | ||||||
|  |  | ||||||
|  |         add_compile_definitions(GGML_USE_CUBLAS) | ||||||
|  |  | ||||||
|  |         if (LLAMA_STATIC) | ||||||
|  |             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) | ||||||
|  |         else() | ||||||
|  |             set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) | ||||||
|  |         endif() | ||||||
|  |  | ||||||
|  |     else() | ||||||
|  |         message(WARNING "cuBLAS not found") | ||||||
|  |     endif() | ||||||
|  | endif() | ||||||
|  |  | ||||||
| if (LLAMA_ALL_WARNINGS) | if (LLAMA_ALL_WARNINGS) | ||||||
|     if (NOT MSVC) |     if (NOT MSVC) | ||||||
|         set(c_flags |         set(c_flags | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @@ -97,6 +97,10 @@ ifdef LLAMA_OPENBLAS | |||||||
| 	CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas | 	CFLAGS  += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas | ||||||
| 	LDFLAGS += -lopenblas | 	LDFLAGS += -lopenblas | ||||||
| endif | endif | ||||||
|  | ifdef LLAMA_CUBLAS | ||||||
|  | 	CFLAGS  += -DGGML_USE_CUBLAS -I/usr/local/cuda/include | ||||||
|  | 	LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 | ||||||
|  | endif | ||||||
| ifdef LLAMA_GPROF | ifdef LLAMA_GPROF | ||||||
| 	CFLAGS   += -pg | 	CFLAGS   += -pg | ||||||
| 	CXXFLAGS += -pg | 	CXXFLAGS += -pg | ||||||
|   | |||||||
							
								
								
									
										206
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										206
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -142,10 +142,46 @@ inline static void* ggml_aligned_malloc(size_t size) { | |||||||
|         } \ |         } \ | ||||||
|     } while (0) |     } while (0) | ||||||
|  |  | ||||||
| #ifdef GGML_USE_ACCELERATE | #if defined(GGML_USE_ACCELERATE) | ||||||
| #include <Accelerate/Accelerate.h> | #include <Accelerate/Accelerate.h> | ||||||
| #elif GGML_USE_OPENBLAS | #elif defined(GGML_USE_OPENBLAS) | ||||||
| #include <cblas.h> | #include <cblas.h> | ||||||
|  | #elif defined(GGML_USE_CUBLAS) | ||||||
|  | #include <cublas_v2.h> | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  | #define CUDA_CHECK(err)                                                                            \ | ||||||
|  |     do {                                                                                           \ | ||||||
|  |         cudaError_t err_ = (err);                                                                  \ | ||||||
|  |         if (err_ != cudaSuccess) {                                                                 \ | ||||||
|  |             printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,                       \ | ||||||
|  |                 cudaGetErrorString(err_));                                                         \ | ||||||
|  |             exit(1);                                                                               \ | ||||||
|  |         }                                                                                          \ | ||||||
|  |     } while (0) | ||||||
|  |  | ||||||
|  | #define CUBLAS_CHECK(err)                                                                          \ | ||||||
|  |     do {                                                                                           \ | ||||||
|  |         cublasStatus_t err_ = (err);                                                               \ | ||||||
|  |         if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \ | ||||||
|  |             printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);                        \ | ||||||
|  |             exit(1);                                                                               \ | ||||||
|  |         }                                                                                          \ | ||||||
|  |     } while (0) | ||||||
|  |  | ||||||
|  | static cublasHandle_t cublasH = NULL; | ||||||
|  | static cudaStream_t cudaStream = NULL; | ||||||
|  | static void init_cublas(void) { | ||||||
|  |     if (cublasH == NULL) { | ||||||
|  |         // create cublas handle, bind a stream | ||||||
|  |         CUBLAS_CHECK(cublasCreate(&cublasH)); | ||||||
|  |  | ||||||
|  |         CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); | ||||||
|  |         CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); | ||||||
|  |  | ||||||
|  |         // configure logging to stdout | ||||||
|  |         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); | ||||||
|  |     } | ||||||
|  | } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #undef MIN | #undef MIN | ||||||
| @@ -3836,6 +3872,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { | |||||||
|             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); |             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         // initialize cuBLAS | ||||||
|  |         #if defined(GGML_USE_CUBLAS) | ||||||
|  |         init_cublas(); | ||||||
|  |         #endif | ||||||
|  |  | ||||||
|         is_first_call = false; |         is_first_call = false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -7567,7 +7608,7 @@ static void ggml_compute_forward_rms_norm( | |||||||
|  |  | ||||||
| // ggml_compute_forward_mul_mat | // ggml_compute_forward_mul_mat | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
| // helper function to determine if it is better to use BLAS or not | // helper function to determine if it is better to use BLAS or not | ||||||
| // for large matrices, BLAS is faster | // for large matrices, BLAS is faster | ||||||
| static bool ggml_compute_forward_mul_mat_use_blas( | static bool ggml_compute_forward_mul_mat_use_blas( | ||||||
| @@ -7607,7 +7648,7 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|     const int64_t ne02 = src0->ne[2]; |     const int64_t ne02 = src0->ne[2]; | ||||||
|     const int64_t ne03 = src0->ne[3]; |     const int64_t ne03 = src0->ne[3]; | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
| #endif | #endif | ||||||
|     const int64_t ne11 = src1->ne[1]; |     const int64_t ne11 = src1->ne[1]; | ||||||
| @@ -7664,7 +7705,7 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|     // nb01 >= nb00 - src0 is not transposed |     // nb01 >= nb00 - src0 is not transposed | ||||||
|     //   compute by src0 rows |     //   compute by src0 rows | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { |     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { | ||||||
|         if (params->ith != 0) { |         if (params->ith != 0) { | ||||||
|             return; |             return; | ||||||
| @@ -7678,6 +7719,21 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         float *d_X = NULL; | ||||||
|  |         float *d_Y = NULL; | ||||||
|  |         float *d_D = NULL; | ||||||
|  |         const float alpha = 1.0f; | ||||||
|  |         const float beta = 0.0f; | ||||||
|  |         const int x_ne = ne01 * ne10; | ||||||
|  |         const int y_ne = ne11 * ne10; | ||||||
|  |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|             for (int64_t i02 = 0; i02 < ne02; i02++) { |             for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
|                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); |                 const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); | ||||||
| @@ -7685,15 +7741,37 @@ static void ggml_compute_forward_mul_mat_f32( | |||||||
|  |  | ||||||
|                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |                 // copy data to device | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |  | ||||||
|  |                 // compute | ||||||
|  |                 CUBLAS_CHECK( | ||||||
|  |                     cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                             ne01, ne11, ne10, | ||||||
|  |                             &alpha, d_X, ne00, | ||||||
|  |                                     d_Y, ne10, | ||||||
|  |                             &beta,  d_D, ne01)); | ||||||
|  |  | ||||||
|  |                 // copy data to host | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|  | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
|                         ne11, ne01, ne10, |                         ne11, ne01, ne10, | ||||||
|                         1.0f,    y, ne10, |                         1.0f,    y, ne10, | ||||||
|                                  x, ne00, |                                  x, ne00, | ||||||
|                         0.0f,    d, ne01); |                         0.0f,    d, ne01); | ||||||
|  | #endif | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_D)); | ||||||
|  | #endif | ||||||
|         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); |         //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); | ||||||
|  |  | ||||||
|         return; |         return; | ||||||
| @@ -7823,7 +7901,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|     // nb01 >= nb00 - src0 is not transposed |     // nb01 >= nb00 - src0 is not transposed | ||||||
|     //   compute by src0 rows |     //   compute by src0 rows | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { |     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { | ||||||
|         GGML_ASSERT(nb10 == sizeof(float)); |         GGML_ASSERT(nb10 == sizeof(float)); | ||||||
|  |  | ||||||
| @@ -7839,10 +7917,37 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float * const wdata = params->wdata; | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         ggml_fp16_t * const wdata = params->wdata; | ||||||
|  |  | ||||||
|  |         float *d_X = NULL; | ||||||
|  |         float *d_Y = NULL; | ||||||
|  |         float *d_D = NULL; | ||||||
|  |         const float alpha = 1.0f; | ||||||
|  |         const float beta = 0.0f; | ||||||
|  |         const int x_ne = ne01 * ne10; | ||||||
|  |         const int y_ne = ne11 * ne10; | ||||||
|  |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); | ||||||
|  | #else | ||||||
|  |         float * const wdata = params->wdata; | ||||||
|  | #endif | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|             for (int64_t i02 = 0; i02 < ne02; i02++) { |             for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |                 // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 | ||||||
|  |                 { | ||||||
|  |                     size_t id = 0; | ||||||
|  |                     for (int64_t i01 = 0; i01 < ne11; ++i01) { | ||||||
|  |                         for (int64_t i00 = 0; i00 < ne10; ++i00) { | ||||||
|  |                             wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | #else | ||||||
|                 { |                 { | ||||||
|                     size_t id = 0; |                     size_t id = 0; | ||||||
|                     for (int64_t i01 = 0; i01 < ne01; ++i01) { |                     for (int64_t i01 = 0; i01 < ne01; ++i01) { | ||||||
| @@ -7851,7 +7956,32 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |                 const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03); | ||||||
|  |                 const ggml_fp16_t * y = (ggml_fp16_t *) wdata; | ||||||
|  |  | ||||||
|  |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|  |                 // copy data to device | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |  | ||||||
|  |                 // compute | ||||||
|  |                 CUBLAS_CHECK( | ||||||
|  |                     cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                             ne01, ne11, ne10, | ||||||
|  |                             &alpha, d_X, CUDA_R_16F, ne00, | ||||||
|  |                                     d_Y, CUDA_R_16F, ne10, | ||||||
|  |                             &beta,  d_D, CUDA_R_32F, ne01, | ||||||
|  |                             CUBLAS_COMPUTE_32F, | ||||||
|  |                             CUBLAS_GEMM_DEFAULT)); | ||||||
|  |  | ||||||
|  |                 // copy data to host | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|  | #else | ||||||
|                 const float * x = wdata; |                 const float * x = wdata; | ||||||
|                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); |                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); | ||||||
|  |  | ||||||
| @@ -7863,9 +7993,15 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                         1.0f,    y, ne10, |                         1.0f,    y, ne10, | ||||||
|                                  x, ne00, |                                  x, ne00, | ||||||
|                         0.0f,    d, ne01); |                         0.0f,    d, ne01); | ||||||
|  | #endif | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_D)); | ||||||
|  | #endif | ||||||
|         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ |         /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ | ||||||
|  |  | ||||||
|         return; |         return; | ||||||
| @@ -8017,7 +8153,7 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|     // nb01 >= nb00 - src0 is not transposed |     // nb01 >= nb00 - src0 is not transposed | ||||||
|     //   compute by src0 rows |     //   compute by src0 rows | ||||||
|  |  | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { |     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { | ||||||
|         if (params->ith != 0) { |         if (params->ith != 0) { | ||||||
|             return; |             return; | ||||||
| @@ -8034,6 +8170,21 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|         float * const wdata = params->wdata; |         float * const wdata = params->wdata; | ||||||
|         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; |         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         float *d_X = NULL; | ||||||
|  |         float *d_Y = NULL; | ||||||
|  |         float *d_D = NULL; | ||||||
|  |         const float alpha = 1.0f; | ||||||
|  |         const float beta = 0.0f; | ||||||
|  |         const int x_ne = ne01 * ne10; | ||||||
|  |         const int y_ne = ne11 * ne10; | ||||||
|  |         const int d_ne = ne11 * ne01; | ||||||
|  |  | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); | ||||||
|  |         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|             for (int64_t i02 = 0; i02 < ne02; i02++) { |             for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
|                 { |                 { | ||||||
| @@ -8049,15 +8200,38 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|  |  | ||||||
|                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); |                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |                 // copy data to device | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); | ||||||
|  |  | ||||||
|  |                 // compute | ||||||
|  |                 CUBLAS_CHECK( | ||||||
|  |                     cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                             ne01, ne11, ne10, | ||||||
|  |                             &alpha, d_X, ne00, | ||||||
|  |                                     d_Y, ne10, | ||||||
|  |                             &beta,  d_D, ne01)); | ||||||
|  |  | ||||||
|  |                 // copy data to host | ||||||
|  |                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|  |                 CUDA_CHECK(cudaStreamSynchronize(cudaStream)); | ||||||
|  | #else | ||||||
|                 // zT = y * xT |                 // zT = y * xT | ||||||
|                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, |                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | ||||||
|                         ne11, ne01, ne10, |                         ne11, ne01, ne10, | ||||||
|                         1.0f,    y, ne10, |                         1.0f,    y, ne10, | ||||||
|                                  x, ne00, |                                  x, ne00, | ||||||
|                         0.0f,    d, ne01); |                         0.0f,    d, ne01); | ||||||
|  | #endif | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|  |         CUDA_CHECK(cudaFree(d_X)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_Y)); | ||||||
|  |         CUDA_CHECK(cudaFree(d_D)); | ||||||
|  | #endif | ||||||
|         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); |         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); | ||||||
|  |  | ||||||
|         return; |         return; | ||||||
| @@ -10874,7 +11048,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) | |||||||
|                         size_t cur = 0; |                         size_t cur = 0; | ||||||
|  |  | ||||||
|                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { |                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { |                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { | ||||||
|                                 node->n_tasks = 1; // TODO: this actually is doing nothing |                                 node->n_tasks = 1; // TODO: this actually is doing nothing | ||||||
|                                                    //       the threads are still spinning |                                                    //       the threads are still spinning | ||||||
| @@ -10891,7 +11065,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) | |||||||
|                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { |                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { | ||||||
|                             cur = 0; |                             cur = 0; | ||||||
|                         } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { |                         } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) { | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { |                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { | ||||||
|                                 node->n_tasks = 1; |                                 node->n_tasks = 1; | ||||||
|                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); |                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); | ||||||
| @@ -12231,7 +12405,15 @@ int ggml_cpu_has_wasm_simd(void) { | |||||||
| } | } | ||||||
|  |  | ||||||
| int ggml_cpu_has_blas(void) { | int ggml_cpu_has_blas(void) { | ||||||
| #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) | #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) | ||||||
|  |     return 1; | ||||||
|  | #else | ||||||
|  |     return 0; | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int ggml_cpu_has_cublas(void) { | ||||||
|  | #if defined(GGML_USE_CUBLAS) | ||||||
|     return 1; |     return 1; | ||||||
| #else | #else | ||||||
|     return 0; |     return 0; | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								ggml.h
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								ggml.h
									
									
									
									
									
								
							| @@ -825,6 +825,7 @@ int ggml_cpu_has_f16c(void); | |||||||
| int ggml_cpu_has_fp16_va(void); | int ggml_cpu_has_fp16_va(void); | ||||||
| int ggml_cpu_has_wasm_simd(void); | int ggml_cpu_has_wasm_simd(void); | ||||||
| int ggml_cpu_has_blas(void); | int ggml_cpu_has_blas(void); | ||||||
|  | int ggml_cpu_has_cublas(void); | ||||||
| int ggml_cpu_has_sse3(void); | int ggml_cpu_has_sse3(void); | ||||||
| int ggml_cpu_has_vsx(void); | int ggml_cpu_has_vsx(void); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1069,7 +1069,7 @@ static bool llama_eval_internal( | |||||||
|     // for big prompts, if BLAS is enabled, it is better to use only one thread |     // for big prompts, if BLAS is enabled, it is better to use only one thread | ||||||
|     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance |     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance | ||||||
|     ggml_cgraph gf = {}; |     ggml_cgraph gf = {}; | ||||||
|     gf.n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads; |     gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads; | ||||||
|  |  | ||||||
|     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); |     struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); | ||||||
|     memcpy(embd->data, tokens, N*ggml_element_size(embd)); |     memcpy(embd->data, tokens, N*ggml_element_size(embd)); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren