mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sycl: use oneDNN for matrices multiplication (#12972)
This commit is contained in:
		 Łukasz Ślusarczyk
					Łukasz Ślusarczyk
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							6c8b91500e
						
					
				
				
					commit
					9c404ed54c
				
			| @@ -731,6 +731,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | |||||||
| | GGML_SYCL_DEVICE_ARCH | Optional (except for AMD)             | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | | | GGML_SYCL_DEVICE_ARCH | Optional (except for AMD)             | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | | ||||||
| | GGML_SYCL_F16      | OFF *(default)* \|ON *(optional)*     | Enable FP16 build with SYCL code path.      | | | GGML_SYCL_F16      | OFF *(default)* \|ON *(optional)*     | Enable FP16 build with SYCL code path.      | | ||||||
| | GGML_SYCL_GRAPH    | ON *(default)* \|OFF *(Optional)*     | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | | | GGML_SYCL_GRAPH    | ON *(default)* \|OFF *(Optional)*     | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | | ||||||
|  | | GGML_SYCL_DNN      | ON *(default)* \|OFF *(Optional)*     | Enable build with oneDNN.                   | | ||||||
| | CMAKE_C_COMPILER   | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path.      | | | CMAKE_C_COMPILER   | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path.      | | ||||||
| | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)*   | Set `icpx/icx` compiler for SYCL code path. | | | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)*   | Set `icpx/icx` compiler for SYCL code path. | | ||||||
|  |  | ||||||
| @@ -741,6 +742,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | |||||||
| | GGML_SYCL_DEBUG   | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG                                                                             | | | GGML_SYCL_DEBUG   | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG                                                                             | | ||||||
| | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase | | | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase | | ||||||
| | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | ||||||
|  | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ||||||
| | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer | | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer | | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -193,6 +193,7 @@ option(GGML_RPC                             "ggml: use RPC" | |||||||
| option(GGML_SYCL                            "ggml: use SYCL"                                  OFF) | option(GGML_SYCL                            "ggml: use SYCL"                                  OFF) | ||||||
| option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF) | option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF) | ||||||
| option(GGML_SYCL_GRAPH                      "ggml: enable graphs in the SYCL backend"         ON) | option(GGML_SYCL_GRAPH                      "ggml: enable graphs in the SYCL backend"         ON) | ||||||
|  | option(GGML_SYCL_DNN                        "ggml: enable oneDNN in the SYCL backend"         ON) | ||||||
| set   (GGML_SYCL_TARGET "INTEL" CACHE STRING | set   (GGML_SYCL_TARGET "INTEL" CACHE STRING | ||||||
|                                             "ggml: sycl target device") |                                             "ggml: sycl target device") | ||||||
| set   (GGML_SYCL_DEVICE_ARCH "" CACHE STRING | set   (GGML_SYCL_DEVICE_ARCH "" CACHE STRING | ||||||
|   | |||||||
| @@ -49,9 +49,10 @@ endif() | |||||||
| target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") | target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") | ||||||
|  |  | ||||||
| # Link against oneDNN | # Link against oneDNN | ||||||
| find_package(DNNL) |  | ||||||
| set(GGML_SYCL_DNNL 0) | set(GGML_SYCL_DNNL 0) | ||||||
| if(DNNL_FOUND) | if(GGML_SYCL_DNN) | ||||||
|  |     find_package(DNNL) | ||||||
|  |     if(DNNL_FOUND) | ||||||
|         if (NOT DEFINED DNNL_GPU_VENDOR) |         if (NOT DEFINED DNNL_GPU_VENDOR) | ||||||
|             # default to intel target |             # default to intel target | ||||||
|             set(DNNL_GPU_VENDOR "INTEL") |             set(DNNL_GPU_VENDOR "INTEL") | ||||||
| @@ -75,8 +76,11 @@ if(DNNL_FOUND) | |||||||
|                  llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. |                  llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. | ||||||
|                  Disabling oneDNN support.") |                  Disabling oneDNN support.") | ||||||
|         endif() |         endif() | ||||||
| else() |     else() | ||||||
|         message(STATUS "oneDNN not found, disabling oneDNN support") |         message(STATUS "oneDNN not found, disabling oneDNN support") | ||||||
|  |     endif() | ||||||
|  | else() | ||||||
|  |     message(STATUS "oneDNN support disabled by the user") | ||||||
| endif() | endif() | ||||||
| target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) | target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -32,16 +32,36 @@ public: | |||||||
|         else static_assert(0); |         else static_assert(0); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, |     // matrix A has m rows, k columns | ||||||
|                                 const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { |     // matrix B has k rows, n columns | ||||||
|  |     // nra - number of elements to skip when moving into next row in A | ||||||
|  |     // nrb - number of elements to skip when moving into next row in B | ||||||
|  |     // nca - number of elements to skip when moving into next column in A | ||||||
|  |     // ncb - number of elements to skip when moving into next column in B | ||||||
|  |     // stride_a - number of elements to skip when moving to next A matrix | ||||||
|  |     // stride_b - number of elements to skip when moving to next B matrix | ||||||
|  |     // batches_a - number of A matrices | ||||||
|  |     // batches_b - number of B matrices | ||||||
|  |     static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, | ||||||
|  |         const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, | ||||||
|  |         const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, | ||||||
|  |         void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { | ||||||
|  |  | ||||||
|         auto stream = ctx.stream_dnnl(q); |         auto stream = ctx.stream_dnnl(q); | ||||||
|         auto eng = ctx.engine_dnnl(q); |         auto eng = ctx.engine_dnnl(q); | ||||||
|         dnnl::memory::dims a_dims = { m, k }; |  | ||||||
|         dnnl::memory::dims b_dims = { k, n }; |         // { # strides, # rows, # columns } | ||||||
|         dnnl::memory::dims c_dims = { m, n }; |         dnnl::memory::dims a_dims = { batches_a, m, k }; | ||||||
|         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); |         dnnl::memory::dims b_dims = { batches_b, k, n }; | ||||||
|         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); |         dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; | ||||||
|         const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::ab); |  | ||||||
|  |         // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } | ||||||
|  |         dnnl::memory::dims a_strides = { stride_a, nra, nca }; | ||||||
|  |         dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; | ||||||
|  |  | ||||||
|  |         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); | ||||||
|  |         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); | ||||||
|  |         const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::abc); | ||||||
|  |  | ||||||
|         dnnl::primitive_attr primitive_attr; |         dnnl::primitive_attr primitive_attr; | ||||||
|         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); | ||||||
| @@ -63,6 +83,15 @@ public: | |||||||
|  |  | ||||||
|         matmul_prim.execute(stream, matmul_args); |         matmul_prim.execute(stream, matmul_args); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // matrices A and B are column major, both having k rows | ||||||
|  |     // matrix A has m column, matrix B has n columns | ||||||
|  |     // output: column major matrix C = A transposed * B | ||||||
|  |     static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, | ||||||
|  |         const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { | ||||||
|  |  | ||||||
|  |         gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); | ||||||
|  |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #endif | #endif | ||||||
|   | |||||||
| @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; | |||||||
| int g_ggml_sycl_debug = 0; | int g_ggml_sycl_debug = 0; | ||||||
| int g_ggml_sycl_disable_optimize = 0; | int g_ggml_sycl_disable_optimize = 0; | ||||||
| int g_ggml_sycl_disable_graph = 0; | int g_ggml_sycl_disable_graph = 0; | ||||||
|  | int g_ggml_sycl_disable_dnn = 0; | ||||||
| int g_ggml_sycl_prioritize_dmmv = 0; | int g_ggml_sycl_prioritize_dmmv = 0; | ||||||
|  |  | ||||||
| static ggml_sycl_device_info ggml_sycl_init() { | static ggml_sycl_device_info ggml_sycl_init() { | ||||||
| @@ -196,12 +197,22 @@ static void ggml_check_sycl() try { | |||||||
|         g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); |         g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); | ||||||
|         g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); |         g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); | ||||||
|         g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); |         g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); | ||||||
|  |         g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); | ||||||
|         g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); |         g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); | ||||||
|         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); |         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); | ||||||
|         GGML_LOG_INFO("Running with Environment Variables:\n"); |         GGML_LOG_INFO("Running with Environment Variables:\n"); | ||||||
|         GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); |         GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); | ||||||
|         GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); |         GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); | ||||||
|  | #ifdef GGML_SYCL_GRAPH | ||||||
|         GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); |         GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); | ||||||
|  | #else | ||||||
|  |         GGML_LOG_INFO("  GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); | ||||||
|  | #endif | ||||||
|  | #if GGML_SYCL_DNNL | ||||||
|  |         GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); | ||||||
|  | #else | ||||||
|  |         GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); | ||||||
|  | #endif | ||||||
|         GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); |         GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); | ||||||
|         GGML_LOG_INFO("Build with Macros:\n"); |         GGML_LOG_INFO("Build with Macros:\n"); | ||||||
| #if defined(GGML_SYCL_FORCE_MMQ) | #if defined(GGML_SYCL_FORCE_MMQ) | ||||||
| @@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|  |  | ||||||
|     const int64_t ne00 = src0->ne[0]; |     const int64_t ne00 = src0->ne[0]; | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |     GGML_ASSERT(ne00 == ne10); | ||||||
|  |  | ||||||
|     const int64_t row_diff = row_high - row_low; |     const int64_t row_diff = row_high - row_low; | ||||||
|  |  | ||||||
|     int id; |     int id; | ||||||
|     SYCL_CHECK( |     SYCL_CHECK( | ||||||
|         CHECK_TRY_ERROR(id = get_current_device_id())); |         CHECK_TRY_ERROR(id = get_current_device_id())); | ||||||
| #if !GGML_SYCL_DNNL |  | ||||||
|     const int64_t ne0 = dst->ne[0]; |     const int64_t ne0 = dst->ne[0]; // used by MKL only | ||||||
|     // the main device has a larger memory buffer to hold the results from all GPUs |     // the main device has a larger memory buffer to hold the results from all GPUs | ||||||
|     // ldc == nrows of the matrix that cuBLAS writes into |     // ldc == nrows of the matrix that cuBLAS writes into | ||||||
|     int ldc = id == ctx.device ? ne0 : row_diff; |     int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef GGML_SYCL_F16 | #ifdef GGML_SYCL_F16 | ||||||
|     bool use_fp16 = true;  // TODO(Yu) SYCL capability check |     bool use_fp16 = true;  // TODO(Yu) SYCL capability check | ||||||
| @@ -2033,7 +2043,17 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|                                          : src1_as_f16.get(); |                                          : src1_as_f16.get(); | ||||||
|         ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols); |         ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols); | ||||||
|  |  | ||||||
| #if !GGML_SYCL_DNNL | #if GGML_SYCL_DNNL | ||||||
|  |         if (!g_ggml_sycl_disable_dnn) { | ||||||
|  |             DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, | ||||||
|  |                                       DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), | ||||||
|  |                                       dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream); | ||||||
|  |             const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); | ||||||
|  |             to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); | ||||||
|  |         } | ||||||
|  |         else | ||||||
|  | #endif | ||||||
|  |         { | ||||||
|             const sycl::half alpha_f16 = 1.0f; |             const sycl::half alpha_f16 = 1.0f; | ||||||
|             const sycl::half beta_f16  = 0.0f; |             const sycl::half beta_f16  = 0.0f; | ||||||
|             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( |             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( | ||||||
| @@ -2045,13 +2065,7 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|                 dpct::library_data_t::real_half))); |                 dpct::library_data_t::real_half))); | ||||||
|             const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); |             const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); | ||||||
|             to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); |             to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); | ||||||
| #else |         } | ||||||
|         DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, |  | ||||||
|                                   DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), |  | ||||||
|                                   dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream); |  | ||||||
|         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); |  | ||||||
|         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); |  | ||||||
| #endif |  | ||||||
|     } |     } | ||||||
|     else { |     else { | ||||||
|         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); |         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); | ||||||
| @@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); |         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); | ||||||
|         const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); |         const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); | ||||||
|  |  | ||||||
| #if !GGML_SYCL_DNNL | #if GGML_SYCL_DNNL | ||||||
|  |         if (!g_ggml_sycl_disable_dnn) { | ||||||
|  |             DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, | ||||||
|  |                                       DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), | ||||||
|  |                                       dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); | ||||||
|  |         } | ||||||
|  |         else | ||||||
|  | #endif | ||||||
|  |         { | ||||||
|             const float alpha = 1.0f; |             const float alpha = 1.0f; | ||||||
|             const float beta  = 0.0f; |             const float beta  = 0.0f; | ||||||
|             SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( |             SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( | ||||||
|                 get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, |                 get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, | ||||||
|                 src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, |                 src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, | ||||||
|                 dpct::get_value(&beta, *stream), dst_dd_i, ldc))); |                 dpct::get_value(&beta, *stream), dst_dd_i, ldc))); | ||||||
| #else |         } | ||||||
|         DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, |  | ||||||
|                                   DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), |  | ||||||
|                                   dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); |  | ||||||
| #endif |  | ||||||
|     } |     } | ||||||
|     GGML_UNUSED(dst); |     GGML_UNUSED(dst); | ||||||
|     GGML_UNUSED(src1_ddq_i); |     GGML_UNUSED(src1_ddq_i); | ||||||
| @@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) { | |||||||
|   std::exit(1); |   std::exit(1); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, | static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, | ||||||
|                                    const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, |                                    const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, | ||||||
|                                    size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, |                                    size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, | ||||||
|                                    int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { |                                    int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { | ||||||
| @@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h | |||||||
|  |  | ||||||
|     const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16); |     const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16); | ||||||
|     const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16); |     const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16); | ||||||
|     uint8_t *       dst_bytes  = reinterpret_cast<uint8_t *>(dst); |     uint8_t *       dst_bytes  = static_cast<uint8_t *>(dst); | ||||||
|  |  | ||||||
|     ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; |     ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; | ||||||
|     ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; |     ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; | ||||||
| @@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons | |||||||
|     GGML_ASSERT(!ggml_is_transposed(src1)); |     GGML_ASSERT(!ggml_is_transposed(src1)); | ||||||
|     GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); |     GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); |     GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||||
|  |     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|     GGML_TENSOR_BINARY_OP_LOCALS |     GGML_TENSOR_BINARY_OP_LOCALS | ||||||
|  |  | ||||||
| @@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); |     ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); | ||||||
|     char *                           dst_t = reinterpret_cast<char *>(dst_ddf); |  | ||||||
|  |  | ||||||
|     dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; |     dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; | ||||||
|     dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float; |     dpct::library_data_t mkl_data_type    = dpct::library_data_t::real_float; | ||||||
| @@ -2783,17 +2801,57 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons | |||||||
|  |  | ||||||
|     GGML_ASSERT(ne12 % ne02 == 0); |     GGML_ASSERT(ne12 % ne02 == 0); | ||||||
|     GGML_ASSERT(ne13 % ne03 == 0); |     GGML_ASSERT(ne13 % ne03 == 0); | ||||||
|  |     GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0)); | ||||||
|  |     GGML_ASSERT(ne10 == ne00); | ||||||
|  |  | ||||||
|     // broadcast factors |     // broadcast factors | ||||||
|     const int64_t r2 = ne12 / ne02; |     const int64_t r2 = ne12 / ne02; | ||||||
|     const int64_t r3 = ne13 / ne03; |     const int64_t r3 = ne13 / ne03; | ||||||
|  |  | ||||||
|  | #if GGML_SYCL_DNNL | ||||||
|  |     if (!g_ggml_sycl_disable_dnn) { | ||||||
|  |         auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] | ||||||
|  |             (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { | ||||||
|  |  | ||||||
|  |             DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, | ||||||
|  |                             src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12, | ||||||
|  |                             src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00, | ||||||
|  |                             dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b); | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         if (r2 == 1 && r3 == 1) { | ||||||
|  |             if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { | ||||||
|  |                 dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { | ||||||
|  |                     const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes | ||||||
|  |                     const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; | ||||||
|  |                     float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); | ||||||
|  |                     dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             // iterate over batches from smaller set of matrices (matrix 0) | ||||||
|  |             for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { | ||||||
|  |                 for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { | ||||||
|  |                     const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); | ||||||
|  |                     const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; | ||||||
|  |                     float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); | ||||||
|  |                     dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     else | ||||||
|  | #endif | ||||||
|  |     { | ||||||
|         if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { |         if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { | ||||||
|             // there is no broadcast and src0, src1 are contiguous across dims 2, 3 |             // there is no broadcast and src0, src1 are contiguous across dims 2, 3 | ||||||
|             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, |             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, | ||||||
|                                                         oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, |                                                         oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, | ||||||
|                                                         src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, |                                                         src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, | ||||||
|                                                     src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t, |                                                         src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, | ||||||
|                                                         mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); |                                                         mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); | ||||||
|         } else { |         } else { | ||||||
|             const int ne23 = ne12 * ne13; |             const int ne23 = ne12 * ne13; | ||||||
| @@ -2809,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons | |||||||
|                 size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); |                 size_t        nb12_scaled  = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); | ||||||
|                 size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); |                 size_t        nb13_scaled  = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); | ||||||
|                 cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { |                 cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { | ||||||
|                 k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, |                     k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, | ||||||
|                                            nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); |                                            nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); | ||||||
|                 }); |                 }); | ||||||
|             }); |             }); | ||||||
| @@ -2820,6 +2878,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons | |||||||
|                 (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, |                 (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, | ||||||
|                 (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); |                 (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
| } catch (const sycl::exception & exc) { | } catch (const sycl::exception & exc) { | ||||||
|     std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; |     std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; | ||||||
|     std::exit(1); |     std::exit(1); | ||||||
| @@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ | |||||||
|             return GGML_STATUS_SUCCESS; |             return GGML_STATUS_SUCCESS; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); |         sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); | ||||||
|  |  | ||||||
|         model_sycl_graph.begin_recording(*(sycl_ctx->stream())); |         model_sycl_graph.begin_recording(*(sycl_ctx->stream())); | ||||||
|         ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); |         ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); | ||||||
|         model_sycl_graph.end_recording(); |         model_sycl_graph.end_recording(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user