mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	[SYCL] Add oneDNN primitive support (#9091)
* add onednn * add sycl_f16 * add dnnl stream * add engine map * use dnnl for intel only * use fp16fp16fp16 * update doc
This commit is contained in:
		| @@ -28,6 +28,7 @@ | |||||||
|     { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, |     { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, | ||||||
|     { "name": "reldbg",  "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, |     { "name": "reldbg",  "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, | ||||||
|     { "name": "static",  "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, |     { "name": "static",  "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, | ||||||
|  |     { "name": "sycl_f16",  "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         "name": "arm64-windows-msvc", "hidden": true, |         "name": "arm64-windows-msvc", "hidden": true, | ||||||
| @@ -60,6 +61,8 @@ | |||||||
|     { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, |     { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, | ||||||
|  |  | ||||||
|     { "name": "x64-windows-sycl-debug"  , "inherits": [ "sycl-base", "debug"   ] }, |     { "name": "x64-windows-sycl-debug"  , "inherits": [ "sycl-base", "debug"   ] }, | ||||||
|     { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] } |     { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, | ||||||
|  |     { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, | ||||||
|  |     { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] } | ||||||
|   ] |   ] | ||||||
| } | } | ||||||
|   | |||||||
| @@ -20,7 +20,7 @@ | |||||||
| **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: | **oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: | ||||||
|  |  | ||||||
| - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. | - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. | ||||||
| - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*. | - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*. | ||||||
| - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. | - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. | ||||||
| - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. | - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. | ||||||
|  |  | ||||||
| @@ -28,10 +28,6 @@ | |||||||
|  |  | ||||||
| The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). | The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). | ||||||
|  |  | ||||||
| When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend. |  | ||||||
|  |  | ||||||
| It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose. |  | ||||||
|  |  | ||||||
| ## Recommended Release | ## Recommended Release | ||||||
|  |  | ||||||
| The SYCL backend would be broken by some PRs due to no online CI. | The SYCL backend would be broken by some PRs due to no online CI. | ||||||
| @@ -45,6 +41,10 @@ The following release is verified with good quality: | |||||||
|  |  | ||||||
| ## News | ## News | ||||||
|  |  | ||||||
|  |  | ||||||
|  | - 2024.8 | ||||||
|  |   - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. | ||||||
|  |  | ||||||
| - 2024.5 | - 2024.5 | ||||||
|   - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. |   - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. | ||||||
|   - Arch Linux is verified successfully. |   - Arch Linux is verified successfully. | ||||||
| @@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li | |||||||
|  |  | ||||||
| Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. | Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable. | ||||||
|  |  | ||||||
| Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs. | Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs. | ||||||
|  |  | ||||||
| - **Adding support to Nvidia GPUs** | - **Adding support to Nvidia GPUs** | ||||||
|  |  | ||||||
| @@ -255,8 +255,6 @@ or | |||||||
| # Export relevant ENV variables | # Export relevant ENV variables | ||||||
| source /opt/intel/oneapi/setvars.sh | source /opt/intel/oneapi/setvars.sh | ||||||
|  |  | ||||||
| # Build LLAMA with MKL BLAS acceleration for intel GPU |  | ||||||
|  |  | ||||||
| # Option 1: Use FP32 (recommended for better performance in most cases) | # Option 1: Use FP32 (recommended for better performance in most cases) | ||||||
| cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx | cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx | ||||||
|  |  | ||||||
|   | |||||||
| @@ -549,6 +549,13 @@ if (GGML_SYCL) | |||||||
|     file(GLOB   GGML_SOURCES_SYCL "ggml-sycl/*.cpp") |     file(GLOB   GGML_SOURCES_SYCL "ggml-sycl/*.cpp") | ||||||
|     list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") |     list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") | ||||||
|  |  | ||||||
|  |     find_package(DNNL) | ||||||
|  |     message("-- DNNL found:"${DNNL_FOUND}) | ||||||
|  |     if (GGML_SYCL_TARGET STREQUAL "INTEL") | ||||||
|  |         add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) | ||||||
|  |     else() | ||||||
|  |         add_compile_definitions(GGML_SYCL_DNNL=0) | ||||||
|  |     endif() | ||||||
|     if (WIN32) |     if (WIN32) | ||||||
|         find_package(IntelSYCL REQUIRED) |         find_package(IntelSYCL REQUIRED) | ||||||
|         find_package(MKL REQUIRED) |         find_package(MKL REQUIRED) | ||||||
| @@ -561,6 +568,9 @@ if (GGML_SYCL) | |||||||
|             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) |             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) | ||||||
|         endif() |         endif() | ||||||
|     endif() |     endif() | ||||||
|  |     if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") | ||||||
|  |         list(APPEND GGML_EXTRA_LIBS DNNL::dnnl) | ||||||
|  |     endif() | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
| if (GGML_RPC) | if (GGML_RPC) | ||||||
|   | |||||||
| @@ -38,6 +38,7 @@ | |||||||
|  |  | ||||||
| #include "ggml-sycl/backend.hpp" | #include "ggml-sycl/backend.hpp" | ||||||
| #include "ggml-sycl/presets.hpp" | #include "ggml-sycl/presets.hpp" | ||||||
|  | #include "ggml-sycl/gemm.hpp" | ||||||
|  |  | ||||||
| bool   ggml_sycl_loaded(void); | bool   ggml_sycl_loaded(void); | ||||||
| void   ggml_sycl_free_data(struct ggml_tensor * tensor); | void   ggml_sycl_free_data(struct ggml_tensor * tensor); | ||||||
| @@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|  |  | ||||||
|         const sycl::half alpha_f16 = 1.0f; |         const sycl::half alpha_f16 = 1.0f; | ||||||
|         const sycl::half beta_f16 = 0.0f; |         const sycl::half beta_f16 = 0.0f; | ||||||
|  | #if !GGML_SYCL_DNNL | ||||||
|         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( |         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( | ||||||
|             *stream, oneapi::mkl::transpose::trans, |             *stream, oneapi::mkl::transpose::trans, | ||||||
|             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, |             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, | ||||||
| @@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|             dpct::library_data_t::real_half))); |             dpct::library_data_t::real_half))); | ||||||
|         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); |         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); | ||||||
|         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); |         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); | ||||||
|  | #else | ||||||
|  |         auto dnnl_stream = ctx.stream_dnnl(stream); | ||||||
|  |         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), | ||||||
|  |             src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>()); | ||||||
|  |         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); | ||||||
|  |         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
|     else { |     else { | ||||||
|         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); |         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); | ||||||
| @@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl( | |||||||
|  |  | ||||||
|         const float alpha = 1.0f; |         const float alpha = 1.0f; | ||||||
|         const float beta = 0.0f; |         const float beta = 0.0f; | ||||||
|  | #if !GGML_SYCL_DNNL | ||||||
|         SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( |         SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( | ||||||
|             *stream, oneapi::mkl::transpose::trans, |             *stream, oneapi::mkl::transpose::trans, | ||||||
|             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, |             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, | ||||||
|             dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, |             dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, | ||||||
|             src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), |             src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), | ||||||
|             dst_dd_i, ldc))); |             dst_dd_i, ldc))); | ||||||
|  | #else | ||||||
|  |         auto dnnl_stream = ctx.stream_dnnl(stream); | ||||||
|  |          DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(), | ||||||
|  |             src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>()); | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
|     (void) dst; |     (void) dst; | ||||||
|     (void) src1_ddq_i; |     (void) src1_ddq_i; | ||||||
|   | |||||||
| @@ -19,6 +19,10 @@ | |||||||
| #include "dpct/helper.hpp" | #include "dpct/helper.hpp" | ||||||
| #include "ggml-sycl.h" | #include "ggml-sycl.h" | ||||||
| #include "presets.hpp" | #include "presets.hpp" | ||||||
|  | #if GGML_SYCL_DNNL | ||||||
|  | #include "dnnl.hpp" | ||||||
|  | #include "dnnl_sycl.hpp" | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #define GGML_COMMON_DECL_SYCL | #define GGML_COMMON_DECL_SYCL | ||||||
| #define GGML_COMMON_IMPL_SYCL | #define GGML_COMMON_IMPL_SYCL | ||||||
| @@ -277,6 +281,52 @@ struct ggml_backend_sycl_context { | |||||||
|         return stream(device, 0); |         return stream(device, 0); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | #if GGML_SYCL_DNNL | ||||||
|  |     dnnl::engine make_engine(sycl::queue* q) { | ||||||
|  |         // Get the device associated with the queue | ||||||
|  |         sycl::device dev = q->get_device(); | ||||||
|  |         // Get the context associated with the queue | ||||||
|  |         sycl::context ctx = q->get_context(); | ||||||
|  |         const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); | ||||||
|  |         return eng; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::unordered_map<sycl::queue*, dnnl::stream> stream_map; | ||||||
|  |     std::unordered_map<sycl::queue*, dnnl::engine> engine_map; | ||||||
|  |     dnnl::stream stream_dnnl(int device, int _stream) { | ||||||
|  |         auto q = stream(device, _stream); | ||||||
|  |         return stream_dnnl(q); | ||||||
|  |     } | ||||||
|  |     dnnl::engine engine_dnnl(sycl::queue* qptr) { | ||||||
|  |         auto it = engine_map.find(qptr); | ||||||
|  |         if (it == engine_map.end()) { | ||||||
|  |             auto eng = make_engine(qptr); | ||||||
|  |             engine_map[qptr] = eng; | ||||||
|  |             return eng; | ||||||
|  |         } | ||||||
|  |         else | ||||||
|  |         { | ||||||
|  |             return it->second; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     dnnl::stream stream_dnnl(sycl::queue* qptr) { | ||||||
|  |         auto it = stream_map.find(qptr); | ||||||
|  |         if (it == stream_map.end()) { | ||||||
|  |             auto eng = engine_dnnl(qptr); | ||||||
|  |             auto stream = dnnl::sycl_interop::make_stream(eng, *qptr); | ||||||
|  |             stream_map[qptr] = stream; | ||||||
|  |             return stream; | ||||||
|  |         } | ||||||
|  |         else | ||||||
|  |         { | ||||||
|  |             return it->second; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     dnnl::stream stream_dnnl() { | ||||||
|  |         return stream_dnnl(device, 0); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     // pool |     // pool | ||||||
|     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; |     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										101
									
								
								ggml/src/ggml-sycl/gemm.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ggml/src/ggml-sycl/gemm.hpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | |||||||
|  | // | ||||||
|  | // MIT license | ||||||
|  | // Copyright (C) 2024 Intel Corporation | ||||||
|  | // SPDX-License-Identifier: MIT | ||||||
|  | // | ||||||
|  |  | ||||||
|  | // | ||||||
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||
|  | // See https://llvm.org/LICENSE.txt for license information. | ||||||
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||
|  | // | ||||||
|  |  | ||||||
|  | #ifndef GGML_SYCL_GEMM_HPP | ||||||
|  | #define GGML_SYCL_GEMM_HPP | ||||||
|  |  | ||||||
|  | #include <fstream> | ||||||
|  | #include <iostream> | ||||||
|  |  | ||||||
|  | #include "ggml-sycl.h" | ||||||
|  |  | ||||||
|  | #if GGML_SYCL_DNNL | ||||||
|  |  | ||||||
|  | #include "dnnl.hpp" | ||||||
|  | #include "dnnl_sycl.hpp" | ||||||
|  |  | ||||||
|  | class DnnlGemmWrapper { | ||||||
|  | public: | ||||||
|  |     using dt = dnnl::memory::data_type; | ||||||
|  |     using tag = dnnl::memory::format_tag; | ||||||
|  |  | ||||||
|  |     template<typename T> | ||||||
|  |     static constexpr dt to_dt() { | ||||||
|  |         if constexpr (std::is_same_v<T, float>) return dt::f32; | ||||||
|  |         else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; | ||||||
|  |         else static_assert(0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static inline void row_gemm(sycl::queue& q, bool a_trans, | ||||||
|  |         bool b_trans, int m, int n, int k, | ||||||
|  |         const void* a, dt at, const void* b, dt bt, void* c, dt ct) | ||||||
|  |     { | ||||||
|  |         // Get the device associated with the queue | ||||||
|  |         sycl::device dev = q.get_device(); | ||||||
|  |         // Get the context associated with the queue | ||||||
|  |         sycl::context ctx = q.get_context(); | ||||||
|  |         const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); | ||||||
|  |         const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); | ||||||
|  |         dnnl::memory::dims a_dims = { m, k }; | ||||||
|  |         dnnl::memory::dims b_dims = { k, n }; | ||||||
|  |         dnnl::memory::dims c_dims = { m, n }; | ||||||
|  |         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); | ||||||
|  |         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); | ||||||
|  |         const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); | ||||||
|  |         auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); | ||||||
|  |         auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); | ||||||
|  |         auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); | ||||||
|  |         auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); | ||||||
|  |  | ||||||
|  |         // Create the primitive. | ||||||
|  |         auto matmul_prim = dnnl::matmul(matmul_pd); | ||||||
|  |         // Primitive arguments. | ||||||
|  |         std::unordered_map<int, dnnl::memory> matmul_args; | ||||||
|  |         matmul_args.insert({ DNNL_ARG_SRC, a_mem }); | ||||||
|  |         matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); | ||||||
|  |         matmul_args.insert({ DNNL_ARG_DST, c_mem }); | ||||||
|  |  | ||||||
|  |         matmul_prim.execute(stream, matmul_args); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     static inline void row_gemm(const dnnl::stream& stream, bool a_trans, | ||||||
|  |         bool b_trans, int m, int n, int k, | ||||||
|  |         const void* a, dt at, const void* b, dt bt, void* c, dt ct) | ||||||
|  |     { | ||||||
|  |         auto const eng = stream.get_engine(); | ||||||
|  |         dnnl::memory::dims a_dims = { m, k }; | ||||||
|  |         dnnl::memory::dims b_dims = { k, n }; | ||||||
|  |         dnnl::memory::dims c_dims = { m, n }; | ||||||
|  |         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); | ||||||
|  |         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); | ||||||
|  |         const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); | ||||||
|  |         auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); | ||||||
|  |         auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); | ||||||
|  |         auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); | ||||||
|  |         auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); | ||||||
|  |  | ||||||
|  |         // Create the primitive. | ||||||
|  |         auto matmul_prim = dnnl::matmul(matmul_pd); | ||||||
|  |         // Primitive arguments. | ||||||
|  |         std::unordered_map<int, dnnl::memory> matmul_args; | ||||||
|  |         matmul_args.insert({ DNNL_ARG_SRC, a_mem }); | ||||||
|  |         matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); | ||||||
|  |         matmul_args.insert({ DNNL_ARG_DST, c_mem }); | ||||||
|  |  | ||||||
|  |         matmul_prim.execute(stream, matmul_args); | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #endif // GGML_SYCL_GEMM_HPP | ||||||
		Reference in New Issue
	
	Block a user
	 luoyu-intel
					luoyu-intel