mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Enabled more data types for oneMKL gemm_batch (#8236)
This commit is contained in:
		 Ouadie EL FAROUKI
					Ouadie EL FAROUKI
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							148ec970b6
						
					
				
				
					commit
					1f3e1b66e2
				
			| @@ -3493,10 +3493,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); | ||||||
|     queue_ptr main_stream = ctx.stream();; |     queue_ptr main_stream = ctx.stream();; | ||||||
|  |  | ||||||
|     bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda || |  | ||||||
|                            main_stream->get_backend() == sycl::backend::ext_oneapi_hip; |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     void * src0_ddq = src0->data; |     void * src0_ddq = src0->data; | ||||||
|     sycl::half *src0_as_f16 = (sycl::half *)src0_ddq; |     sycl::half *src0_as_f16 = (sycl::half *)src0_ddq; | ||||||
|     float * src1_ddf = (float *) src1->data; |     float * src1_ddf = (float *) src1->data; | ||||||
| @@ -3514,15 +3510,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|     sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf |     sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf | ||||||
|                                                        : src1_f16_alloc.get(); |                                                        : src1_f16_alloc.get(); | ||||||
|  |  | ||||||
|     ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); |  | ||||||
|     char * dst_t; |     char * dst_t; | ||||||
|  |  | ||||||
|     dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; |     dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; | ||||||
|     dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; |     dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; | ||||||
|     if (no_mixed_dtypes) { |  | ||||||
|         cu_compute_type = dpct::library_data_t::real_half; |  | ||||||
|         cu_data_type = dpct::library_data_t::real_half; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // dst strides |     // dst strides | ||||||
|     size_t nbd2 = dst->nb[2]; |     size_t nbd2 = dst->nb[2]; | ||||||
| @@ -3531,26 +3522,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|     const float alpha_f32 = 1.0f; |     const float alpha_f32 = 1.0f; | ||||||
|     const float beta_f32 = 0.0f; |     const float beta_f32 = 0.0f; | ||||||
|  |  | ||||||
|     const sycl::half alpha_f16 = 1.0f; |  | ||||||
|     const sycl::half beta_f16 = 0.0f; |  | ||||||
|  |  | ||||||
|     const void * alpha = &alpha_f32; |     const void * alpha = &alpha_f32; | ||||||
|     const void * beta  = &beta_f32; |     const void * beta  = &beta_f32; | ||||||
|     if (no_mixed_dtypes) { |  | ||||||
|         alpha = &alpha_f16; |  | ||||||
|         beta  = &beta_f16; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway |  | ||||||
|     // when oneMKL open source supports half, half, float, float: datatypes |  | ||||||
|  |  | ||||||
|     dst_t = (char *) dst_ddf; |     dst_t = (char *) dst_ddf; | ||||||
|     if (no_mixed_dtypes) { |  | ||||||
|         dst_t = (char *) dst_f16.alloc(ne_dst); |  | ||||||
|  |  | ||||||
|         nbd2 /= sizeof(float) / sizeof(sycl::half); |  | ||||||
|         nbd3 /= sizeof(float) / sizeof(sycl::half); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     GGML_ASSERT(ne12 % ne02 == 0); |     GGML_ASSERT(ne12 % ne02 == 0); | ||||||
|     GGML_ASSERT(ne13 % ne03 == 0); |     GGML_ASSERT(ne13 % ne03 == 0); | ||||||
| @@ -3612,11 +3587,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, |             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, | ||||||
|             cu_compute_type))); |             cu_compute_type))); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (no_mixed_dtypes) { |  | ||||||
|         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); |  | ||||||
|         to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream); |  | ||||||
|     } |  | ||||||
| } | } | ||||||
| catch (sycl::exception const &exc) { | catch (sycl::exception const &exc) { | ||||||
|   std::cerr << exc.what() << "Exception caught at file:" << __FILE__ |   std::cerr << exc.what() << "Exception caught at file:" << __FILE__ | ||||||
|   | |||||||
| @@ -2426,6 +2426,7 @@ namespace dpct | |||||||
|                                            b, ldb, beta, c, ldc, batch_size); |                                            b, ldb, beta, c, ldc, batch_size); | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  | #endif | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_int8, library_data_t::real_int8, |             library_data_t::real_int8, library_data_t::real_int8, | ||||||
|             library_data_t::real_int32, library_data_t::real_int32): |             library_data_t::real_int32, library_data_t::real_int32): | ||||||
| @@ -2458,7 +2459,6 @@ namespace dpct | |||||||
|                 batch_size); |                 batch_size); | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
| #endif |  | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_half, library_data_t::real_half, |             library_data_t::real_half, library_data_t::real_half, | ||||||
|             library_data_t::real_half, library_data_t::real_float): |             library_data_t::real_half, library_data_t::real_float): | ||||||
| @@ -2595,6 +2595,7 @@ namespace dpct | |||||||
|                                            stride_c, batch_size); |                                            stride_c, batch_size); | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|  | #endif | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_int8, library_data_t::real_int8, |             library_data_t::real_int8, library_data_t::real_int8, | ||||||
|             library_data_t::real_int32, library_data_t::real_int32): |             library_data_t::real_int32, library_data_t::real_int32): | ||||||
| @@ -2623,7 +2624,6 @@ namespace dpct | |||||||
|                 beta, c, ldc, stride_c, batch_size); |                 beta, c, ldc, stride_c, batch_size); | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
| #endif |  | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_half, library_data_t::real_half, |             library_data_t::real_half, library_data_t::real_half, | ||||||
|             library_data_t::real_half, library_data_t::real_float): |             library_data_t::real_half, library_data_t::real_float): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user