[SYCL] refactor soft_max, add soft_max_back (#16472)

* refactor to support soft_max_ext

* fix error and support soft_max_back

* rm unused functions

* fix format issue

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
This commit is contained in:
Neo Zhang Jianyu
2025-10-09 15:25:11 +08:00
committed by GitHub
parent e08db42595
commit b260213755
5 changed files with 437 additions and 189 deletions

View File

@@ -197,6 +197,7 @@ struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support
size_t total_vram;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
@@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
const sycl::nd_item<3>& item_ct1) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
/*
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
kernel that calls this function may be less than "32". The function
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
CPU device. Modify the size of the work-group to ensure that the value
of the right-most dimension is a multiple of "32".
*/
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
}
return x;
@@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
return a;
}
template <int width = WARP_SIZE>
static __dpct_inline__ int warp_reduce_sum(int x) {
return sycl::reduce_over_group(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
}
template <int width = WARP_SIZE>
static __dpct_inline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
}
return x;
}
template <int width = WARP_SIZE>
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
a.x() += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
width);
a.y() += dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
width);
}
return a;
}
template <int width = WARP_SIZE>
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
a = a + dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
width);
}
return a;
}
static constexpr int ggml_sycl_get_physical_warp_size() {
// todo: for old iGPU + dGPU case, need to be changed.
return WARP_SIZE;
}
template <int width = WARP_SIZE>
static __dpct_inline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
offset, width));
}
return x;
}
static __dpct_inline__ float warp_reduce_max(float x,
const sycl::nd_item<3>& item_ct1) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
/*
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
kernel that calls this function may be less than "32". The function
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
CPU device. Modify the size of the work-group to ensure that the value
of the right-most dimension is a multiple of "32".
*/
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
item_ct1.get_sub_group(), x, mask));
}
@@ -558,4 +602,18 @@ struct scope_op_debug_print {
std::string_view func_suffix;
};
static __dpct_inline__ float get_alibi_slope(const float max_bias,
const uint32_t h,
const uint32_t n_head_log2,
const float m0,
const float m1) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return dpct::pow(base, exph);
}
#endif // GGML_SYCL_COMMON_HPP