mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | |||||||
|             return dequantize_row_iq3_s_cuda; |             return dequantize_row_iq3_s_cuda; | ||||||
|         case GGML_TYPE_F16: |         case GGML_TYPE_F16: | ||||||
|             return convert_unary_cuda<half>; |             return convert_unary_cuda<half>; | ||||||
|  |         case GGML_TYPE_BF16: | ||||||
|  |             return convert_unary_cuda<nv_bfloat16>; | ||||||
|         default: |         default: | ||||||
|             return nullptr; |             return nullptr; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co | |||||||
| static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); |     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); | ||||||
|  |  | ||||||
|     bool use_mul_mat_vec   = src0->type == GGML_TYPE_F16 |     bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) | ||||||
|         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 |         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 | ||||||
|         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; |         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; | ||||||
|     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) |     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) | ||||||
| @@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|                     case GGML_TYPE_IQ3_XXS: |                     case GGML_TYPE_IQ3_XXS: | ||||||
|                     case GGML_TYPE_IQ4_NL: |                     case GGML_TYPE_IQ4_NL: | ||||||
|                     case GGML_TYPE_IQ4_XS: |                     case GGML_TYPE_IQ4_XS: | ||||||
|  |                     case GGML_TYPE_BF16: | ||||||
| #ifdef GGML_USE_MUSA | #ifdef GGML_USE_MUSA | ||||||
|                         if (a->type == GGML_TYPE_Q3_K) { |                         if (a->type == GGML_TYPE_Q3_K) { | ||||||
|                             return false; |                             return false; | ||||||
|   | |||||||
| @@ -1,9 +1,9 @@ | |||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "mmv.cuh" | #include "mmv.cuh" | ||||||
|  |  | ||||||
| template <typename type_acc, int block_size> | template <typename T, typename type_acc, int block_size> | ||||||
| static __global__ void mul_mat_vec( | static __global__ void mul_mat_vec( | ||||||
|         const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, |         const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, | ||||||
|         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { |         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { | ||||||
|     const int64_t row     = blockIdx.x; |     const int64_t row     = blockIdx.x; | ||||||
|     const int64_t channel = blockIdx.z; |     const int64_t channel = blockIdx.z; | ||||||
| @@ -13,7 +13,6 @@ static __global__ void mul_mat_vec( | |||||||
|     y   +=  channel               *stride_channel_y; |     y   +=  channel               *stride_channel_y; | ||||||
|     dst +=  channel               *stride_channel_dst; |     dst +=  channel               *stride_channel_dst; | ||||||
|  |  | ||||||
|     const half2  * x2 = (const half2  *) x; |  | ||||||
|     const float2 * y2 = (const float2 *) y; |     const float2 * y2 = (const float2 *) y; | ||||||
|  |  | ||||||
|     extern __shared__ char data_mmv[]; |     extern __shared__ char data_mmv[]; | ||||||
| @@ -28,28 +27,44 @@ static __global__ void mul_mat_vec( | |||||||
|  |  | ||||||
|     float sumf; |     float sumf; | ||||||
|  |  | ||||||
|     if (std::is_same<type_acc, float>::value) { |     if constexpr (std::is_same<T, half>::value) { | ||||||
|  |         const half2 * x2 = (const half2 *) x; | ||||||
|  |  | ||||||
|  |         if (std::is_same<type_acc, float>::value) { | ||||||
|  |             sumf = 0.0f; | ||||||
|  |  | ||||||
|  |             for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { | ||||||
|  |                 const float2 tmpx = __half22float2(x2[col2]); | ||||||
|  |                 const float2 tmpy = y2[col2]; | ||||||
|  |                 sumf += tmpx.x * tmpy.x; | ||||||
|  |                 sumf += tmpx.y * tmpy.y; | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  | #ifdef FP16_AVAILABLE | ||||||
|  |             half2 sumh2 = make_half2(0.0f, 0.0f); | ||||||
|  |  | ||||||
|  |             for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { | ||||||
|  |                 const float2 tmp = y2[col2]; | ||||||
|  |                 sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             sumf = __low2float(sumh2) + __high2float(sumh2); | ||||||
|  | #else | ||||||
|  |             NO_DEVICE_CODE; | ||||||
|  | #endif // FP16_AVAILABLE | ||||||
|  |         } | ||||||
|  |     } else if constexpr (std::is_same<T, nv_bfloat16>::value) { | ||||||
|  |         const int * x2 = (const int *) x; | ||||||
|         sumf = 0.0f; |         sumf = 0.0f; | ||||||
|  |  | ||||||
|         for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { |         for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { | ||||||
|             const float2 tmpx = __half22float2(x2[col2]); |             const int    tmpx = x2[col2]; | ||||||
|             const float2 tmpy = y2[col2]; |             const float2 tmpy = y2[col2]; | ||||||
|             sumf += tmpx.x * tmpy.x; |             sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x; | ||||||
|             sumf += tmpx.y * tmpy.y; |             sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y; | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
| #ifdef FP16_AVAILABLE |         static_assert(std::is_same<T, void>::value, "unsupported type"); | ||||||
|         half2 sumh2 = make_half2(0.0f, 0.0f); |  | ||||||
|  |  | ||||||
|         for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { |  | ||||||
|             const float2 tmp = y2[col2]; |  | ||||||
|             sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         sumf = __low2float(sumh2) + __high2float(sumh2); |  | ||||||
| #else |  | ||||||
|         NO_DEVICE_CODE; |  | ||||||
| #endif // FP16_AVAILABLE |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     sumf = warp_reduce_sum(sumf); |     sumf = warp_reduce_sum(sumf); | ||||||
| @@ -71,9 +86,9 @@ static __global__ void mul_mat_vec( | |||||||
|     dst[row] = sumf; |     dst[row] = sumf; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename type_acc> | template <typename T, typename type_acc> | ||||||
| static void launch_mul_mat_vec_cuda( | static void launch_mul_mat_vec_cuda( | ||||||
|         const half * x, const float * y, float * dst, |         const T * x, const float * y, float * dst, | ||||||
|         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, |         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, | ||||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, |         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, | ||||||
|         cudaStream_t stream) { |         cudaStream_t stream) { | ||||||
| @@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda( | |||||||
|     const dim3 block_dims(block_size_best, 1, 1); |     const dim3 block_dims(block_size_best, 1, 1); | ||||||
|     switch (block_size_best) { |     switch (block_size_best) { | ||||||
|         case   32: { |         case   32: { | ||||||
|             mul_mat_vec<type_acc,  32><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case   64: { |         case   64: { | ||||||
|             mul_mat_vec<type_acc,  64><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case   96: { |         case   96: { | ||||||
|             mul_mat_vec<type_acc,  96><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case  128: { |         case  128: { | ||||||
|             mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case  160: { |         case  160: { | ||||||
|             mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case  192: { |         case  192: { | ||||||
|             mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case  224: { |         case  224: { | ||||||
|             mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         case  256: { |         case  256: { | ||||||
|             mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>> |             mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>> | ||||||
|                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); |                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); | ||||||
|         } break; |         } break; | ||||||
|         default: { |         default: { | ||||||
| @@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<typename T> | ||||||
| static void mul_mat_vec_cuda( | static void mul_mat_vec_cuda( | ||||||
|         const half * x, const float * y, float * dst, |         const T * x, const float * y, float * dst, | ||||||
|         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, |         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, | ||||||
|         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, |         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, | ||||||
|         enum ggml_prec prec, cudaStream_t stream) { |         enum ggml_prec prec, cudaStream_t stream) { | ||||||
|     switch (prec) { |     switch (prec) { | ||||||
|         case GGML_PREC_DEFAULT: { |         case GGML_PREC_DEFAULT: { | ||||||
|             launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, |             launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, | ||||||
|                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); |                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); | ||||||
|         } break; |         } break; | ||||||
|         case GGML_PREC_F32: { |         case GGML_PREC_F32: { | ||||||
|             launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, |             launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, | ||||||
|                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); |                 stride_channel_x, stride_channel_y, stride_channel_dst, stream); | ||||||
|         } break; |         } break; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); |  | ||||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32); |     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(dst->type  == GGML_TYPE_F32); |     GGML_ASSERT(dst->type  == GGML_TYPE_F32); | ||||||
|  |  | ||||||
| @@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | |||||||
|     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||||
|     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; |     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; | ||||||
|  |  | ||||||
|     const half  * src0_d = (const half  *) src0->data; |  | ||||||
|     const float * src1_d = (const float *) src1->data; |     const float * src1_d = (const float *) src1->data; | ||||||
|     float       *  dst_d = (float       *)  dst->data; |     float       *  dst_d = (float       *)  dst->data; | ||||||
|  |  | ||||||
| @@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * | |||||||
|     const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type); |     const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type); | ||||||
|     const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type); |     const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type); | ||||||
|  |  | ||||||
|     mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); |     switch (src0->type) { | ||||||
|  |         case GGML_TYPE_F16: { | ||||||
|  |             const half * src0_d = (const half *) src0->data; | ||||||
|  |             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, | ||||||
|  |                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); | ||||||
|  |         } break; | ||||||
|  |         case GGML_TYPE_BF16: { | ||||||
|  |             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; | ||||||
|  |             mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, | ||||||
|  |                 channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); | ||||||
|  |         } break; | ||||||
|  |         default: | ||||||
|  |             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_op_mul_mat_vec( | void ggml_cuda_op_mul_mat_vec( | ||||||
| @@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec( | |||||||
|     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, |     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, | ||||||
|     const int64_t src1_padded_row_size, cudaStream_t stream) { |     const int64_t src1_padded_row_size, cudaStream_t stream) { | ||||||
|  |  | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); |  | ||||||
|     GGML_ASSERT(src1->type == GGML_TYPE_F32); |     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(dst->type  == GGML_TYPE_F32); |     GGML_ASSERT(dst->type  == GGML_TYPE_F32); | ||||||
|  |  | ||||||
| @@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec( | |||||||
|     const int64_t channel_stride_y   = 0; |     const int64_t channel_stride_y   = 0; | ||||||
|     const int64_t channel_stride_dst = 0; |     const int64_t channel_stride_dst = 0; | ||||||
|  |  | ||||||
|     mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, |     switch (src0->type) { | ||||||
|         nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); |         case GGML_TYPE_F16: { | ||||||
|  |             const half * src0_d = (const half *) src0_dd_i; | ||||||
|  |             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, | ||||||
|  |                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); | ||||||
|  |         } break; | ||||||
|  |         case GGML_TYPE_BF16: { | ||||||
|  |             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; | ||||||
|  |             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, | ||||||
|  |                 nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); | ||||||
|  |         } break; | ||||||
|  |         default: | ||||||
|  |             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     GGML_UNUSED(ctx); |     GGML_UNUSED(ctx); | ||||||
|     GGML_UNUSED(src1); |     GGML_UNUSED(src1); | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								ggml/src/ggml-cuda/vendors/cuda.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								ggml/src/ggml-cuda/vendors/cuda.h
									
									
									
									
										vendored
									
									
								
							| @@ -3,6 +3,7 @@ | |||||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||||
| #include <cuda.h> | #include <cuda.h> | ||||||
| #include <cublas_v2.h> | #include <cublas_v2.h> | ||||||
|  | #include <cuda_bf16.h> | ||||||
| #include <cuda_fp16.h> | #include <cuda_fp16.h> | ||||||
|  |  | ||||||
| #if CUDART_VERSION < 11020 | #if CUDART_VERSION < 11020 | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -3,6 +3,7 @@ | |||||||
| #include <hip/hip_runtime.h> | #include <hip/hip_runtime.h> | ||||||
| #include <hipblas/hipblas.h> | #include <hipblas/hipblas.h> | ||||||
| #include <hip/hip_fp16.h> | #include <hip/hip_fp16.h> | ||||||
|  | #include <hip/hip_bfloat16.h> | ||||||
| #ifdef __HIP_PLATFORM_AMD__ | #ifdef __HIP_PLATFORM_AMD__ | ||||||
| // for rocblas_initialize() | // for rocblas_initialize() | ||||||
| #include "rocblas/rocblas.h" | #include "rocblas/rocblas.h" | ||||||
| @@ -121,6 +122,8 @@ | |||||||
|     #define __has_builtin(x) 0 |     #define __has_builtin(x) 0 | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | typedef hip_bfloat16 nv_bfloat16; | ||||||
|  |  | ||||||
| typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); | typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); | ||||||
| typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); | typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); | ||||||
| static __device__ __forceinline__ int __vsubss4(const int a, const int b) { | static __device__ __forceinline__ int __vsubss4(const int a, const int b) { | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								ggml/src/ggml-cuda/vendors/musa.h
									
									
									
									
										vendored
									
									
								
							| @@ -3,6 +3,7 @@ | |||||||
| #include <musa_runtime.h> | #include <musa_runtime.h> | ||||||
| #include <musa.h> | #include <musa.h> | ||||||
| #include <mublas.h> | #include <mublas.h> | ||||||
|  | #include <musa_bf16.h> | ||||||
| #include <musa_fp16.h> | #include <musa_fp16.h> | ||||||
| #define CUBLAS_COMPUTE_16F CUDA_R_16F | #define CUBLAS_COMPUTE_16F CUDA_R_16F | ||||||
| #define CUBLAS_COMPUTE_32F CUDA_R_32F | #define CUBLAS_COMPUTE_32F CUDA_R_32F | ||||||
| @@ -132,3 +133,5 @@ | |||||||
| #define cudaKernelNodeParams musaKernelNodeParams | #define cudaKernelNodeParams musaKernelNodeParams | ||||||
| #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed | #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed | ||||||
| #define cudaStreamEndCapture musaStreamEndCapture | #define cudaStreamEndCapture musaStreamEndCapture | ||||||
|  |  | ||||||
|  | typedef mt_bfloat16 nv_bfloat16; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler