mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	SYCL: Introducing memory host pool (#11251)
* Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp * Remove unnecessary headers and cast * Reorder member variable to avoid warning on initialization * Formatting * Remove unused variable * Address PR review feedback - remove warning --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
		| @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { | |||||||
|     // pool |     // pool | ||||||
|     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; |     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; | ||||||
|  |  | ||||||
|  |     std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES]; | ||||||
|  |  | ||||||
|     static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device); |     static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device); | ||||||
|  |  | ||||||
|  |     static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device); | ||||||
|  |  | ||||||
|     ggml_sycl_pool & pool(int device) { |     ggml_sycl_pool & pool(int device) { | ||||||
|         if (pools[device] == nullptr) { |         if (pools[device] == nullptr) { | ||||||
|             pools[device] = new_pool_for_device(stream(device,0), device); |             pools[device] = new_pool_for_device(stream(device,0), device); | ||||||
| @@ -345,6 +349,15 @@ struct ggml_backend_sycl_context { | |||||||
|     ggml_sycl_pool & pool() { |     ggml_sycl_pool & pool() { | ||||||
|         return pool(device); |         return pool(device); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     ggml_sycl_pool & host_pool(int device) { | ||||||
|  |         if (host_pools[device] == nullptr) { | ||||||
|  |             host_pools[device] = new_pool_for_host(stream(device, 0), device); | ||||||
|  |         } | ||||||
|  |         return *host_pools[device]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     ggml_sycl_pool & host_pool() { return host_pool(device); } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // common device functions | // common device functions | ||||||
|   | |||||||
| @@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { | |||||||
|     return device_type.str(); |     return device_type.str(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename Ts> struct matrix_info_t { | ||||||
|  |     oneapi::mkl::transpose transpose_info[2]; | ||||||
|  |     Ts                     value_info[2]; | ||||||
|  |     std::int64_t           size_info[3]; | ||||||
|  |     std::int64_t           ld_info[3]; | ||||||
|  |     std::int64_t           groupsize_info; | ||||||
|  | }; | ||||||
|  |  | ||||||
| namespace dpct | namespace dpct | ||||||
| { | { | ||||||
|     typedef sycl::queue *queue_ptr; |     typedef sycl::queue *queue_ptr; | ||||||
| @@ -1727,26 +1735,13 @@ namespace dpct | |||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         template <class Ta, class Tb, class Tc, class Ts> |         template <class Ta, class Tb, class Tc, class Ts> | ||||||
|         inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, |         inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, | ||||||
|                                     oneapi::mkl::transpose b_trans, int m, int n, int k, |                                     int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, | ||||||
|                                     const void *alpha, const void **a, int lda, |                                     int ldb, const void * beta, void ** c, int ldc, int batch_size, | ||||||
|                                     const void **b, int ldb, const void *beta, void **c, |                                     matrix_info_t<float> * matrix_info) { | ||||||
|                                     int ldc, int batch_size) |  | ||||||
|         { |  | ||||||
|             struct matrix_info_t |  | ||||||
|             { |  | ||||||
|                 oneapi::mkl::transpose transpose_info[2]; |  | ||||||
|                 Ts value_info[2]; |  | ||||||
|                 std::int64_t size_info[3]; |  | ||||||
|                 std::int64_t ld_info[3]; |  | ||||||
|                 std::int64_t groupsize_info; |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); |             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); | ||||||
|             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); |             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); | ||||||
|  |  | ||||||
|             matrix_info_t *matrix_info = |  | ||||||
|                 (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); |  | ||||||
|             matrix_info->transpose_info[0] = a_trans; |             matrix_info->transpose_info[0] = a_trans; | ||||||
|             matrix_info->transpose_info[1] = b_trans; |             matrix_info->transpose_info[1] = b_trans; | ||||||
|             matrix_info->value_info[0] = alpha_value; |             matrix_info->value_info[0] = alpha_value; | ||||||
| @@ -1763,23 +1758,18 @@ namespace dpct | |||||||
|             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( |             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( | ||||||
|                 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, |                 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, | ||||||
|                 matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, |                 matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, | ||||||
|                 matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a), |                 matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), | ||||||
|                 matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1, |                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), | ||||||
|                 matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, |                 matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1), | ||||||
|                 &(matrix_info->groupsize_info)); |                 reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); | ||||||
| #else | #else | ||||||
|             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( |             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( | ||||||
|                 q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, |                 q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, | ||||||
|                 matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, |                 matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), | ||||||
|                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), |                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), | ||||||
|                 matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), |                 matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1), | ||||||
|                 matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); |                 reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|             q.submit([&](sycl::handler &cgh) |  | ||||||
|                      { |  | ||||||
|     cgh.depends_on(e); |  | ||||||
|     cgh.host_task([=] { std::free(matrix_info); }); }); |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         template <class Ta, class Tb, class Tc, class Ts> |         template <class Ta, class Tb, class Tc, class Ts> | ||||||
| @@ -2422,25 +2412,11 @@ namespace dpct | |||||||
|     /// \param [in] ldc Leading dimension of C. |     /// \param [in] ldc Leading dimension of C. | ||||||
|     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. |     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. | ||||||
|     /// \param [in] scaling_type Data type of the scaling factors. |     /// \param [in] scaling_type Data type of the scaling factors. | ||||||
|     inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, |     inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, | ||||||
|                            oneapi::mkl::transpose b_trans, int m, int n, int k, |                            int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, | ||||||
|                            const void *alpha, const void *a[], |                            const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], | ||||||
|                            library_data_t a_type, int lda, const void *b[], |                            library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, | ||||||
|                            library_data_t b_type, int ldb, const void *beta, |                            matrix_info_t<float> * matrix_info) { | ||||||
|                            void *c[], library_data_t c_type, int ldc, |  | ||||||
|                            int batch_size, library_data_t scaling_type) |  | ||||||
|     { |  | ||||||
|         if (scaling_type == library_data_t::real_float && |  | ||||||
|             c_type == library_data_t::complex_float) |  | ||||||
|         { |  | ||||||
|             scaling_type = library_data_t::complex_float; |  | ||||||
|         } |  | ||||||
|         else if (scaling_type == library_data_t::real_double && |  | ||||||
|                  c_type == library_data_t::complex_double) |  | ||||||
|         { |  | ||||||
|             scaling_type = library_data_t::complex_double; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         std::uint64_t key = |         std::uint64_t key = | ||||||
|             detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); |             detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); | ||||||
|         switch (key) |         switch (key) | ||||||
| @@ -2449,48 +2425,24 @@ namespace dpct | |||||||
|             library_data_t::real_float, library_data_t::real_float, |             library_data_t::real_float, library_data_t::real_float, | ||||||
|             library_data_t::real_float, library_data_t::real_float): |             library_data_t::real_float, library_data_t::real_float): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<float, float, float, float>( |             detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |                                                                 beta, c, ldc, batch_size, matrix_info); | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_double, library_data_t::real_double, |             library_data_t::real_double, library_data_t::real_double, | ||||||
|             library_data_t::real_double, library_data_t::real_double): |             library_data_t::real_double, library_data_t::real_double): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<double, double, double, double>( |             detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |                                                                     beta, c, ldc, batch_size, matrix_info); | ||||||
|                 batch_size); |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|         case detail::get_type_combination_id( |  | ||||||
|             library_data_t::complex_float, library_data_t::complex_float, |  | ||||||
|             library_data_t::complex_float, library_data_t::complex_float): |  | ||||||
|         { |  | ||||||
|             detail::gemm_batch_impl<std::complex<float>, std::complex<float>, |  | ||||||
|                                     std::complex<float>, std::complex<float>>( |  | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |  | ||||||
|                 batch_size); |  | ||||||
|             break; |  | ||||||
|         } |  | ||||||
|         case detail::get_type_combination_id( |  | ||||||
|             library_data_t::complex_double, library_data_t::complex_double, |  | ||||||
|             library_data_t::complex_double, library_data_t::complex_double): |  | ||||||
|         { |  | ||||||
|             detail::gemm_batch_impl<std::complex<double>, std::complex<double>, |  | ||||||
|                                     std::complex<double>, std::complex<double>>( |  | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |  | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_half, library_data_t::real_half, |             library_data_t::real_half, library_data_t::real_half, | ||||||
|             library_data_t::real_half, library_data_t::real_half): |             library_data_t::real_half, library_data_t::real_half): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, |             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( | ||||||
|                                     sycl::half>(q, a_trans, b_trans, m, n, k, alpha, |                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||||
|                                                 a, lda, b, ldb, beta, c, ldc, |  | ||||||
|                                                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
| #ifdef __INTEL_MKL__ | #ifdef __INTEL_MKL__ | ||||||
| @@ -2498,19 +2450,16 @@ namespace dpct | |||||||
|             library_data_t::real_bfloat16, library_data_t::real_bfloat16, |             library_data_t::real_bfloat16, library_data_t::real_bfloat16, | ||||||
|             library_data_t::real_bfloat16, library_data_t::real_float): |             library_data_t::real_bfloat16, library_data_t::real_float): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, |             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( | ||||||
|                                     oneapi::mkl::bfloat16, float>( |                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |  | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
|             library_data_t::real_bfloat16, library_data_t::real_bfloat16, |             library_data_t::real_bfloat16, library_data_t::real_bfloat16, | ||||||
|             library_data_t::real_float, library_data_t::real_float): |             library_data_t::real_float, library_data_t::real_float): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, |             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( | ||||||
|                                     float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, |                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||||
|                                            b, ldb, beta, c, ldc, batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
| #endif | #endif | ||||||
| @@ -2522,10 +2471,9 @@ namespace dpct | |||||||
|                 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q); |                 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q); | ||||||
|             float beta_float = |             float beta_float = | ||||||
|                 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q); |                 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q); | ||||||
|             detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, |             detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>( | ||||||
|                                     float>(q, a_trans, b_trans, m, n, k, &alpha_float, |                 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, | ||||||
|                                            a, lda, b, ldb, &beta_float, c, ldc, |                 matrix_info); | ||||||
|                                            batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
| @@ -2533,8 +2481,7 @@ namespace dpct | |||||||
|             library_data_t::real_float, library_data_t::real_float): |             library_data_t::real_float, library_data_t::real_float): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>( |             detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>( | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
| @@ -2542,8 +2489,7 @@ namespace dpct | |||||||
|             library_data_t::real_float, library_data_t::real_float): |             library_data_t::real_float, library_data_t::real_float): | ||||||
|         { |         { | ||||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, float, float>( |             detail::gemm_batch_impl<sycl::half, sycl::half, float, float>( | ||||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, |                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         case detail::get_type_combination_id( |         case detail::get_type_combination_id( | ||||||
| @@ -2557,8 +2503,7 @@ namespace dpct | |||||||
|             sycl::half alpha_half(alpha_value); |             sycl::half alpha_half(alpha_value); | ||||||
|             sycl::half beta_half(beta_value); |             sycl::half beta_half(beta_value); | ||||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( |             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( | ||||||
|                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, |                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); | ||||||
|                 batch_size); |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|         default: |         default: | ||||||
|   | |||||||
| @@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | struct ggml_sycl_pool_host : public ggml_sycl_pool { | ||||||
|  |     queue_ptr qptr; | ||||||
|  |     int       device; | ||||||
|  |  | ||||||
|  |     inline static int counter{ 0 }; | ||||||
|  |  | ||||||
|  |     struct ggml_sycl_buffer { | ||||||
|  |         void * ptr  = nullptr; | ||||||
|  |         size_t size = 0; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     // Set arbitrarly to 64 | ||||||
|  |     static constexpr int          MAX_POOL_SIZE{ 64 }; | ||||||
|  |     std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE); | ||||||
|  |     size_t                        pool_size   = 0; | ||||||
|  |  | ||||||
|  |     explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} | ||||||
|  |  | ||||||
|  |     ~ggml_sycl_pool_host() { | ||||||
|  |         for (int i = 0; i < MAX_POOL_SIZE; ++i) { | ||||||
|  |             ggml_sycl_buffer & b = buffer_pool[i]; | ||||||
|  |             if (b.ptr != nullptr) { | ||||||
|  |                 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); | ||||||
|  |                 b.ptr = nullptr; | ||||||
|  |                 pool_size -= b.size; | ||||||
|  |                 b.size = 0; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         counter = 0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void * alloc(size_t size, size_t * actual_size) override { | ||||||
|  |         if (counter == MAX_POOL_SIZE) { | ||||||
|  |             ggml_sycl_buffer b               = buffer_pool[0]; | ||||||
|  |             void *           ptr             = b.ptr; | ||||||
|  |             *actual_size                     = b.size; | ||||||
|  |             counter                          = 1; | ||||||
|  |             return ptr; | ||||||
|  |         } | ||||||
|  |         ggml_sycl_buffer & b = buffer_pool[counter]; | ||||||
|  |  | ||||||
|  |         if (b.ptr == nullptr) { | ||||||
|  |             void * ptr; | ||||||
|  |  | ||||||
|  |             SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); | ||||||
|  |             if (!ptr) { | ||||||
|  |                 GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); | ||||||
|  |                 return nullptr; | ||||||
|  |             } | ||||||
|  |             pool_size += size; | ||||||
|  |             *actual_size = size; | ||||||
|  |             counter      = counter + 1; | ||||||
|  |             return ptr; | ||||||
|  |         } else { | ||||||
|  |             ++counter; | ||||||
|  |             b.size = size; | ||||||
|  |             return b.ptr; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     void free(void * ptr, size_t size) override { | ||||||
|  |         // if the pool is not completed add the pointer to it in place of the first nullptr found. | ||||||
|  |         // Otherwise do nothing, pointers will be freed once the pool is deallocated. | ||||||
|  |         for (int i = 0; i < MAX_POOL_SIZE; ++i) { | ||||||
|  |             ggml_sycl_buffer & b = buffer_pool[i]; | ||||||
|  |             if (b.ptr == nullptr) { | ||||||
|  |                 b.ptr  = ptr; | ||||||
|  |                 b.size = size; | ||||||
|  |                 return; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { | ||||||
|  |     // return pool for the host to speed up memory management | ||||||
|  |     return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device)); | ||||||
|  | } | ||||||
|  |  | ||||||
| std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { | std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { | ||||||
|     // TBD: NO VMM support |     // TBD: NO VMM support | ||||||
|     // if (ggml_sycl_info().devices[device].vmm) { |     // if (ggml_sycl_info().devices[device].vmm) { | ||||||
| @@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|  |  | ||||||
|         ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); |         ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); | ||||||
|         ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23); |         ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23); | ||||||
|  |         ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1); | ||||||
|  |  | ||||||
|         sycl::range<3> block_dims(1, ne12, ne13); |         sycl::range<3> block_dims(1, ne12, ne13); | ||||||
|         /* |         /* | ||||||
| @@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | |||||||
|             }); |             }); | ||||||
|         } |         } | ||||||
|         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( |         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( | ||||||
|             *main_stream, oneapi::mkl::transpose::trans, |             *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, | ||||||
|             oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, |             (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, | ||||||
|             (const void **)(ptrs_src.get() + 0 * ne23), |             (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, | ||||||
|             dpct::library_data_t::real_half, nb01 / nb00, |             (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); | ||||||
|             (const void **)(ptrs_src.get() + 1 * ne23), |  | ||||||
|             dpct::library_data_t::real_half, nb11 / nb10, beta, |  | ||||||
|             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, |  | ||||||
|             cu_compute_type))); |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| catch (sycl::exception const &exc) { | catch (sycl::exception const &exc) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Nicolò Scipione
					Nicolò Scipione