sycl: add usage of enqueue_functions extension (#14244)

* Add header and namespace to use enqueue_functions extension

* Convert submit and parallel_for to use new extension in convert.cpp

* Convert submit and parallel_for to use extension in ggml-sycl.cpp

* Convert submit and parallel_for to use extension in gla.cpp

* Convert submit and parallel_for in mmq.cpp

* Convert submit and parallel_for in mmvq.cpp

* Convert submit and parallel_for in remaining files

* Convert all simple parallel_for to nd_launch from enqueue_functions
extension

* Wrapping extension in general function

Create a general function that enable the enqueue_functions extension if
it is enable in the compiler, otherwise call the general SYCL function
to launch kernels.

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
Nicolò Scipione
2025-06-20 15:07:21 +02:00
committed by GitHub
parent 6369be0735
commit 8308f98c7f
19 changed files with 750 additions and 986 deletions

View File

@@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(
sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
});
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
}
}
@@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
}
#endif
@@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
}
#endif
}
@@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_0(vx, y, nb32, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
}
}
@@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});
}
template <typename dst_t>
@@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_1(vx, y, nb32, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
}
}
@@ -171,14 +149,13 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
});
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
});
});
}
}
@@ -191,13 +168,13 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
});
}
@@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
}
#endif
@@ -242,24 +213,18 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
}
#endif
@@ -271,9 +236,9 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
}
template <typename dst_t>
@@ -284,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s(
vx, y, item_ct1, iq1s_grid_gpu
);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
});
}
}
@@ -305,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_m(
vx, y, item_ct1, iq1s_grid_gpu
);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
});
}
}
@@ -326,15 +281,12 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs(
vx, y, item_ct1, iq2xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -347,15 +299,12 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs(
vx, y, item_ct1, iq2xs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -368,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
});
}
}
@@ -388,15 +334,12 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs(
vx, y, item_ct1, iq3xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -409,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_s(
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
});
}
}
@@ -432,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_xs(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
});
}
#endif
@@ -453,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
});
}
}