mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : fix const ptrs warning causing ROCm build issues (#3913)
This commit is contained in:
		
							
								
								
									
										37
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -7248,7 +7248,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |||||||
|  |  | ||||||
| __global__ void k_compute_batched_ptrs( | __global__ void k_compute_batched_ptrs( | ||||||
|         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, |         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, | ||||||
|         void ** ptrs, |         const void ** ptrs_src, void ** ptrs_dst, | ||||||
|         int ne12, int ne13, |         int ne12, int ne13, | ||||||
|         int ne23, |         int ne23, | ||||||
|         int nb02, int nb03, |         int nb02, int nb03, | ||||||
| @@ -7265,9 +7265,9 @@ __global__ void k_compute_batched_ptrs( | |||||||
|     int i03 = i13 / r3; |     int i03 = i13 / r3; | ||||||
|     int i02 = i12 / r2; |     int i02 = i12 / r2; | ||||||
|  |  | ||||||
|     ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02   + i03*nb03; |     ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03; | ||||||
|     ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; |     ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; | ||||||
|     ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* nb2/2 + i13* nb3/2; |     ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)     dst_f16 + i12* nb2/2 + i13* nb3/2; | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
| @@ -7372,14 +7372,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|     } else { |     } else { | ||||||
|         // use cublasGemmBatchedEx |         // use cublasGemmBatchedEx | ||||||
|         const int ne23 = ne12*ne13; |         const int ne23 = ne12*ne13; | ||||||
|         // allocate device memory for pointers |  | ||||||
|         size_t ptrs_s = 0; |         const void ** ptrs_src = nullptr; | ||||||
|         void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream); |               void ** ptrs_dst = nullptr; | ||||||
|  |  | ||||||
|  |         size_t ptrs_src_s = 0; | ||||||
|  |         size_t ptrs_dst_s = 0; | ||||||
|  |  | ||||||
|  |         ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); | ||||||
|  |         ptrs_dst = (      void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); | ||||||
|  |  | ||||||
|         dim3 block_dims(ne13, ne12); |         dim3 block_dims(ne13, ne12); | ||||||
|         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( |         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( | ||||||
|                 src0_as_f16, src1_as_f16, dst_f16, |                 src0_as_f16, src1_as_f16, dst_f16, | ||||||
|                 ptrs_as, |                 ptrs_src, ptrs_dst, | ||||||
|                 ne12, ne13, |                 ne12, ne13, | ||||||
|                 ne23, |                 ne23, | ||||||
|                 nb02, nb03, |                 nb02, nb03, | ||||||
| @@ -7390,15 +7396,18 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |||||||
|         CUBLAS_CHECK( |         CUBLAS_CHECK( | ||||||
|         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, |         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|                 ne01, ne11, ne10, |                 ne01, ne11, ne10, | ||||||
|                 &alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), |                 &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), | ||||||
|                             (const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), |                             (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), | ||||||
|                 &beta_f16,  (      void **       ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, |                 &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, | ||||||
|                 ne23, |                 ne23, | ||||||
|                 CUBLAS_COMPUTE_16F, |                 CUBLAS_COMPUTE_16F, | ||||||
|                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||||
|         // free device memory for pointers |  | ||||||
|         if (ptrs_s != 0) { |         if (ptrs_src_s != 0) { | ||||||
|             ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream); |             ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); | ||||||
|  |         } | ||||||
|  |         if (ptrs_dst_s != 0) { | ||||||
|  |             ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| #endif | #endif | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov