mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
[SYCL] refactor soft_max, add soft_max_back (#16472)
* refactor to support soft_max_ext * fix error and support soft_max_back * rm unused functions * fix format issue --------- Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
This commit is contained in:
@@ -197,6 +197,7 @@ struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
// size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
|
||||
@@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
|
||||
const sycl::nd_item<3>& item_ct1) {
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
/*
|
||||
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
|
||||
kernel that calls this function may be less than "32". The function
|
||||
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
|
||||
CPU device. Modify the size of the work-group to ensure that the value
|
||||
of the right-most dimension is a multiple of "32".
|
||||
*/
|
||||
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
|
||||
}
|
||||
return x;
|
||||
@@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ int warp_reduce_sum(int x) {
|
||||
return sycl::reduce_over_group(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
a.x() += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
|
||||
width);
|
||||
a.y() += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
|
||||
width);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
a = a + dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
|
||||
width);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
static constexpr int ggml_sycl_get_physical_warp_size() {
|
||||
// todo: for old iGPU + dGPU case, need to be changed.
|
||||
return WARP_SIZE;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
|
||||
offset, width));
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float warp_reduce_max(float x,
|
||||
const sycl::nd_item<3>& item_ct1) {
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
/*
|
||||
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
|
||||
kernel that calls this function may be less than "32". The function
|
||||
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
|
||||
CPU device. Modify the size of the work-group to ensure that the value
|
||||
of the right-most dimension is a multiple of "32".
|
||||
*/
|
||||
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
|
||||
item_ct1.get_sub_group(), x, mask));
|
||||
}
|
||||
@@ -558,4 +602,18 @@ struct scope_op_debug_print {
|
||||
std::string_view func_suffix;
|
||||
};
|
||||
|
||||
static __dpct_inline__ float get_alibi_slope(const float max_bias,
|
||||
const uint32_t h,
|
||||
const uint32_t n_head_log2,
|
||||
const float m0,
|
||||
const float m1) {
|
||||
if (max_bias <= 0.0f) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
return dpct::pow(base, exph);
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
Reference in New Issue
Block a user