mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	CUDA: don't convert BF16 weights to FP32 (ggml/1174)
* add bf16 support * use convert_from_bf16_cuda instead of convert_unary_cuda for f32 * revert 7ec5085 * move functionality into convert_unary with constexpr
This commit is contained in:
		
				
					committed by
					
						
						Georgi Gerganov
					
				
			
			
				
	
			
			
			
						parent
						
							53cb49e337
						
					
				
				
					commit
					4683cb402a
				
			@@ -579,7 +579,13 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
 | 
			
		||||
 | 
			
		||||
    const src_t * x = (const src_t *) vx;
 | 
			
		||||
 | 
			
		||||
    y[i] = x[i];
 | 
			
		||||
    if constexpr (std::is_same_v<src_t, nv_bfloat16>) {
 | 
			
		||||
        y[i] = __bfloat162float(x[i]);
 | 
			
		||||
    } else if constexpr (std::is_same_v<dst_t, nv_bfloat16> && std::is_same_v<src_t, half>) {
 | 
			
		||||
        y[i] = (float)x[i];
 | 
			
		||||
    } else {
 | 
			
		||||
        y[i] = x[i];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename src_t, typename dst_t>
 | 
			
		||||
@@ -588,6 +594,17 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
 | 
			
		||||
    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
        case GGML_TYPE_F32:
 | 
			
		||||
            return convert_unary_cuda<float>;
 | 
			
		||||
        case GGML_TYPE_F16:
 | 
			
		||||
            return convert_unary_cuda<half>;
 | 
			
		||||
        default:
 | 
			
		||||
            return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
 | 
			
		||||
    switch (type) {
 | 
			
		||||
        case GGML_TYPE_Q4_0:
 | 
			
		||||
@@ -633,6 +650,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
 | 
			
		||||
            return dequantize_row_iq3_s_cuda;
 | 
			
		||||
        case GGML_TYPE_F32:
 | 
			
		||||
            return convert_unary_cuda<float>;
 | 
			
		||||
        case GGML_TYPE_BF16:
 | 
			
		||||
            return convert_unary_cuda<nv_bfloat16>;
 | 
			
		||||
        default:
 | 
			
		||||
            return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,10 @@ using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, in
 | 
			
		||||
 | 
			
		||||
typedef to_t_cuda_t<float> to_fp32_cuda_t;
 | 
			
		||||
typedef to_t_cuda_t<half> to_fp16_cuda_t;
 | 
			
		||||
typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;
 | 
			
		||||
 | 
			
		||||
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
 | 
			
		||||
 | 
			
		||||
to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
 | 
			
		||||
 | 
			
		||||
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
 | 
			
		||||
 
 | 
			
		||||
@@ -1194,7 +1194,35 @@ static void ggml_cuda_op_mul_mat_cublas(
 | 
			
		||||
 | 
			
		||||
    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
 | 
			
		||||
 | 
			
		||||
    if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
 | 
			
		||||
    if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
 | 
			
		||||
        ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
 | 
			
		||||
        if (src1->type != GGML_TYPE_BF16) {
 | 
			
		||||
            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
 | 
			
		||||
            GGML_ASSERT(to_bf16_cuda != nullptr);
 | 
			
		||||
            size_t ne = src1_ncols*ne10;
 | 
			
		||||
            src1_as_bf16.alloc(ne);
 | 
			
		||||
            to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
 | 
			
		||||
        }
 | 
			
		||||
        const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
 | 
			
		||||
        const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
 | 
			
		||||
        ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
 | 
			
		||||
 | 
			
		||||
        const float alpha_f32 = 1.0f;
 | 
			
		||||
        const float beta_f32  = 0.0f;
 | 
			
		||||
 | 
			
		||||
        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 | 
			
		||||
        CUBLAS_CHECK(
 | 
			
		||||
            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
 | 
			
		||||
                    row_diff, src1_ncols, ne10,
 | 
			
		||||
                    &alpha_f32,  src0_ptr,       CUDA_R_16BF, ne00,
 | 
			
		||||
                                 src1_ptr,       CUDA_R_16BF, ne10,
 | 
			
		||||
                    &beta_f32,   dst_bf16.get(), CUDA_R_16BF, ldc,
 | 
			
		||||
                    CUBLAS_COMPUTE_32F,
 | 
			
		||||
                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 | 
			
		||||
 | 
			
		||||
        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
 | 
			
		||||
        to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
 | 
			
		||||
    } else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
 | 
			
		||||
        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
 | 
			
		||||
        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
 | 
			
		||||
        if (src0->type != GGML_TYPE_F16) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user