From 59fc1ec8e83b14354c1a3a8acf8c5c2cbf9af42f Mon Sep 17 00:00:00 2001 From: shani-f Date: Mon, 27 Oct 2025 03:19:50 +0200 Subject: [PATCH] sycl: add REPEAT_BACK operation support (#16734) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL repeat_back v1 — add core op + switch case * Implement repeat_back SYCL operation and minor fixes * Update ggml/src/ggml-sycl/repeat_back.cpp Co-authored-by: Sigbjørn Skjæret * Update ggml/src/ggml-sycl/repeat_back.hpp Co-authored-by: Sigbjørn Skjæret * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-sycl/ggml-sycl.cpp | 13 +++++++ ggml/src/ggml-sycl/repeat_back.cpp | 56 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/repeat_back.hpp | 8 +++++ 3 files changed, 77 insertions(+) create mode 100644 ggml/src/ggml-sycl/repeat_back.cpp create mode 100644 ggml/src/ggml-sycl/repeat_back.hpp diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b695ba051b..e6bcc596a4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -48,6 +48,7 @@ #include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/quantize.hpp" #include "ggml.h" @@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_repeat_back(ctx, dst); +} static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); @@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_REPEAT: ggml_sycl_repeat(ctx, dst); break; + case GGML_OP_REPEAT_BACK: + ggml_sycl_repeat_back(ctx, dst); + break; case GGML_OP_GET_ROWS: ggml_sycl_get_rows(ctx, dst); break; @@ -4516,6 +4524,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } + case GGML_OP_REPEAT_BACK: + { + ggml_type src0_type = op->src[0]->type; + return src0_type == GGML_TYPE_F32; + } case GGML_OP_DUP: case GGML_OP_ARGMAX: case GGML_OP_NONE: diff --git a/ggml/src/ggml-sycl/repeat_back.cpp b/ggml/src/ggml-sycl/repeat_back.cpp new file mode 100644 index 0000000000..abcd4cee72 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.cpp @@ -0,0 +1,56 @@ +#include "repeat_back.hpp" + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const float * src0_dd = (const float *) dst->src[0]->data; + float * dst_dd = (float *) dst->data; + + const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; + const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], + ne03 = dst->src[0]->ne[3]; + + const int nr0 = (int) (ne00 / ne0); + const int nr1 = (int) (ne01 / ne1); + const int nr2 = (int) (ne02 / ne2); + const int nr3 = (int) (ne03 / ne3); + + const size_t total = ne0 * ne1 * ne2 * ne3; + const int BLOCK_SIZE = 256; + const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + queue_ptr stream = ctx.stream(); + + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + const size_t i = item_ct1.get_global_linear_id(); + if (i >= total) { + return; + } + + const int i0 = i % ne0; + const int i1 = (i / ne0) % ne1; + const int i2 = (i / (ne0 * ne1)) % ne2; + const int i3 = i / (ne0 * ne1 * ne2); + + float acc = 0.0f; + + for (int j3 = 0; j3 < nr3; ++j3) { + for (int j2 = 0; j2 < nr2; ++j2) { + for (int j1 = 0; j1 < nr1; ++j1) { + for (int j0 = 0; j0 < nr0; ++j0) { + acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 + + (i3 + j3 * ne3) * ne00 * ne01 * ne02]; + } + } + } + } + + dst_dd[i] = acc; + }); +} diff --git a/ggml/src/ggml-sycl/repeat_back.hpp b/ggml/src/ggml-sycl/repeat_back.hpp new file mode 100644 index 0000000000..17a87f3e15 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_REPEAT_BACK_HPP +#define GGML_SYCL_REPEAT_BACK_HPP + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_REPEAT_BACK_HPP