mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: add bf16 and f32 support to cublas_mul_mat_batched (#14361)
* CUDA: add bf16 and f32 support to cublas_mul_mat_batched * Review: add type traits and make function more generic * Review: make check more explicit, add back comments, and fix formatting * Review: fix formatting, remove useless type conversion, fix naming for bools
This commit is contained in:
		| @@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { | ||||
|             return nullptr; | ||||
|     } | ||||
| } | ||||
|  | ||||
| to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             return convert_unary_cuda<float, nv_bfloat16>; | ||||
|         case GGML_TYPE_F16: | ||||
|             return convert_unary_cuda<half, nv_bfloat16>; | ||||
|         default: | ||||
|             return nullptr; | ||||
|     } | ||||
| } | ||||
|  | ||||
| to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { | ||||
|     switch (type) { | ||||
|         case GGML_TYPE_F16: | ||||
|             return convert_unary_cuda<half, float>; | ||||
|         case GGML_TYPE_BF16: | ||||
|             return convert_unary_cuda<nv_bfloat16, float>; | ||||
|         default: | ||||
|             return nullptr; | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y, | ||||
|     int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, | ||||
|     int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); | ||||
|  | ||||
| typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t; | ||||
| typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t; | ||||
| typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t; | ||||
|  | ||||
| to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); | ||||
| to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); | ||||
| to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); | ||||
|   | ||||
| @@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat( | ||||
| } | ||||
|  | ||||
| static __global__ void k_compute_batched_ptrs( | ||||
|         const half * src0_as_f16, const half * src1_as_f16, char * dst, | ||||
|         const void * src0_as_f16, const void * src1_as_f16, char * dst, | ||||
|         const void ** ptrs_src, void ** ptrs_dst, | ||||
|         int64_t ne12, int64_t ne13, | ||||
|         int64_t ne23, | ||||
| @@ -1772,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs( | ||||
|     ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3; | ||||
| } | ||||
|  | ||||
| static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
| // Type traits for mapping ggml types to CUDA/cuBLAS types | ||||
| template<ggml_type T> | ||||
| struct batched_mul_mat_traits; | ||||
|  | ||||
| template<> | ||||
| struct batched_mul_mat_traits<GGML_TYPE_F32> { | ||||
|     using cuda_type = float; | ||||
|     static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; | ||||
|     static inline const cudaDataType_t data_type = CUDA_R_32F; | ||||
|     static inline const ggml_type ggml_type_val = GGML_TYPE_F32; | ||||
|     static inline const float alpha = 1.0f; | ||||
|     static inline const float beta = 0.0f; | ||||
|     static inline const void* get_alpha() { static const float val = alpha; return &val; } | ||||
|     static inline const void* get_beta() { static const float val = beta; return &val; } | ||||
|     static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } | ||||
| }; | ||||
|  | ||||
| template<> | ||||
| struct batched_mul_mat_traits<GGML_TYPE_BF16> { | ||||
|     using cuda_type = nv_bfloat16; | ||||
|     static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; | ||||
|     static inline const cudaDataType_t data_type = CUDA_R_16BF; | ||||
|     static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; | ||||
|     static inline const float alpha = 1.0f; | ||||
|     static inline const float beta = 0.0f; | ||||
|     static inline const void* get_alpha() { static const float val = alpha; return &val; } | ||||
|     static inline const void* get_beta() { static const float val = beta; return &val; } | ||||
|     static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } | ||||
| }; | ||||
|  | ||||
| template<> | ||||
| struct batched_mul_mat_traits<GGML_TYPE_F16> { | ||||
|     using cuda_type = half; | ||||
|     static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; | ||||
|     static inline const cudaDataType_t data_type = CUDA_R_16F; | ||||
|     static inline const ggml_type ggml_type_val = GGML_TYPE_F16; | ||||
|     static inline const half alpha = 1.0; | ||||
|     static inline const half beta = 0.0; | ||||
|     static inline const void* get_alpha() { static const half val = alpha; return &val; } | ||||
|     static inline const void* get_beta() { static const half val = beta; return &val; } | ||||
|     static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } | ||||
| }; | ||||
|  | ||||
| template<ggml_type src0_type> | ||||
| static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     using traits = batched_mul_mat_traits<src0_type>; | ||||
|     using cuda_t = typename traits::cuda_type; | ||||
|  | ||||
|     GGML_ASSERT(!ggml_is_transposed(src0)); | ||||
|     GGML_ASSERT(!ggml_is_transposed(src1)); | ||||
|  | ||||
|     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft)); | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||
|     GGML_ASSERT(src0->type == src0_type); | ||||
|     GGML_ASSERT(ggml_is_contiguous(dst)); | ||||
|  | ||||
|     // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. | ||||
|     // As long as dst is contiguous this does not matter though. | ||||
|     GGML_ASSERT(ggml_is_contiguous(dst)); | ||||
|  | ||||
|     GGML_TENSOR_BINARY_OP_LOCALS | ||||
|  | ||||
|     const int64_t ne_dst = ggml_nelements(dst); | ||||
|  | ||||
|     cudaStream_t main_stream = ctx.stream(); | ||||
|  | ||||
|     CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); | ||||
|  | ||||
|     const half * src0_f16 = (const half *) src0->data; | ||||
|     float * dst_ddf = (float *) dst->data; | ||||
|  | ||||
|     const half * src1_f16 = (const half *) src1->data; | ||||
|     const size_t ts_src1 = ggml_type_size(src1->type); | ||||
|     GGML_ASSERT(nb10 == ts_src1); | ||||
|     int64_t s11 = nb11 / ts_src1; | ||||
|     int64_t s12 = nb12 / ts_src1; | ||||
|     int64_t s13 = nb13 / ts_src1; | ||||
|     ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool()); | ||||
|  | ||||
|     // convert src1 to fp16 | ||||
|     if (src1->type != GGML_TYPE_F16) { | ||||
|         const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type); | ||||
|     const cuda_t * src0_ptr = nullptr; | ||||
|     const cuda_t * src1_ptr = nullptr; | ||||
|  | ||||
|     ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool()); | ||||
|     ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool()); | ||||
|  | ||||
|     // Handle src0 | ||||
|     src0_ptr = (const cuda_t *) src0->data; | ||||
|  | ||||
|     // Handle src1 - convert if necessary | ||||
|     if (src1->type == src0_type) { | ||||
|         src1_ptr = (const cuda_t *) src1->data; | ||||
|     } else { | ||||
|         // Convert src1 to target type using traits conversion functions | ||||
|         const int64_t ne_src1 = ggml_nelements(src1); | ||||
|         src1_f16_alloc.alloc(ne_src1); | ||||
|         GGML_ASSERT(to_fp16_cuda != nullptr); | ||||
|         src1_alloc.alloc(ne_src1); | ||||
|  | ||||
|         to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); | ||||
|  | ||||
|         src1_f16 = src1_f16_alloc.get(); | ||||
|         const auto convert_func = traits::get_nc_converter(src1->type); | ||||
|         GGML_ASSERT(convert_func != nullptr); | ||||
|         convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); | ||||
|         src1_ptr = src1_alloc.get(); | ||||
|         s11 = ne10; | ||||
|         s12 = ne11*s11; | ||||
|         s13 = ne12*s12; | ||||
|     } | ||||
|  | ||||
|     ggml_cuda_pool_alloc<half> dst_f16(ctx.pool()); | ||||
|     // Setup destination buffer | ||||
|     ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool()); | ||||
|     char * dst_t; | ||||
|  | ||||
|     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; | ||||
|     cudaDataType_t      cu_data_type    = CUDA_R_16F; | ||||
|  | ||||
|     // dst strides | ||||
|     size_t nbd2 = dst->nb[2]; | ||||
|     size_t nbd3 = dst->nb[3]; | ||||
|  | ||||
|     const half  alpha_f16 = 1.0f; | ||||
|     const half  beta_f16  = 0.0f; | ||||
|  | ||||
|     cublasComputeType_t cu_compute_type = traits::compute_type; | ||||
|     cudaDataType_t cu_data_type = traits::data_type; | ||||
|     cudaDataType_t cu_data_type_a = traits::data_type; | ||||
|     cudaDataType_t cu_data_type_b = traits::data_type; | ||||
|     const void * alpha = traits::get_alpha(); | ||||
|     const void * beta = traits::get_beta(); | ||||
|     const float alpha_f32 = 1.0f; | ||||
|     const float beta_f32  = 0.0f; | ||||
|  | ||||
|     const void * alpha = &alpha_f16; | ||||
|     const void * beta  = &beta_f16; | ||||
|     const float beta_f32 = 0.0f; | ||||
|  | ||||
|     if (dst->op_params[0] == GGML_PREC_DEFAULT) { | ||||
|         dst_t = (char *) dst_f16.alloc(ne_dst); | ||||
|  | ||||
|         nbd2 /= sizeof(float) / sizeof(half); | ||||
|         nbd3 /= sizeof(float) / sizeof(half); | ||||
|         if constexpr (src0_type == GGML_TYPE_F32) { | ||||
|             dst_t = (char *) dst_ddf;  // Direct F32 output | ||||
|         } else { | ||||
|             dst_t = (char *) dst_temp.alloc(ne_dst); | ||||
|             nbd2 /= sizeof(float) / sizeof(cuda_t); | ||||
|             nbd3 /= sizeof(float) / sizeof(cuda_t); | ||||
|         } | ||||
|     } else { | ||||
|         dst_t = (char *) dst_ddf; | ||||
|  | ||||
|         cu_compute_type = CUBLAS_COMPUTE_32F; | ||||
|         cu_data_type    = CUDA_R_32F; | ||||
|  | ||||
|         cu_data_type = CUDA_R_32F; | ||||
|         alpha = &alpha_f32; | ||||
|         beta  = &beta_f32; | ||||
|         beta = &beta_f32; | ||||
|     } | ||||
|  | ||||
|     int id = ggml_cuda_get_device(); | ||||
| @@ -1856,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co | ||||
|     if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { | ||||
|         cu_compute_type = CUBLAS_COMPUTE_32F; | ||||
|         alpha = &alpha_f32; | ||||
|         beta  = &beta_f32; | ||||
|         beta = &beta_f32; | ||||
|     } | ||||
|  | ||||
|     GGML_ASSERT(ne12 % ne02 == 0); | ||||
| @@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co | ||||
|     const int64_t r2 = ne12/ne02; | ||||
|     const int64_t r3 = ne13/ne03; | ||||
|  | ||||
| #if 0 | ||||
|     // use cublasGemmEx | ||||
|     { | ||||
|         for (int i13 = 0; i13 < ne13; ++i13) { | ||||
|             for (int i12 = 0; i12 < ne12; ++i12) { | ||||
|                 int i03 = i13 / r3; | ||||
|                 int i02 = i12 / r2; | ||||
|  | ||||
|                 CUBLAS_CHECK( | ||||
|                 cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, | ||||
|                     ne01, ne11, ne10, | ||||
|                     alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F,   nb01/sizeof(half), | ||||
|                                           src1_f16 + i13*s13  + i12*s12,  CUDA_R_16F,   s11, | ||||
|                     beta,  (      char *)    dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0, | ||||
|                     cu_compute_type, | ||||
|                     CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #else | ||||
|     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { | ||||
|         // there is no broadcast and src0, src1 are contiguous across dims 2, 3 | ||||
|         // use cublasGemmStridedBatchedEx | ||||
|         CUBLAS_CHECK( | ||||
|         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, | ||||
|                 ne01, ne11, ne10, | ||||
|                 alpha, src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00, // strideA | ||||
|                        src1_f16, CUDA_R_16F,   s11,       s12,       // strideB | ||||
|                 beta,     dst_t, cu_data_type, ne0,       ne1*ne0,   // strideC | ||||
|                 alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA | ||||
|                        src1_ptr, cu_data_type_b, s11,       s12,       // strideB | ||||
|                 beta,     dst_t, cu_data_type,   ne0,       ne1*ne0,   // strideC | ||||
|                 ne12*ne13, | ||||
|                 cu_compute_type, | ||||
|                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||
| @@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co | ||||
|         ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); | ||||
|         ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23); | ||||
|  | ||||
|         size_t src1_stride_size = sizeof(cuda_t); | ||||
|  | ||||
|         dim3 block_dims(ne13, ne12); | ||||
|         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( | ||||
|                 src0_f16, src1_f16, dst_t, | ||||
|                 src0_ptr, src1_ptr, dst_t, | ||||
|                 ptrs_src.get(), ptrs_dst.get(), | ||||
|                 ne12, ne13, | ||||
|                 ne23, | ||||
|                 nb02, nb03, | ||||
|                 src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), | ||||
|                 src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), | ||||
|                 (src1->type == src0_type) ? nb12 : s12*src1_stride_size, | ||||
|                 (src1->type == src0_type) ? nb13 : s13*src1_stride_size, | ||||
|                 nbd2, nbd3, | ||||
|                 r2, r3); | ||||
|  | ||||
|         CUDA_CHECK(cudaGetLastError()); | ||||
|  | ||||
|         CUBLAS_CHECK( | ||||
|         cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, | ||||
|                 ne01, ne11, ne10, | ||||
|                 alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00, | ||||
|                        (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   s11, | ||||
|                 beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, | ||||
|                 alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00, | ||||
|                        (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11, | ||||
|                 beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type,   ne0, | ||||
|                 ne23, | ||||
|                 cu_compute_type, | ||||
|                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { | ||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||||
|         to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); | ||||
|     // Convert output back to F32 if needed | ||||
|     if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { | ||||
|         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); | ||||
|         to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); | ||||
|  | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst); | ||||
|             break; | ||||
|         case GGML_TYPE_BF16: | ||||
|             ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst); | ||||
|             break; | ||||
|         case GGML_TYPE_F16: | ||||
|             ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst); | ||||
|             break; | ||||
|         default: | ||||
|             GGML_ABORT("Unsupported type"); | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); | ||||
|     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); | ||||
|  | ||||
|     //TODO update for generic tensor parallelism | ||||
|     const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||
|     bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16); | ||||
|     bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); | ||||
|     bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32; | ||||
|  | ||||
|     if (!split && use_mul_mat_vec) { | ||||
|         // the custom F16 vector kernel can be used over batched cuBLAS GEMM | ||||
|         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) | ||||
| @@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor | ||||
|         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); | ||||
|     } else if (!split && use_mul_mat_q) { | ||||
|         ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst); | ||||
|     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && | ||||
|             !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { | ||||
|     } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32) | ||||
|         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { | ||||
|         // general KQ + KQV multi-batch without FlashAttention | ||||
|         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); | ||||
|     } else if (use_mul_mat_vec) { | ||||
|   | ||||
| @@ -4425,8 +4425,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | ||||
|         for (auto nr : {1,4}) { | ||||
|             for (uint32_t m = 0; m < 2; ++m) { | ||||
|                 for (uint32_t k = 0; k < 2; ++k) { | ||||
|                     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3})); | ||||
|                     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true)); | ||||
|                     for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) { | ||||
|                         test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3})); | ||||
|                         test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true)); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Aman Gupta
					Aman Gupta