mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: don't convert BF16 weights to FP32 (ggml/1174)
* add bf16 support * use convert_from_bf16_cuda instead of convert_unary_cuda for f32 * revert 7ec5085 * move functionality into convert_unary with constexpr
This commit is contained in:
		 Sigbjørn Skjæret
					Sigbjørn Skjæret
				
			
				
					committed by
					
						 Georgi Gerganov
						Georgi Gerganov
					
				
			
			
				
	
			
			
			 Georgi Gerganov
						Georgi Gerganov
					
				
			
						parent
						
							995083e4ed
						
					
				
				
					commit
					36ca8b3628
				
			| @@ -1194,7 +1194,35 @@ static void ggml_cuda_op_mul_mat_cublas( | ||||
|  | ||||
|     const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; | ||||
|  | ||||
|     if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { | ||||
|     if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { | ||||
|         ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id)); | ||||
|         if (src1->type != GGML_TYPE_BF16) { | ||||
|             const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); | ||||
|             GGML_ASSERT(to_bf16_cuda != nullptr); | ||||
|             size_t ne = src1_ncols*ne10; | ||||
|             src1_as_bf16.alloc(ne); | ||||
|             to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream); | ||||
|         } | ||||
|         const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); | ||||
|         const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i; | ||||
|         ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols); | ||||
|  | ||||
|         const float alpha_f32 = 1.0f; | ||||
|         const float beta_f32  = 0.0f; | ||||
|  | ||||
|         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); | ||||
|         CUBLAS_CHECK( | ||||
|             cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, | ||||
|                     row_diff, src1_ncols, ne10, | ||||
|                     &alpha_f32,  src0_ptr,       CUDA_R_16BF, ne00, | ||||
|                                  src1_ptr,       CUDA_R_16BF, ne10, | ||||
|                     &beta_f32,   dst_bf16.get(), CUDA_R_16BF, ldc, | ||||
|                     CUBLAS_COMPUTE_32F, | ||||
|                     CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||
|  | ||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); | ||||
|         to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream); | ||||
|     } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) { | ||||
|         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 | ||||
|         ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id)); | ||||
|         if (src0->type != GGML_TYPE_F16) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user