mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml-cuda : perform cublas fp16 matrix multiplication as fp16 (#3370)
* ggml-cuda : perform cublas fp16 matrix multiplication as fp16 * try to fix rocm build * restrict fp16 mat mul to volta and up
This commit is contained in:
		
							
								
								
									
										104
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -14,9 +14,11 @@ | ||||
| // for rocblas_initialize() | ||||
| #include "rocblas/rocblas.h" | ||||
| #endif // __HIP_PLATFORM_AMD__ | ||||
| #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F | ||||
| #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F | ||||
| #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F | ||||
| #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT | ||||
| #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT | ||||
| #define CUBLAS_OP_N HIPBLAS_OP_N | ||||
| #define CUBLAS_OP_T HIPBLAS_OP_T | ||||
| #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS | ||||
| @@ -235,8 +237,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * | ||||
|     return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment | ||||
| } | ||||
|  | ||||
| template<typename T> | ||||
| using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); | ||||
| typedef to_t_cuda_t<float> to_fp32_cuda_t; | ||||
| typedef to_t_cuda_t<half> to_fp16_cuda_t; | ||||
|  | ||||
| typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); | ||||
| typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); | ||||
| typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); | ||||
| typedef void (*cpy_kernel_t)(const char * cx, char * cdst); | ||||
| typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); | ||||
| @@ -1515,6 +1521,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, | ||||
|     v.y = x[ib + iqs + 1]; | ||||
| } | ||||
|  | ||||
| static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ | ||||
|     const float * x = (const float *) vx; | ||||
|  | ||||
|     // automatic half -> float type cast if dfloat == float | ||||
|     v.x = x[ib + iqs + 0]; | ||||
|     v.y = x[ib + iqs + 1]; | ||||
| } | ||||
|  | ||||
| static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { | ||||
|     const int ix = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|  | ||||
| @@ -1554,8 +1568,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest | ||||
|     reinterpret_cast<half&>(y[ib].ds.y) = sum; | ||||
| } | ||||
|  | ||||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel> | ||||
| static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) { | ||||
| template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> | ||||
| static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { | ||||
|     const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; | ||||
|  | ||||
|     if (i >= k) { | ||||
| @@ -4826,6 +4840,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c | ||||
|     dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); | ||||
| } | ||||
|  | ||||
| static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { | ||||
|     const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; | ||||
|     dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); | ||||
| } | ||||
|  | ||||
| static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
| @@ -4835,6 +4854,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); | ||||
| } | ||||
|  | ||||
| static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             return convert_fp32_to_fp16_cuda; | ||||
|         default: | ||||
|             return nullptr; | ||||
|     } | ||||
| } | ||||
|  | ||||
| static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_Q4_0: | ||||
| @@ -6016,8 +6044,6 @@ inline void ggml_cuda_op_mul_mat_cublas( | ||||
|     GGML_ASSERT(src1_ddf_i != nullptr); | ||||
|     GGML_ASSERT(dst_dd_i != nullptr); | ||||
|  | ||||
|     const float alpha = 1.0f; | ||||
|     const float beta = 0.0f; | ||||
|  | ||||
|     const int64_t ne00 = src0->ne[0]; | ||||
|  | ||||
| @@ -6026,16 +6052,6 @@ inline void ggml_cuda_op_mul_mat_cublas( | ||||
|     const int64_t ne0 = dst->ne[0]; | ||||
|     const int64_t row_diff = row_high - row_low; | ||||
|  | ||||
|     float * src0_ddq_as_f32; | ||||
|     size_t src0_as = 0; | ||||
|  | ||||
|     if (src0->type != GGML_TYPE_F32) { | ||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); | ||||
|         src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT | ||||
|         to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); | ||||
|     } | ||||
|     const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; | ||||
|  | ||||
|     int id; | ||||
|     CUDA_CHECK(cudaGetDevice(&id)); | ||||
|  | ||||
| @@ -6043,6 +6059,61 @@ inline void ggml_cuda_op_mul_mat_cublas( | ||||
|     // ldc == nrows of the matrix that cuBLAS writes into | ||||
|     int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; | ||||
|  | ||||
|     const int compute_capability = g_compute_capabilities[id]; | ||||
|  | ||||
|     if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) { | ||||
|         // convert src1 to fp16, multiply as fp16, convert dst to fp32 | ||||
|         half * src1_as_f16 = nullptr; | ||||
|         size_t src1_as = 0; | ||||
|         if (src1->type != GGML_TYPE_F16) { | ||||
|             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); | ||||
|             GGML_ASSERT(to_fp16_cuda != nullptr); | ||||
|             size_t ne = src1_ncols*ne10; | ||||
|             src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); | ||||
|             to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); | ||||
|         } | ||||
|         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; | ||||
|  | ||||
|         size_t dst_as = 0; | ||||
|         half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); | ||||
|  | ||||
|         const half alpha_f16 = 1.0f; | ||||
|         const half beta_f16 = 0.0f; | ||||
|  | ||||
|         CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); | ||||
|         CUBLAS_CHECK( | ||||
|             cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||
|                     row_diff, src1_ncols, ne10, | ||||
|                     &alpha_f16, src0_dd_i, CUDA_R_16F, ne00, | ||||
|                                 src1_ptr,  CUDA_R_16F, ne10, | ||||
|                     &beta_f16,   dst_f16,  CUDA_R_16F, ldc, | ||||
|                     CUBLAS_COMPUTE_16F, | ||||
|                     CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||
|  | ||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||||
|         to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); | ||||
|  | ||||
|         ggml_cuda_pool_free(dst_f16, dst_as); | ||||
|  | ||||
|         if (src1_as != 0) { | ||||
|             ggml_cuda_pool_free(src1_as_f16, src1_as); | ||||
|         } | ||||
|     } | ||||
|     else { | ||||
|         float * src0_ddq_as_f32 = nullptr; | ||||
|         size_t src0_as = 0; | ||||
|  | ||||
|         if (src0->type != GGML_TYPE_F32) { | ||||
|             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); | ||||
|             GGML_ASSERT(to_fp32_cuda != nullptr); | ||||
|             src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT | ||||
|             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); | ||||
|         } | ||||
|         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; | ||||
|  | ||||
|         const float alpha = 1.0f; | ||||
|         const float beta = 0.0f; | ||||
|  | ||||
|         CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); | ||||
|         CUBLAS_CHECK( | ||||
|             cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||
| @@ -6051,9 +6122,10 @@ inline void ggml_cuda_op_mul_mat_cublas( | ||||
|                             src1_ddf_i,  ne10, | ||||
|                     &beta,  dst_dd_i,   ldc)); | ||||
|  | ||||
|     if (src0_as > 0) { | ||||
|         if (src0_as != 0) { | ||||
|             ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     (void) dst; | ||||
|     (void) src1_ddq_i; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren