diff --git a/docs/ops.md b/docs/ops.md index dfd1cfab6a..3738a48072 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -79,7 +79,7 @@ Legend: | REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | | REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | -| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index fe6876357f..101e80f64c 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -5637,25 +5637,25 @@ "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL" "SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 328d1a71b7..c97c589943 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -42,6 +42,7 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/norm.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/set_rows.hpp" @@ -2637,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds ggml_sycl_op_rms_norm(ctx, dst); } +static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_rms_norm_back(ctx, dst); +} + static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_l2_norm(ctx, dst); @@ -3827,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_sycl_leaky_relu(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_sycl_rms_norm_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_sycl_rms_norm(ctx, dst); break; @@ -4571,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + case GGML_OP_RMS_NORM_BACK: + return ((op->src[0]->ne[0] % WARP_SIZE) == 0); case GGML_OP_SCALE: return true; case GGML_OP_CONT: diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 4ec1416849..823d3a4828 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } +void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float eps = 1e-5f; + std::memcpy(&eps, dst->op_params, sizeof(float)); + if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f; + + const float * g_base = static_cast(dst->src[0]->data); // dz + const float * x_base = static_cast(dst->src[1]->data); // x + float * dx_base = static_cast< float *>(dst->data); + + const int64_t D = dst->ne[0]; + const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3; + const int64_t N = ggml_nrows(dst); + if (D == 0 || N == 0) return; + + const ggml_tensor *G = dst->src[0]; + const ggml_tensor *X = dst->src[1]; + const int ts = (int) ggml_type_size(X->type); + GGML_ASSERT((size_t) X->nb[0] == (size_t) ts); + GGML_ASSERT((size_t) G->nb[0] == (size_t) ts); + GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts); + + const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts; + const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts; + const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts; + + dpct::queue_ptr q = ctx.stream(); + + // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D + const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device]; + auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; }; + int wg_cap = 256; + if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg); + int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min(D, wg_cap), WARP_SIZE), wg_cap)); + + // FP32 path: per-thread compensated accumulation + hierarchical reduction + q->submit([&](sycl::handler &cgh) { + const int nwarps_loc = std::max(1, WG / WARP_SIZE); + // store one partial value per warp (xx and xg) for cross-warp reduction + auto l_xx = sycl::local_accessor(sycl::range<1>(nwarps_loc), cgh); + auto l_xg = sycl::local_accessor(sycl::range<1>(nwarps_loc), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG), + sycl::range<3>(1, 1, WG)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const int row = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t i1 = row % n1; + const int64_t i2 = (row / n1) % n2; + const int64_t i3 = row / (n1 * n2); + + const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1; + const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1; + float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1; + + // per-thread accumulation (compensated by default) + float sum_xx = 0.f, sum_xg = 0.f; +#ifndef GGML_SYCL_RMS_BACK_FAST + float c_xx = 0.f, c_xg = 0.f; +#endif + for (int64_t col = tid; col < D; col += WG) { + const float xv = x_row[col]; + const float gv = g_row[col]; +#ifdef GGML_SYCL_RMS_BACK_FAST + sum_xx += xv * xv; + sum_xg += xv * gv; +#else + float y1 = xv * xv - c_xx; + float t1 = sum_xx + y1; + c_xx = (t1 - sum_xx) - y1; + sum_xx = t1; + + float y2 = xv * gv - c_xg; + float t2 = sum_xg + y2; + c_xg = (t2 - sum_xg) - y2; + sum_xg = t2; +#endif + } + + // warp-level reduction + sycl::float2 xx = sycl::float2(sum_xx, +#ifndef GGML_SYCL_RMS_BACK_FAST + c_xx +#else + 0.f +#endif + ); + sycl::float2 xg = sycl::float2(sum_xg, +#ifndef GGML_SYCL_RMS_BACK_FAST + c_xg +#else + 0.f +#endif + ); + xx = warp_reduce_sum(xx, item_ct1); + xg = warp_reduce_sum(xg, item_ct1); + + // cross-warp reduction using local memory (single barrier) + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + + sycl::float2 xx_total = xx; + sycl::float2 xg_total = xg; + if (nwarps > 1) { + if (wi_in_sg == 0) { + l_xx[sg_id] = xx; + l_xg[sg_id] = xg; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (sg_id == 0) { + const unsigned wi_u = wi_in_sg; + sycl::float2 xx_first = (wi_u < static_cast(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f); + sycl::float2 xg_first = (wi_u < static_cast(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f); + xx_total = warp_reduce_sum(xx_first, item_ct1); + xg_total = warp_reduce_sum(xg_first, item_ct1); + } else { + // other subgroups keep their local totals; they'll be ignored + xx_total = xx; + xg_total = xg; + } + // ensure all threads see the first-subgroup result via broadcast below + } + + // compute inv_r and coeff once per row and broadcast to the whole work-group + float inv_r = 0.f; + float coeff = 0.f; + if (tid == 0) { + const float sum_xx_f = xx_total.x() + xx_total.y(); + const float sum_xdz_f = xg_total.x() + xg_total.y(); + const float mean_eps = sum_xx_f / (float) D + eps; + const float sum_eps = sum_xx_f + eps * (float) D; + inv_r = sycl::rsqrt(mean_eps); + coeff = -sum_xdz_f / sum_eps; + } + inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r); + coeff = sycl::group_broadcast(item_ct1.get_group(), coeff); + + for (int64_t col = tid; col < D; col += WG) { + d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r; + } + }); + }); + +} + void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp index 612cd67cf9..8cb885eb2e 100644 --- a/ggml/src/ggml-sycl/norm.hpp +++ b/ggml/src/ggml-sycl/norm.hpp @@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); +void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);