mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-09 10:17:06 +00:00
sycl: add usage of enqueue_functions extension (#14244)
* Add header and namespace to use enqueue_functions extension * Convert submit and parallel_for to use new extension in convert.cpp * Convert submit and parallel_for to use extension in ggml-sycl.cpp * Convert submit and parallel_for to use extension in gla.cpp * Convert submit and parallel_for in mmq.cpp * Convert submit and parallel_for in mmvq.cpp * Convert submit and parallel_for in remaining files * Convert all simple parallel_for to nd_launch from enqueue_functions extension * Wrapping extension in general function Create a general function that enable the enqueue_functions extension if it is enable in the compiler, otherwise call the general SYCL function to launch kernels. --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
@@ -118,12 +118,10 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows<qk, qr, dq>(
|
||||
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
|
||||
item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
@@ -156,9 +154,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user