Revert "sycl: add usage of enqueue_functions extension (#14244)" (#15910)

* Revert "sycl: add usage of enqueue_functions extension (#14244)"

This reverts commit 8308f98c7f.

* fix missed revert code, format the code
This commit is contained in:
Neo Zhang Jianyu
2025-09-12 09:15:12 +08:00
committed by GitHub
parent 360d6533db
commit 704d90c987
20 changed files with 845 additions and 674 deletions

View File

@@ -254,13 +254,14 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -271,15 +272,16 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -288,14 +290,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
const int ne_elements, queue_ptr stream, int device) {
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler& cgh) {
const float eps_ct4 = eps;
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
WARP_SIZE);
});
});
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -307,18 +313,22 @@ static void group_norm_f32_sycl(const float* x, float* dst,
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
const float eps_ct4 = eps;
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -330,13 +340,14 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -347,15 +358,16 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -366,12 +378,16 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
});
});
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -382,15 +398,18 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
work_group_size);
});
});
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}