mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	SYCL: Add non contiguous support in RMS_NORM and NORM kernels (#13611)
* SYCL: Add non contiguous input support to norm kernel * refactor and add RMS_NORM non contiguous input support ggml-ci * restore subgroup reduction for multi-subgroup thread blocks in norm kernels * Swap grid dims of nsamples and nrows ggml-ci * Revert "Swap grid dims of nsamples and nrows" This reverts commit 43be2d657fec7f7fba54e2cd154106bc0fc45adf. * restore not required changes ggml-ci * address review comments: change it to more like SYCL * Use a common function to calculate offset * remove wrap around logic for handling broadcasts * remove static from calculate_offset fn and use ceil_div
This commit is contained in:
		| @@ -13,6 +13,7 @@ | |||||||
| #ifndef GGML_SYCL_COMMON_HPP | #ifndef GGML_SYCL_COMMON_HPP | ||||||
| #define GGML_SYCL_COMMON_HPP | #define GGML_SYCL_COMMON_HPP | ||||||
|  |  | ||||||
|  | #include <cstddef> | ||||||
| #include <fstream> | #include <fstream> | ||||||
| #include <iostream> | #include <iostream> | ||||||
| #include <string> | #include <string> | ||||||
| @@ -481,6 +482,19 @@ static __dpct_inline__ float warp_reduce_max(float x, | |||||||
|     return x; |     return x; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /* Helper for Computing the linear offset of a ggml_tensor given | ||||||
|  | per-dimension sizes, strides, and indices */ | ||||||
|  | template<int N> | ||||||
|  | __dpct_inline__ size_t calculate_offset(const std::array<int, N> & strides, const std::array<int, N> & indices) { | ||||||
|  |     size_t offset = 0; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i = 0; i < N; i++) { | ||||||
|  |         auto index_i = indices[i]; | ||||||
|  |         offset += strides[i] * index_i; | ||||||
|  |     } | ||||||
|  |     return offset; | ||||||
|  | } | ||||||
|  |  | ||||||
| // Helper for vec loading aligned data | // Helper for vec loading aligned data | ||||||
| template <typename Tp, int n> | template <typename Tp, int n> | ||||||
| inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) { | inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) { | ||||||
|   | |||||||
| @@ -4241,6 +4241,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
| #endif | #endif | ||||||
|         case GGML_OP_NORM: |         case GGML_OP_NORM: | ||||||
|         case GGML_OP_RMS_NORM: |         case GGML_OP_RMS_NORM: | ||||||
|  |             return true; | ||||||
|         case GGML_OP_L2_NORM: |         case GGML_OP_L2_NORM: | ||||||
|         case GGML_OP_GROUP_NORM: |         case GGML_OP_GROUP_NORM: | ||||||
|             return ggml_is_contiguous(op->src[0]); |             return ggml_is_contiguous(op->src[0]); | ||||||
|   | |||||||
| @@ -1,17 +1,31 @@ | |||||||
| #include "norm.hpp" | #include "norm.hpp" | ||||||
|  | #include "ggml-sycl/common.hpp" | ||||||
|  | #include "ggml-sycl/presets.hpp" | ||||||
|  |  | ||||||
| static void norm_f32(const float* x, float* dst, const int ncols, const float eps, | static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, | ||||||
|     const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { |         const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { | ||||||
|     const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + |  | ||||||
|         item_ct1.get_local_id(1); |     const int nrows = item_ct1.get_group_range(2); | ||||||
|     const int tid = item_ct1.get_local_id(2); |     const int nchannels = item_ct1.get_group_range(1); | ||||||
|  |  | ||||||
|     const int nthreads = item_ct1.get_local_range(2); |     const int nthreads = item_ct1.get_local_range(2); | ||||||
|  |     const int sample  = item_ct1.get_group(0); | ||||||
|  |     const int channel = item_ct1.get_group(1); | ||||||
|  |     const int row     = item_ct1.get_group(2); | ||||||
|  |  | ||||||
|  |     const int tid = item_ct1.get_local_id(2); | ||||||
|     const int nwarps = nthreads / WARP_SIZE; |     const int nwarps = nthreads / WARP_SIZE; | ||||||
|  |  | ||||||
|  |     const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); | ||||||
|  |     const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); | ||||||
|  |  | ||||||
|  |     x += strided_offset; | ||||||
|  |     dst += packed_offset; | ||||||
|  |  | ||||||
|     sycl::float2 mean_var = sycl::float2(0.f, 0.f); |     sycl::float2 mean_var = sycl::float2(0.f, 0.f); | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const float xi = x[row * ncols + col]; |         const float xi = x[col]; | ||||||
|         mean_var.x() += xi; |         mean_var.x() += xi; | ||||||
|         mean_var.y() += xi * xi; |         mean_var.y() += xi * xi; | ||||||
|     } |     } | ||||||
| @@ -19,22 +33,18 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep | |||||||
|     // sum up partial sums |     // sum up partial sums | ||||||
|     mean_var = warp_reduce_sum(mean_var, item_ct1); |     mean_var = warp_reduce_sum(mean_var, item_ct1); | ||||||
|     if  (block_size > WARP_SIZE) { |     if  (block_size > WARP_SIZE) { | ||||||
|  |         const auto sub_group = item_ct1.get_sub_group(); | ||||||
|         int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; |         const auto sg_id = sub_group.get_group_linear_id(); | ||||||
|         int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; |         const auto wi_in_sg = sub_group.get_local_linear_id(); | ||||||
|         if (lane_id == 0) { |         if (wi_in_sg == 0) { | ||||||
|             s_sum[warp_id] = mean_var; |             s_sum[sg_id] = mean_var; | ||||||
|         } |         } | ||||||
|         /* |  | ||||||
|         DPCT1118:0: SYCL group functions and algorithms must be encountered in |  | ||||||
|         converged control flow. You may need to adjust the code. |  | ||||||
|         */ |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|         mean_var = 0.f; |         mean_var = 0.f; | ||||||
|         size_t nreduce = nwarps / WARP_SIZE; |         const size_t nreduce = ceil_div(nwarps, WARP_SIZE); | ||||||
|         for (size_t i = 0; i < nreduce; i += 1) |         for (size_t i = 0; i < nreduce; i += 1) | ||||||
|         { |         { | ||||||
|             mean_var += s_sum[lane_id + i * WARP_SIZE]; |             mean_var += s_sum[wi_in_sg + i * WARP_SIZE]; | ||||||
|         } |         } | ||||||
|         mean_var = warp_reduce_sum(mean_var, item_ct1); |         mean_var = warp_reduce_sum(mean_var, item_ct1); | ||||||
|     } |     } | ||||||
| @@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep | |||||||
|     const float inv_std = sycl::rsqrt(var + eps); |     const float inv_std = sycl::rsqrt(var + eps); | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std; |         dst[col] = (x[col] - mean) * inv_std; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps, | static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, | ||||||
|     const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { |         const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { | ||||||
|     const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + |  | ||||||
|         item_ct1.get_local_id(1); |     const int nrows = item_ct1.get_group_range(2); | ||||||
|     const int tid = item_ct1.get_local_id(2); |     const int nchannels = item_ct1.get_group_range(1); | ||||||
|  |  | ||||||
|  |     const int sample  = item_ct1.get_group(0); | ||||||
|  |     const int channel = item_ct1.get_group(1); | ||||||
|  |     const int row     = item_ct1.get_group(2); | ||||||
|  |  | ||||||
|     const int nthreads = item_ct1.get_local_range(2); |     const int nthreads = item_ct1.get_local_range(2); | ||||||
|  |  | ||||||
|  |     const int tid = item_ct1.get_local_id(2); | ||||||
|     const int nwarps = nthreads / WARP_SIZE; |     const int nwarps = nthreads / WARP_SIZE; | ||||||
|  |  | ||||||
|  |     const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); | ||||||
|  |     const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); | ||||||
|  |  | ||||||
|  |     x   += strided_offset; | ||||||
|  |     dst += packed_offset; | ||||||
|  |  | ||||||
|  |  | ||||||
|     float tmp = 0.0f; // partial sum for thread in warp |     float tmp = 0.0f; // partial sum for thread in warp | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         const float xi = x[row * ncols + col]; |         const float xi = x[col]; | ||||||
|         tmp += xi * xi; |         tmp += xi * xi; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // sum up partial sums |     // sum up partial sums | ||||||
|     tmp = warp_reduce_sum(tmp, item_ct1); |     tmp = warp_reduce_sum(tmp, item_ct1); | ||||||
|     if (block_size > WARP_SIZE) { |     if (block_size > WARP_SIZE) { | ||||||
|  |         const auto sub_group = item_ct1.get_sub_group(); | ||||||
|         int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; |         const auto sg_id = sub_group.get_group_linear_id(); | ||||||
|         int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; |         const auto wi_in_sg = sub_group.get_local_linear_id(); | ||||||
|         if (lane_id == 0) { |         if (wi_in_sg == 0) { | ||||||
|             s_sum[warp_id] = tmp; |             s_sum[sg_id] = tmp; | ||||||
|         } |         } | ||||||
|         /* |  | ||||||
|         DPCT1118:3: SYCL group functions and algorithms must be encountered in |  | ||||||
|         converged control flow. You may need to adjust the code. |  | ||||||
|         */ |  | ||||||
|         item_ct1.barrier(sycl::access::fence_space::local_space); |         item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||
|         size_t nreduce = nwarps / WARP_SIZE; |         const size_t nreduce = ceil_div(nwarps, WARP_SIZE); | ||||||
|         tmp = 0.f; |         tmp = 0.f; | ||||||
|         for (size_t i = 0; i < nreduce; i += 1) |         for (size_t i = 0; i < nreduce; i += 1) | ||||||
|         { |         { | ||||||
|             tmp += s_sum[lane_id + i * WARP_SIZE]; |             tmp += s_sum[wi_in_sg + i * WARP_SIZE]; | ||||||
|         } |         } | ||||||
|         tmp = warp_reduce_sum(tmp, item_ct1); |         tmp = warp_reduce_sum(tmp, item_ct1); | ||||||
|     } |     } | ||||||
| @@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa | |||||||
|     const float scale = sycl::rsqrt(mean + eps); |     const float scale = sycl::rsqrt(mean + eps); | ||||||
|  |  | ||||||
|     for (int col = tid; col < ncols; col += block_size) { |     for (int col = tid; col < ncols; col += block_size) { | ||||||
|         dst[row * ncols + col] = scale * x[row * ncols + col]; |         dst[col] = scale * x[col]; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void norm_f32_sycl(const float* x, float* dst, const int ncols, | static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, | ||||||
|     const int nrows, const float eps, |         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, | ||||||
|     queue_ptr stream, int device) { |         const float eps, queue_ptr stream, int device) { | ||||||
|  |  | ||||||
|  |     const sycl::range<3> global_dims(nsamples, nchannels, nrows); | ||||||
|     GGML_ASSERT(ncols % WARP_SIZE == 0); |     GGML_ASSERT(ncols % WARP_SIZE == 0); | ||||||
|     if (ncols < 1024) { |     if (ncols < 1024) { | ||||||
|         const sycl::range<3> block_dims(1, 1, WARP_SIZE); |         const sycl::range<3> block_dims(1, 1, WARP_SIZE); | ||||||
|         stream->submit([&](sycl::handler& cgh) { |         stream->submit([&](sycl::handler& cgh) { | ||||||
|             cgh.parallel_for( |             cgh.parallel_for( | ||||||
|                 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, |                 sycl::nd_range<3>(global_dims * block_dims, block_dims), | ||||||
|                     block_dims), |  | ||||||
|                 [=](sycl::nd_item<3> item_ct1) |                 [=](sycl::nd_item<3> item_ct1) | ||||||
|                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|                     norm_f32(x, dst, ncols, eps, item_ct1, |                     norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); | ||||||
|                         nullptr, WARP_SIZE); |  | ||||||
|                 }); |                 }); | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @@ -253,14 +275,11 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, | |||||||
|         stream->submit([&](sycl::handler& cgh) { |         stream->submit([&](sycl::handler& cgh) { | ||||||
|             sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1( |             sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1( | ||||||
|                             sycl::range<1>(work_group_size / WARP_SIZE), cgh); |                             sycl::range<1>(work_group_size / WARP_SIZE), cgh); | ||||||
|  |  | ||||||
|             cgh.parallel_for( |             cgh.parallel_for( | ||||||
|                 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, |                 sycl::nd_range<3>(global_dims * block_dims, block_dims), | ||||||
|                     block_dims), |  | ||||||
|                 [=](sycl::nd_item<3> item_ct1) |                 [=](sycl::nd_item<3> item_ct1) | ||||||
|                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|                     norm_f32(x, dst, ncols, eps, item_ct1, |                     norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); | ||||||
|                         get_pointer(s_sum_acc_ct1), work_group_size); |  | ||||||
|                 }); |                 }); | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst, | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, | static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, | ||||||
|     const int nrows, const float eps, |         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { | ||||||
|     queue_ptr stream, int device) { |  | ||||||
|     GGML_ASSERT(ncols % WARP_SIZE == 0); |     GGML_ASSERT(ncols % WARP_SIZE == 0); | ||||||
|     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); |     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); | ||||||
|  |  | ||||||
|  |     const sycl::range<3> global_dims(nsamples, nchannels, nrows); | ||||||
|     if (ncols < 1024) { |     if (ncols < 1024) { | ||||||
|         const sycl::range<3> block_dims(1, 1, WARP_SIZE); |         const sycl::range<3> block_dims(1, 1, WARP_SIZE); | ||||||
|         stream->submit([&](sycl::handler& cgh) { |         stream->submit([&](sycl::handler& cgh) { | ||||||
|             cgh.parallel_for( |             cgh.parallel_for( | ||||||
|                 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, |                 sycl::nd_range<3>(global_dims * block_dims, block_dims), | ||||||
|                     block_dims), |  | ||||||
|                 [=](sycl::nd_item<3> item_ct1) |                 [=](sycl::nd_item<3> item_ct1) | ||||||
|                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|                     rms_norm_f32(x, dst, ncols, eps, item_ct1, |                     rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); | ||||||
|                         nullptr, WARP_SIZE); |  | ||||||
|                 }); |                 }); | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, | |||||||
|             sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), |             sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), | ||||||
|                 cgh); |                 cgh); | ||||||
|             cgh.parallel_for( |             cgh.parallel_for( | ||||||
|                 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, |                 sycl::nd_range<3>(global_dims * block_dims, block_dims), | ||||||
|                     block_dims), |  | ||||||
|                 [=](sycl::nd_item<3> item_ct1) |                 [=](sycl::nd_item<3> item_ct1) | ||||||
|                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { |                 [[sycl::reqd_sub_group_size(WARP_SIZE)]] { | ||||||
|                     rms_norm_f32(x, dst, ncols, eps, item_ct1, |                     rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); | ||||||
|                         get_pointer(s_sum_acc_ct1), work_group_size); |  | ||||||
|                 }); |                 }); | ||||||
|             }); |             }); | ||||||
|     } |     } | ||||||
| @@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, | |||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | ||||||
|  |     const ggml_tensor * src0 = dst->src[0]; | ||||||
|  |  | ||||||
|     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F32); |     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|     const int64_t ne00 = dst->src[0]->ne[0]; |     GGML_TENSOR_UNARY_OP_LOCALS | ||||||
|     const int64_t nrows = ggml_nrows(dst->src[0]); |  | ||||||
|     dpct::queue_ptr main_stream = ctx.stream(); |     dpct::queue_ptr main_stream = ctx.stream(); | ||||||
|     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); | ||||||
|     const float * src0_dd = static_cast<const float *>(dst->src[0]->data); |     const float * src0_dd = static_cast<const float *>(dst->src[0]->data); | ||||||
| @@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | |||||||
|  |  | ||||||
|     float eps; |     float eps; | ||||||
|     memcpy(&eps, dst->op_params, sizeof(float)); |     memcpy(&eps, dst->op_params, sizeof(float)); | ||||||
|  |     GGML_ASSERT(eps >= 0.0f); | ||||||
|  |     const size_t ts0 = ggml_type_size(src0->type); | ||||||
|  |     GGML_ASSERT(nb00 == ts0); | ||||||
|  |     const int64_t s01 = nb01 / ts0; | ||||||
|  |     const int64_t s02 = nb02 / ts0; | ||||||
|  |     const int64_t s03 = nb03 / ts0; | ||||||
|  |  | ||||||
|     norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); |     norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | ||||||
| @@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | |||||||
|  |  | ||||||
| void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | ||||||
|  |  | ||||||
|  |     const ggml_tensor * src0 = dst->src[0]; | ||||||
|     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); |     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(dst->type == GGML_TYPE_F32); |     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|     const int64_t ne00 = dst->src[0]->ne[0]; |  | ||||||
|     const int64_t nrows = ggml_nrows(dst->src[0]); |  | ||||||
|     dpct::queue_ptr main_stream = ctx.stream(); |     dpct::queue_ptr main_stream = ctx.stream(); | ||||||
|     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); |     SYCL_CHECK(ggml_sycl_set_device(ctx.device)); | ||||||
|  |  | ||||||
| @@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { | |||||||
|     float eps; |     float eps; | ||||||
|     memcpy(&eps, dst->op_params, sizeof(float)); |     memcpy(&eps, dst->op_params, sizeof(float)); | ||||||
|  |  | ||||||
|     rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); |     GGML_TENSOR_UNARY_OP_LOCALS | ||||||
|  |     const size_t ts0 = ggml_type_size(src0->type); | ||||||
|  |     GGML_ASSERT(nb00 == ts0); | ||||||
|  |     const int64_t s01 = nb01 / ts0; | ||||||
|  |     const int64_t s02 = nb02 / ts0; | ||||||
|  |     const int64_t s03 = nb03 / ts0; | ||||||
|  |     rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Akarshan Biswas
					Akarshan Biswas