mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : use CUDA memory pool with async memory allocation/deallocation when available (#3903)
* Using cuda memory pools for async alloc/dealloc. * If cuda device doesnt support memory pool than use old implementation. * Removed redundant cublasSetStream --------- Co-authored-by: Oleksii Maryshchenko <omaryshchenko@dtis.com>
This commit is contained in:
		 Oleksii Maryshchenko
					Oleksii Maryshchenko
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							4ff1046d75
						
					
				
				
					commit
					d6069051de
				
			
							
								
								
									
										130
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										130
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -181,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | |||||||
|     do {                                                                                \ |     do {                                                                                \ | ||||||
|         cudaError_t err_ = (err);                                                       \ |         cudaError_t err_ = (err);                                                       \ | ||||||
|         if (err_ != cudaSuccess) {                                                      \ |         if (err_ != cudaSuccess) {                                                      \ | ||||||
|             int id;                                                                     \ |             int dev_id;                                                                     \ | ||||||
|             cudaGetDevice(&id);                                                         \ |             cudaGetDevice(&dev_id);                                                         \ | ||||||
|             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ |             fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ | ||||||
|                 cudaGetErrorString(err_));                                              \ |                 cudaGetErrorString(err_));                                              \ | ||||||
|             fprintf(stderr, "current device: %d\n", id);                                \ |             fprintf(stderr, "current device: %d\n", dev_id);                                \ | ||||||
|             exit(1);                                                                    \ |             exit(1);                                                                    \ | ||||||
|         }                                                                               \ |         }                                                                               \ | ||||||
|     } while (0) |     } while (0) | ||||||
| @@ -195,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | |||||||
|     do {                                                                                \ |     do {                                                                                \ | ||||||
|         cublasStatus_t err_ = (err);                                                    \ |         cublasStatus_t err_ = (err);                                                    \ | ||||||
|         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \ |         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \ | ||||||
|             int id;                                                                     \ |             int dev_id;                                                                     \ | ||||||
|             cudaGetDevice(&id);                                                         \ |             cudaGetDevice(&dev_id);                                                         \ | ||||||
|             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \ |             fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \ | ||||||
|                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \ |                     err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \ | ||||||
|             fprintf(stderr, "current device: %d\n", id);                                \ |             fprintf(stderr, "current device: %d\n", dev_id);                                \ | ||||||
|             exit(1);                                                                    \ |             exit(1);                                                                    \ | ||||||
|         }                                                                               \ |         }                                                                               \ | ||||||
|     } while (0) |     } while (0) | ||||||
| @@ -465,6 +465,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA | |||||||
|  |  | ||||||
| #define MAX_STREAMS 8 | #define MAX_STREAMS 8 | ||||||
| static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; | static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; | ||||||
|  | static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr }; | ||||||
|  |  | ||||||
| struct ggml_tensor_extra_gpu { | struct ggml_tensor_extra_gpu { | ||||||
|     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors |     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors | ||||||
| @@ -5772,6 +5773,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | |||||||
|     return ptr; |     return ptr; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) { | ||||||
|  |     if (g_cudaMemPools[id] == nullptr) { | ||||||
|  |         return ggml_cuda_pool_malloc(size, actual_size); | ||||||
|  |     } | ||||||
|  |     void *ptr; | ||||||
|  |     CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream)); | ||||||
|  |     *actual_size = size; | ||||||
|  |     return ptr; | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_cuda_pool_free(void * ptr, size_t size) { | static void ggml_cuda_pool_free(void * ptr, size_t size) { | ||||||
|     scoped_spin_lock lock(g_cuda_pool_lock); |     scoped_spin_lock lock(g_cuda_pool_lock); | ||||||
|     int id; |     int id; | ||||||
| @@ -5790,6 +5801,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) { | ||||||
|  |     if (g_cudaMemPools[id] == nullptr) { | ||||||
|  |         return ggml_cuda_pool_free(ptr, actual_size); | ||||||
|  |     } | ||||||
|  |     CUDA_CHECK(cudaFreeAsync(ptr, stream)); | ||||||
|  | } | ||||||
|  |  | ||||||
| void ggml_init_cublas() { | void ggml_init_cublas() { | ||||||
|     static bool initialized = false; |     static bool initialized = false; | ||||||
|  |  | ||||||
| @@ -5844,6 +5862,13 @@ void ggml_init_cublas() { | |||||||
|             // create cublas handle |             // create cublas handle | ||||||
|             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); |             CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); | ||||||
|             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); |             CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); | ||||||
|  |  | ||||||
|  |             // configure memory pool | ||||||
|  |             cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); | ||||||
|  |             if (err == cudaSuccess) { | ||||||
|  |                 size_t treshold = UINT64_MAX; | ||||||
|  |                 CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // configure logging to stdout |         // configure logging to stdout | ||||||
| @@ -6437,7 +6462,7 @@ inline void ggml_cuda_op_mul_mat_cublas( | |||||||
|             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); |             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); | ||||||
|             GGML_ASSERT(to_fp16_cuda != nullptr); |             GGML_ASSERT(to_fp16_cuda != nullptr); | ||||||
|             size_t ne = row_diff*ne00; |             size_t ne = row_diff*ne00; | ||||||
|             src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); |             src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); | ||||||
|             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); |             to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); | ||||||
|         } |         } | ||||||
|         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; |         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; | ||||||
| @@ -6448,13 +6473,12 @@ inline void ggml_cuda_op_mul_mat_cublas( | |||||||
|             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); |             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); | ||||||
|             GGML_ASSERT(to_fp16_cuda != nullptr); |             GGML_ASSERT(to_fp16_cuda != nullptr); | ||||||
|             size_t ne = src1_ncols*ne10; |             size_t ne = src1_ncols*ne10; | ||||||
|             src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); |             src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); | ||||||
|             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); |             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); | ||||||
|         } |         } | ||||||
|         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; |         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; | ||||||
|  |         size_t dst_f16_as = 0; | ||||||
|         size_t dst_as = 0; |         half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); | ||||||
|         half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); |  | ||||||
|  |  | ||||||
|         const half alpha_f16 = 1.0f; |         const half alpha_f16 = 1.0f; | ||||||
|         const half beta_f16 = 0.0f; |         const half beta_f16 = 0.0f; | ||||||
| @@ -6472,14 +6496,15 @@ inline void ggml_cuda_op_mul_mat_cublas( | |||||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); |         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||||||
|         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); |         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); | ||||||
|  |  | ||||||
|         ggml_cuda_pool_free(dst_f16, dst_as); |         if (dst_f16_as != 0) { | ||||||
|  |             ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream); | ||||||
|         if (src0_as != 0) { |  | ||||||
|             ggml_cuda_pool_free(src0_as_f16, src0_as); |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         if (src0_as != 0) { | ||||||
|  |             ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); | ||||||
|  |         } | ||||||
|         if (src1_as != 0) { |         if (src1_as != 0) { | ||||||
|             ggml_cuda_pool_free(src1_as_f16, src1_as); |             ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     else { |     else { | ||||||
| @@ -6489,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas( | |||||||
|         if (src0->type != GGML_TYPE_F32) { |         if (src0->type != GGML_TYPE_F32) { | ||||||
|             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); |             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); | ||||||
|             GGML_ASSERT(to_fp32_cuda != nullptr); |             GGML_ASSERT(to_fp32_cuda != nullptr); | ||||||
|             src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT |             src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT | ||||||
|             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); |             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); | ||||||
|         } |         } | ||||||
|         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; |         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; | ||||||
| @@ -6506,7 +6531,7 @@ inline void ggml_cuda_op_mul_mat_cublas( | |||||||
|                     &beta,  dst_dd_i,   ldc)); |                     &beta,  dst_dd_i,   ldc)); | ||||||
|  |  | ||||||
|         if (src0_as != 0) { |         if (src0_as != 0) { | ||||||
|             ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); |             ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -6929,21 +6954,22 @@ static void ggml_cuda_op_mul_mat( | |||||||
|             src0_dd[id] = (char *) src0_extra->data_device[id]; |             src0_dd[id] = (char *) src0_extra->data_device[id]; | ||||||
|         } else { |         } else { | ||||||
|             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); |             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); | ||||||
|             src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); |             src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (src1_on_device && src1_is_contiguous) { |         if (src1_on_device && src1_is_contiguous) { | ||||||
|             src1_ddf[id] = (float *) src1_extra->data_device[id]; |             src1_ddf[id] = (float *) src1_extra->data_device[id]; | ||||||
|         } else { |         } else { | ||||||
|             src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); |             src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (convert_src1_to_q8_1) { |         if (convert_src1_to_q8_1) { | ||||||
|             src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); |             const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; | ||||||
|  |             src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream); | ||||||
|  |  | ||||||
|             if (src1_on_device && src1_is_contiguous) { |             if (src1_on_device && src1_is_contiguous) { | ||||||
|                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); |                 quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); | ||||||
|                 CUDA_CHECK(cudaGetLastError()); |                 // CUDA_CHECK(cudaGetLastError()); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -6951,7 +6977,7 @@ static void ggml_cuda_op_mul_mat( | |||||||
|             dst_dd[id] = (float *) dst_extra->data_device[id]; |             dst_dd[id] = (float *) dst_extra->data_device[id]; | ||||||
|         } else { |         } else { | ||||||
|             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); |             const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); | ||||||
|             dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); |             dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id,  stream); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -7077,24 +7103,6 @@ static void ggml_cuda_op_mul_mat( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (int64_t id = 0; id < g_device_count; ++id) { |  | ||||||
|         CUDA_CHECK(ggml_cuda_set_device(id)); |  | ||||||
|  |  | ||||||
|         // free buffers again when done |  | ||||||
|         if (src0_as[id] > 0) { |  | ||||||
|             ggml_cuda_pool_free(src0_dd[id], src0_as[id]); |  | ||||||
|         } |  | ||||||
|         if (src1_asf[id] > 0) { |  | ||||||
|             ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); |  | ||||||
|         } |  | ||||||
|         if (src1_asq[id] > 0) { |  | ||||||
|             ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); |  | ||||||
|         } |  | ||||||
|         if (dst_as[id] > 0) { |  | ||||||
|             ggml_cuda_pool_free(dst_dd[id], dst_as[id]); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // main device waits for all other devices to be finished |     // main device waits for all other devices to be finished | ||||||
|     if (split && g_device_count > 1) { |     if (split && g_device_count > 1) { | ||||||
|         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; |         int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; | ||||||
| @@ -7112,6 +7120,21 @@ static void ggml_cuda_op_mul_mat( | |||||||
|         CUDA_CHECK(ggml_cuda_set_device(g_main_device)); |         CUDA_CHECK(ggml_cuda_set_device(g_main_device)); | ||||||
|         CUDA_CHECK(cudaDeviceSynchronize()); |         CUDA_CHECK(cudaDeviceSynchronize()); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     for (int64_t id = 0; id < g_device_count; ++id) { | ||||||
|  |         if (src0_as[id] > 0) { | ||||||
|  |             ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]); | ||||||
|  |         } | ||||||
|  |         if (src1_asf[id] > 0) { | ||||||
|  |             ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]); | ||||||
|  |         } | ||||||
|  |         if (src1_asq[id] > 0) { | ||||||
|  |             ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]); | ||||||
|  |         } | ||||||
|  |         if (dst_as[id] > 0) { | ||||||
|  |             ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
| @@ -7298,11 +7321,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|     GGML_ASSERT(to_fp16_cuda != nullptr); |     GGML_ASSERT(to_fp16_cuda != nullptr); | ||||||
|  |  | ||||||
|     size_t src1_as = 0; |     size_t src1_as = 0; | ||||||
|     half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); |     half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); | ||||||
|     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); |     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); | ||||||
|  |  | ||||||
|     size_t dst_as = 0; |     size_t dst_as = 0; | ||||||
|     half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); |     half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); | ||||||
|  |  | ||||||
|     GGML_ASSERT(ne12 % ne02 == 0); |     GGML_ASSERT(ne12 % ne02 == 0); | ||||||
|     GGML_ASSERT(ne13 % ne03 == 0); |     GGML_ASSERT(ne13 % ne03 == 0); | ||||||
| @@ -7349,10 +7372,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|     } else { |     } else { | ||||||
|         // use cublasGemmBatchedEx |         // use cublasGemmBatchedEx | ||||||
|         const int ne23 = ne12*ne13; |         const int ne23 = ne12*ne13; | ||||||
|  |         // allocate device memory for pointers | ||||||
|         void ** ptrs_as = nullptr; |  | ||||||
|         size_t ptrs_s = 0; |         size_t ptrs_s = 0; | ||||||
|         ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s); |         void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream); | ||||||
|  |  | ||||||
|         dim3 block_dims(ne13, ne12); |         dim3 block_dims(ne13, ne12); | ||||||
|         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( |         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( | ||||||
| @@ -7365,7 +7387,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|                 dst->nb[2], dst->nb[3], |                 dst->nb[2], dst->nb[3], | ||||||
|                 r2, r3); |                 r2, r3); | ||||||
|         CUDA_CHECK(cudaGetLastError()); |         CUDA_CHECK(cudaGetLastError()); | ||||||
|  |  | ||||||
|         CUBLAS_CHECK( |         CUBLAS_CHECK( | ||||||
|         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, |         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|                 ne01, ne11, ne10, |                 ne01, ne11, ne10, | ||||||
| @@ -7375,16 +7396,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|                 ne23, |                 ne23, | ||||||
|                 CUBLAS_COMPUTE_16F, |                 CUBLAS_COMPUTE_16F, | ||||||
|                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||||
|  |         // free device memory for pointers | ||||||
|         ggml_cuda_pool_free(ptrs_as, ptrs_s); |         if (ptrs_s != 0) { | ||||||
|  |             ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); |     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||||||
|     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); |     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); | ||||||
|  |     if (src1_as != 0) { | ||||||
|     ggml_cuda_pool_free(src1_as_f16, src1_as); |         ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); | ||||||
|     ggml_cuda_pool_free(dst_f16, dst_as); |     } | ||||||
|  |     if (dst_as != 0) { | ||||||
|  |         ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user