From b2602137557b2b28a39e03612717d85ead9a6f5a Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Thu, 9 Oct 2025 15:25:11 +0800 Subject: [PATCH] [SYCL] refactor soft_max, add soft_max_back (#16472) * refactor to support soft_max_ext * fix error and support soft_max_back * rm unused functions * fix format issue --------- Co-authored-by: Zhang Jianyu --- ggml/src/ggml-sycl/common.hpp | 86 ++++- ggml/src/ggml-sycl/dpct/helper.hpp | 20 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 25 +- ggml/src/ggml-sycl/softmax.cpp | 491 +++++++++++++++++++---------- ggml/src/ggml-sycl/softmax.hpp | 4 + 5 files changed, 437 insertions(+), 189 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 4e7449d06e..d66d7ade90 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -197,6 +197,7 @@ struct sycl_device_info { int cc; // compute capability // int nsm; // number of streaming multiprocessors // size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t total_vram; //sycl_hw_info hw_info; \\ device id and aarch, currently not used @@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:98: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); } return x; @@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } +template +static __dpct_inline__ int warp_reduce_sum(int x) { + return sycl::reduce_over_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); +} + +template +static __dpct_inline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width); + } + return x; +} + +template +static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a.x() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset, + width); + a.y() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset, + width); + } + return a; +} + +template +static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a = a + dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset, + width); + } + return a; +} + +static constexpr int ggml_sycl_get_physical_warp_size() { + // todo: for old iGPU + dGPU case, need to be changed. + return WARP_SIZE; +} + +template +static __dpct_inline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width)); + } + return x; +} + static __dpct_inline__ float warp_reduce_max(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:97: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x = sycl::fmax(x, dpct::permute_sub_group_by_xor( item_ct1.get_sub_group(), x, mask)); } @@ -558,4 +602,18 @@ struct scope_op_debug_print { std::string_view func_suffix; }; +static __dpct_inline__ float get_alibi_slope(const float max_bias, + const uint32_t h, + const uint32_t n_head_log2, + const float m0, + const float m1) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return dpct::pow(base, exph); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index d538965b09..f93cfa701f 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -277,6 +277,26 @@ namespace dpct } // namespace detail + // COPY from DPCT head files + /// dim3 is used to store 3 component dimensions. + class dim3 { + public: + unsigned x, y, z; + + constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1) + : x(x), y(y), z(z) {} + + dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {} + + operator sycl::range<3>() const { return sycl::range<3>(z, y, x); } + }; // namespace dim3 + + inline dim3 operator*(const dim3 &a, const dim3 &b) { + return dim3{a.x * b.x, a.y * b.y, a.z * b.z}; + } + // COPY from DPCT head files + + /// Pitched 2D/3D memory data. class pitched_data { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 4ac919ea2d..e4cc3c8ed8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -87,6 +87,7 @@ static ggml_sycl_device_info ggml_sycl_init() { 100 * prop.get_major_version() + 10 * prop.get_minor_version(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].smpbo = prop.get_local_mem_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -3741,6 +3742,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_sycl_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_sycl_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; @@ -3778,6 +3782,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return true; } catch (sycl::exception & e) { std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::cerr << "Error OP "<op)<< std::endl; std::exit(1); } @@ -4386,19 +4391,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_SOFT_MAX: - // TODO: support batching - if (op->src[0]->ne[3] != 1) { - return false; - } - // TODO: support attention sinks [TAG_ATTN_SINKS] - if (op->src[2]) { - return false; - } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_DIAG_MASK_INF: + return true; + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 52fcf4b3db..83b7c71b66 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,37 +1,94 @@ #include "softmax.hpp" +#include +#include +#include -template -static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; - const int tid = item_ct1.get_local_id(2); - const int rowx = item_ct1.get_group(2); - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension +template static __dpct_inline__ float t2f32(T val) { + return (float) val; +} - const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; +template <> float __dpct_inline__ t2f32(sycl::half val) { + return sycl::vec(val) + .convert()[0]; +} - const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static void soft_max_f32(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params p, + uint8_t * dpct_local) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + const int block_size = block_size_template == 0 + ? item_ct1.get_local_range(2) + : block_size_template; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; size_t nreduce = nwarps / WARP_SIZE; - float slope = 1.0f; - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = rowx/nrows_y; // head index + const int tid = item_ct1.get_local_id(2); - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int64_t i03 = item_ct1.get_group(0); + const int64_t i02 = item_ct1.get_group(1); + const int64_t i01 = item_ct1.get_group(2); - slope = sycl::pow(base, float(exp)); - } + //TODO: noncontigous inputs/outputs + const int rowx = item_ct1.get_group(2) + + item_ct1.get_group(1) * item_ct1.get_group_range(2) + + item_ct1.get_group(0) * item_ct1.get_group_range(2) * + item_ct1.get_group_range(1); - float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; - float max_val = -INFINITY; + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + float * buf_iw = (float *) dpct_local; + + // shared memory buffer to cache values between iterations: + float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst; + float max_val = sinks ? sinks[i02] : -INFINITY; +#pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; - max_val = sycl::max(max_val, val); + max_val = sycl::max(max_val, val); } - // find the max value in the block - max_val = warp_reduce_max(max_val, item_ct1); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; - for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = -INFINITY; - } + buf_iw[lane_id] = -INFINITY; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } - item_ct1.barrier(sycl::access::fence_space::local_space); - max_val = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) { - max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); - } - max_val = warp_reduce_max(max_val, item_ct1); - } + item_ct1.barrier(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + float tmp = 0.0f; // partial sum - float tmp = 0.f; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + + if (ncols_template == 0 && col >= ncols) { break; } @@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int tmp += val; vals[col] = val; } - // find the sum of exps in the block - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = 0.f; + buf_iw[lane_id + i * WARP_SIZE] = 0.f; } } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - tmp += buf[lane_id + i * WARP_SIZE]; + tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); } - - const float inv_sum = 1.f / tmp; + if (sinks) { + tmp += sycl::native::exp(sinks[i02] - max_val); + } + const float inv_sum = 1.0f / tmp; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int return; } - const int idst = rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, + const int ncols, const float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } -template -static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, queue_ptr stream) { +template +static void launch_soft_max_kernels(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & p, + dpct::queue_ptr stream, + dpct::dim3 block_dims, + dpct::dim3 block_nums, + size_t nbytes_shared) +{ + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + if (p.ncols == ncols) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, - nrows_y, scale, max_bias, m0, - m1, n_head_log2, item_ct1, - get_pointer(local_buf_acc)); - }); + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); }); } -template -static void soft_max_f32_sycl(const float * x, const T * mask, - float * dst, const int ncols_x, const int nrows_x, - const int nrows_y, const float scale, const float max_bias, - queue_ptr stream, int device) { +template +static void soft_max_f32_sycl(const float *x, const T *mask, + const float *sinks, float *dst, + const soft_max_params ¶ms, + dpct::queue_ptr stream, int device) { int nth = WARP_SIZE; int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; - const sycl::range<3> block_dims(1, 1, nth); - const sycl::range<3> block_nums(1, 1, nrows_x); - const size_t n_val_tmp = nth / WARP_SIZE; - const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp); + const dpct::dim3 block_dims(nth, 1, 1); + const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = + (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const int id = get_current_device_id(); + const size_t smpbo = ggml_sycl_info().devices[id].smpbo; - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const size_t local_mem_size = stream->get_device().get_info(); - if (n_local_scratch*sizeof(float) < local_mem_size) { - if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - return; - } - switch (ncols_x) { - case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - } + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>( + x, mask, sinks, dst, params, stream, block_dims, block_nums, + nbytes_shared); } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, WARP_SIZE, stream); + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared_low), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_f32( + x, mask, sinks, dst, params, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); } } +static void soft_max_back_f32_sycl(const float * grad, + const float * dstf, + float * dst, + const int ncols, + const int nrows, + const float scale, + dpct::queue_ptr stream) { + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums(nrows, 1, 1); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_back_f32(grad, dstf, dst, ncols, scale); + GGML_UNUSED(item_ct1); + }); +} + void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional + // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows_x = ggml_nrows(dst->src[0]); - const int64_t nrows_y = dst->src[0]->ne[1]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - ggml_sycl_set_device(ctx.device); - dpct::queue_ptr main_stream = ctx.stream(); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; - if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { - const sycl::half * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, - main_stream, ctx.device); - } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { - const float * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d, + (const float *)src2_d, dst_d, params, stream, + ctx.device); } else { - /* mask unavailable */ - soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d, + dst_d, params, stream, ctx.device); } } + +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index 2cf8582ec9..23f1e5a9d6 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,6 +15,10 @@ #include "common.hpp" +#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 + void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_SOFTMAX_HPP