From 7e994168b1ccc12337ba8de939c4fd466107c1fb Mon Sep 17 00:00:00 2001 From: shani-f Date: Mon, 3 Nov 2025 03:35:33 +0200 Subject: [PATCH] =?UTF-8?q?SYCL:=20optimized=20repeat=5Fback=20kernel=20(3?= =?UTF-8?q?=C3=97=20fewer=20asm=20instructions,=202=C3=97=20faster)Feature?= =?UTF-8?q?/sycl=20repeat=20back=20opt=20(#16869)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL repeat_back v1 — add core op + switch case * Implement repeat_back SYCL operation and minor fixes * SYCL: optimize repeat_back kernel * Remove Hebrew comment from repeat_back.cpp * Remove comments for code clarity Removed comments to clean up the code. * Fix formatting in ggml-sycl.cpp * Formatted lambda according to legacy style. No logic changes * Remove blank line in repeat_back.cpp Remove unnecessary blank line before assigning acc to dst_dd. --- ggml/src/ggml-sycl/repeat_back.cpp | 70 +++++++++++++++++++----------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/repeat_back.cpp b/ggml/src/ggml-sycl/repeat_back.cpp index abcd4cee72..845b48468c 100644 --- a/ggml/src/ggml-sycl/repeat_back.cpp +++ b/ggml/src/ggml-sycl/repeat_back.cpp @@ -2,26 +2,43 @@ #include "common.hpp" -void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#define GGML_ASSERT_TENSOR_FITS_INT(t) \ + GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX) +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); const float * src0_dd = (const float *) dst->src[0]->data; float * dst_dd = (float *) dst->data; - const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; - const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], - ne03 = dst->src[0]->ne[3]; + GGML_ASSERT_TENSOR_FITS_INT(dst); + GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]); - const int nr0 = (int) (ne00 / ne0); - const int nr1 = (int) (ne01 / ne1); - const int nr2 = (int) (ne02 / ne2); - const int nr3 = (int) (ne03 / ne3); + const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; + const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], + ne03 = dst->src[0]->ne[3]; - const size_t total = ne0 * ne1 * ne2 * ne3; - const int BLOCK_SIZE = 256; - const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int nr0 = ne00 / ne0; + const int nr1 = ne01 / ne1; + const int nr2 = ne02 / ne2; + const int nr3 = ne03 / ne3; + + const int nb0 = dst->src[0]->nb[0]; + const int nb1 = dst->src[0]->nb[1]; + const int nb2 = dst->src[0]->nb[2]; + const int nb3 = dst->src[0]->nb[3]; + + const char * base = (const char *) src0_dd; + + const size_t total = (size_t) ne0 * ne1 * ne2 * ne3; + constexpr int BLOCK_SIZE = 256; + const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + const float inv_ne0 = 1.0f / ne0; + const float inv_ne_01 = 1.0f / (ne0 * ne1); + const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2); + const int repeat_count = nr0 * nr1 * nr2 * nr3; queue_ptr stream = ctx.stream(); @@ -33,24 +50,27 @@ void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst return; } - const int i0 = i % ne0; - const int i1 = (i / ne0) % ne1; - const int i2 = (i / (ne0 * ne1)) % ne2; - const int i3 = i / (ne0 * ne1 * ne2); + const int i3 = (int) (i * inv_ne_012); + const int i2 = (int) (i * inv_ne_01) - i3 * ne2; + const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1; + const int i0 = i - (int) (i * inv_ne0) * ne0; + int j0 = 0, j1 = 0, j2 = 0, j3 = 0; float acc = 0.0f; - for (int j3 = 0; j3 < nr3; ++j3) { - for (int j2 = 0; j2 < nr2; ++j2) { - for (int j1 = 0; j1 < nr1; ++j1) { - for (int j0 = 0; j0 < nr0; ++j0) { - acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 + - (i3 + j3 * ne3) * ne00 * ne01 * ne02]; - } - } - } - } + for (int j = 0; j < repeat_count; ++j) { + const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 + + (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3); + acc += *ptr; + int carry = (++j0 >= nr0); + j0 -= carry * nr0; + carry = (carry && (++j1 >= nr1)); + j1 -= carry * nr1; + carry = (carry && (++j2 >= nr2)); + j2 -= carry * nr2; + j3 += carry; + } dst_dd[i] = acc; }); }