mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	SYCL: Introducing memory host pool (#11251)
* Implement host pool for matrix_info Creating a new memory pool on the host to store memory location for matrix_info needed to launch gemm_batch from oneMKL/oneMath. Removing complex support in gemm_batch since it is not used in llama.cpp * Remove unnecessary headers and cast * Reorder member variable to avoid warning on initialization * Formatting * Remove unused variable * Address PR review feedback - remove warning --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
		| @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { | ||||
|     // pool | ||||
|     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; | ||||
|  | ||||
|     std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES]; | ||||
|  | ||||
|     static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device); | ||||
|  | ||||
|     static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device); | ||||
|  | ||||
|     ggml_sycl_pool & pool(int device) { | ||||
|         if (pools[device] == nullptr) { | ||||
|             pools[device] = new_pool_for_device(stream(device,0), device); | ||||
| @@ -345,6 +349,15 @@ struct ggml_backend_sycl_context { | ||||
|     ggml_sycl_pool & pool() { | ||||
|         return pool(device); | ||||
|     } | ||||
|  | ||||
|     ggml_sycl_pool & host_pool(int device) { | ||||
|         if (host_pools[device] == nullptr) { | ||||
|             host_pools[device] = new_pool_for_host(stream(device, 0), device); | ||||
|         } | ||||
|         return *host_pools[device]; | ||||
|     } | ||||
|  | ||||
|     ggml_sycl_pool & host_pool() { return host_pool(device); } | ||||
| }; | ||||
|  | ||||
| // common device functions | ||||
|   | ||||
| @@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { | ||||
|     return device_type.str(); | ||||
| } | ||||
|  | ||||
| template <typename Ts> struct matrix_info_t { | ||||
|     oneapi::mkl::transpose transpose_info[2]; | ||||
|     Ts                     value_info[2]; | ||||
|     std::int64_t           size_info[3]; | ||||
|     std::int64_t           ld_info[3]; | ||||
|     std::int64_t           groupsize_info; | ||||
| }; | ||||
|  | ||||
| namespace dpct | ||||
| { | ||||
|     typedef sycl::queue *queue_ptr; | ||||
| @@ -1727,26 +1735,13 @@ namespace dpct | ||||
|         }; | ||||
|  | ||||
|         template <class Ta, class Tb, class Tc, class Ts> | ||||
|         inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, | ||||
|                                     oneapi::mkl::transpose b_trans, int m, int n, int k, | ||||
|                                     const void *alpha, const void **a, int lda, | ||||
|                                     const void **b, int ldb, const void *beta, void **c, | ||||
|                                     int ldc, int batch_size) | ||||
|         { | ||||
|             struct matrix_info_t | ||||
|             { | ||||
|                 oneapi::mkl::transpose transpose_info[2]; | ||||
|                 Ts value_info[2]; | ||||
|                 std::int64_t size_info[3]; | ||||
|                 std::int64_t ld_info[3]; | ||||
|                 std::int64_t groupsize_info; | ||||
|             }; | ||||
|  | ||||
|         inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, | ||||
|                                     int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, | ||||
|                                     int ldb, const void * beta, void ** c, int ldc, int batch_size, | ||||
|                                     matrix_info_t<float> * matrix_info) { | ||||
|             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); | ||||
|             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q); | ||||
|  | ||||
|             matrix_info_t *matrix_info = | ||||
|                 (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); | ||||
|             matrix_info->transpose_info[0] = a_trans; | ||||
|             matrix_info->transpose_info[1] = b_trans; | ||||
|             matrix_info->value_info[0] = alpha_value; | ||||
| @@ -1763,23 +1758,18 @@ namespace dpct | ||||
|             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( | ||||
|                 oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info, | ||||
|                 matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, | ||||
|                 matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a), | ||||
|                 matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1, | ||||
|                 matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, | ||||
|                 &(matrix_info->groupsize_info)); | ||||
|                 matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), | ||||
|                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), | ||||
|                 matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1), | ||||
|                 reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); | ||||
| #else | ||||
|             sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( | ||||
|                 q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, | ||||
|                 matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, | ||||
|                 matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), | ||||
|                 reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), | ||||
|                 matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), | ||||
|                 matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); | ||||
|                 matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1), | ||||
|                 reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); | ||||
| #endif | ||||
|  | ||||
|             q.submit([&](sycl::handler &cgh) | ||||
|                      { | ||||
|     cgh.depends_on(e); | ||||
|     cgh.host_task([=] { std::free(matrix_info); }); }); | ||||
|         } | ||||
|  | ||||
|         template <class Ta, class Tb, class Tc, class Ts> | ||||
| @@ -2422,25 +2412,11 @@ namespace dpct | ||||
|     /// \param [in] ldc Leading dimension of C. | ||||
|     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. | ||||
|     /// \param [in] scaling_type Data type of the scaling factors. | ||||
|     inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, | ||||
|                            oneapi::mkl::transpose b_trans, int m, int n, int k, | ||||
|                            const void *alpha, const void *a[], | ||||
|                            library_data_t a_type, int lda, const void *b[], | ||||
|                            library_data_t b_type, int ldb, const void *beta, | ||||
|                            void *c[], library_data_t c_type, int ldc, | ||||
|                            int batch_size, library_data_t scaling_type) | ||||
|     { | ||||
|         if (scaling_type == library_data_t::real_float && | ||||
|             c_type == library_data_t::complex_float) | ||||
|         { | ||||
|             scaling_type = library_data_t::complex_float; | ||||
|         } | ||||
|         else if (scaling_type == library_data_t::real_double && | ||||
|                  c_type == library_data_t::complex_double) | ||||
|         { | ||||
|             scaling_type = library_data_t::complex_double; | ||||
|         } | ||||
|  | ||||
|     inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, | ||||
|                            int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, | ||||
|                            const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], | ||||
|                            library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, | ||||
|                            matrix_info_t<float> * matrix_info) { | ||||
|         std::uint64_t key = | ||||
|             detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); | ||||
|         switch (key) | ||||
| @@ -2449,48 +2425,24 @@ namespace dpct | ||||
|             library_data_t::real_float, library_data_t::real_float, | ||||
|             library_data_t::real_float, library_data_t::real_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<float, float, float, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|             detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, | ||||
|                                                                 beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
|             library_data_t::real_double, library_data_t::real_double, | ||||
|             library_data_t::real_double, library_data_t::real_double): | ||||
|         { | ||||
|             detail::gemm_batch_impl<double, double, double, double>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
|             library_data_t::complex_float, library_data_t::complex_float, | ||||
|             library_data_t::complex_float, library_data_t::complex_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<std::complex<float>, std::complex<float>, | ||||
|                                     std::complex<float>, std::complex<float>>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
|             library_data_t::complex_double, library_data_t::complex_double, | ||||
|             library_data_t::complex_double, library_data_t::complex_double): | ||||
|         { | ||||
|             detail::gemm_batch_impl<std::complex<double>, std::complex<double>, | ||||
|                                     std::complex<double>, std::complex<double>>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|             detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, | ||||
|                                                                     beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
|             library_data_t::real_half, library_data_t::real_half, | ||||
|             library_data_t::real_half, library_data_t::real_half): | ||||
|         { | ||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, | ||||
|                                     sycl::half>(q, a_trans, b_trans, m, n, k, alpha, | ||||
|                                                 a, lda, b, ldb, beta, c, ldc, | ||||
|                                                 batch_size); | ||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
| #ifdef __INTEL_MKL__ | ||||
| @@ -2498,19 +2450,16 @@ namespace dpct | ||||
|             library_data_t::real_bfloat16, library_data_t::real_bfloat16, | ||||
|             library_data_t::real_bfloat16, library_data_t::real_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, | ||||
|                                     oneapi::mkl::bfloat16, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
|             library_data_t::real_bfloat16, library_data_t::real_bfloat16, | ||||
|             library_data_t::real_float, library_data_t::real_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, | ||||
|                                     float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, | ||||
|                                            b, ldb, beta, c, ldc, batch_size); | ||||
|             detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
| #endif | ||||
| @@ -2522,10 +2471,9 @@ namespace dpct | ||||
|                 dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q); | ||||
|             float beta_float = | ||||
|                 dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q); | ||||
|             detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, | ||||
|                                     float>(q, a_trans, b_trans, m, n, k, &alpha_float, | ||||
|                                            a, lda, b, ldb, &beta_float, c, ldc, | ||||
|                                            batch_size); | ||||
|             detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, | ||||
|                 matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
| @@ -2533,8 +2481,7 @@ namespace dpct | ||||
|             library_data_t::real_float, library_data_t::real_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
| @@ -2542,8 +2489,7 @@ namespace dpct | ||||
|             library_data_t::real_float, library_data_t::real_float): | ||||
|         { | ||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, float, float>( | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, | ||||
|                 batch_size); | ||||
|                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         case detail::get_type_combination_id( | ||||
| @@ -2557,8 +2503,7 @@ namespace dpct | ||||
|             sycl::half alpha_half(alpha_value); | ||||
|             sycl::half beta_half(beta_value); | ||||
|             detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>( | ||||
|                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, | ||||
|                 batch_size); | ||||
|                 q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); | ||||
|             break; | ||||
|         } | ||||
|         default: | ||||
|   | ||||
| @@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct ggml_sycl_pool_host : public ggml_sycl_pool { | ||||
|     queue_ptr qptr; | ||||
|     int       device; | ||||
|  | ||||
|     inline static int counter{ 0 }; | ||||
|  | ||||
|     struct ggml_sycl_buffer { | ||||
|         void * ptr  = nullptr; | ||||
|         size_t size = 0; | ||||
|     }; | ||||
|  | ||||
|     // Set arbitrarly to 64 | ||||
|     static constexpr int          MAX_POOL_SIZE{ 64 }; | ||||
|     std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE); | ||||
|     size_t                        pool_size   = 0; | ||||
|  | ||||
|     explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} | ||||
|  | ||||
|     ~ggml_sycl_pool_host() { | ||||
|         for (int i = 0; i < MAX_POOL_SIZE; ++i) { | ||||
|             ggml_sycl_buffer & b = buffer_pool[i]; | ||||
|             if (b.ptr != nullptr) { | ||||
|                 SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); | ||||
|                 b.ptr = nullptr; | ||||
|                 pool_size -= b.size; | ||||
|                 b.size = 0; | ||||
|             } | ||||
|         } | ||||
|         counter = 0; | ||||
|     } | ||||
|  | ||||
|     void * alloc(size_t size, size_t * actual_size) override { | ||||
|         if (counter == MAX_POOL_SIZE) { | ||||
|             ggml_sycl_buffer b               = buffer_pool[0]; | ||||
|             void *           ptr             = b.ptr; | ||||
|             *actual_size                     = b.size; | ||||
|             counter                          = 1; | ||||
|             return ptr; | ||||
|         } | ||||
|         ggml_sycl_buffer & b = buffer_pool[counter]; | ||||
|  | ||||
|         if (b.ptr == nullptr) { | ||||
|             void * ptr; | ||||
|  | ||||
|             SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); | ||||
|             if (!ptr) { | ||||
|                 GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); | ||||
|                 return nullptr; | ||||
|             } | ||||
|             pool_size += size; | ||||
|             *actual_size = size; | ||||
|             counter      = counter + 1; | ||||
|             return ptr; | ||||
|         } else { | ||||
|             ++counter; | ||||
|             b.size = size; | ||||
|             return b.ptr; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     void free(void * ptr, size_t size) override { | ||||
|         // if the pool is not completed add the pointer to it in place of the first nullptr found. | ||||
|         // Otherwise do nothing, pointers will be freed once the pool is deallocated. | ||||
|         for (int i = 0; i < MAX_POOL_SIZE; ++i) { | ||||
|             ggml_sycl_buffer & b = buffer_pool[i]; | ||||
|             if (b.ptr == nullptr) { | ||||
|                 b.ptr  = ptr; | ||||
|                 b.size = size; | ||||
|                 return; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { | ||||
|     // return pool for the host to speed up memory management | ||||
|     return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device)); | ||||
| } | ||||
|  | ||||
| std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { | ||||
|     // TBD: NO VMM support | ||||
|     // if (ggml_sycl_info().devices[device].vmm) { | ||||
| @@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | ||||
|  | ||||
|         ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); | ||||
|         ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23); | ||||
|         ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1); | ||||
|  | ||||
|         sycl::range<3> block_dims(1, ne12, ne13); | ||||
|         /* | ||||
| @@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, | ||||
|             }); | ||||
|         } | ||||
|         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( | ||||
|             *main_stream, oneapi::mkl::transpose::trans, | ||||
|             oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, | ||||
|             (const void **)(ptrs_src.get() + 0 * ne23), | ||||
|             dpct::library_data_t::real_half, nb01 / nb00, | ||||
|             (const void **)(ptrs_src.get() + 1 * ne23), | ||||
|             dpct::library_data_t::real_half, nb11 / nb10, beta, | ||||
|             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, | ||||
|             cu_compute_type))); | ||||
|             *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, | ||||
|             (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, | ||||
|             (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, | ||||
|             (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); | ||||
|     } | ||||
| } | ||||
| catch (sycl::exception const &exc) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Nicolò Scipione
					Nicolò Scipione