From 9d7c518d642db8657e2a53a9793c5f039ed8ea5a Mon Sep 17 00:00:00 2001 From: YehuditE Date: Thu, 6 Nov 2025 12:02:33 +0200 Subject: [PATCH] sycl: add CONCAT operator support (#16047) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sycl: add CONCAT operator support * cleanup: remove stray lines added by mistake * fix: code format issues in concat.cpp and tests/test-backend-ops.cpp * chore: fix editorconfig violations * cleanup: drop unnecessary i16 type support * docs: update sycl-csv and regenerate ops.md * update docs/ops.md * fix: adapt to upstream master changes after rebase * fix: remove empty files * fix: drop whitespace --------- Co-authored-by: Sigbjørn Skjæret --- docs/ops.md | 2 +- docs/ops/SYCL.csv | 32 +++++------ ggml/src/ggml-sycl/concat.cpp | 99 ++++++++++++++++++-------------- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 +- 4 files changed, 73 insertions(+), 66 deletions(-) diff --git a/docs/ops.md b/docs/ops.md index b8fcf04635..775b938bd1 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -24,7 +24,7 @@ Legend: | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | -| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | +| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | | CONV_2D | ❌ | ❌ | ✅ | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index 101e80f64c..d7e71990a8 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -9307,37 +9307,37 @@ "SYCL0","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","SYCL" "SYCL0","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","yes","SYCL" "SYCL0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","SYCL" -"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","SYCL" +"SYCL0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","yes","SYCL" "SYCL0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","SYCL" "SYCL0","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","SYCL" "SYCL0","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","SYCL" diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index c768365048..d16215bc91 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -11,9 +11,13 @@ // #include "concat.hpp" -#include "common.hpp" -static void concat_f32_dim0(const float *x, const float *y, float *dst, +static inline size_t elem_size(ggml_type t) { + return ggml_type_size(t) / ggml_blck_size(t); +} + +template +static void concat_T_dim0(const T *x, const T *y, T *dst, const int ne0, const int ne00, const sycl::nd_item<3> &item_ct1) { int nidx = item_ct1.get_local_id(2) + @@ -36,7 +40,8 @@ static void concat_f32_dim0(const float *x, const float *y, float *dst, } } -static void concat_f32_dim1(const float *x, const float *y, float *dst, +template +static void concat_T_dim1(const T *x, const T *y, T *dst, const int ne0, const int ne01, const sycl::nd_item<3> &item_ct1) { int nidx = item_ct1.get_local_id(2) + @@ -59,7 +64,8 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst, } } -static void concat_f32_dim2(const float *x, const float *y, float *dst, +template +static void concat_T_dim2(const T *x, const T *y, T *dst, const int ne0, const int ne02, const sycl::nd_item<3> &item_ct1) { int nidx = item_ct1.get_local_id(2) + @@ -82,45 +88,35 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst, } } -static void concat_f32_sycl(const float *x, const float *y, float *dst, +template +static void concat_T_sycl(const T *x, const T *y, T *dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, queue_ptr stream) { int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; sycl::range<3> gridDim(ne2, ne1, num_blocks); switch (dim) { case 0: - stream->parallel_for( - sycl::nd_range<3>(gridDim * - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); - }); - break; + stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { concat_T_dim0(x, y, dst, ne0, ne00, item_ct1); }); + break; case 1: - stream->parallel_for( - sycl::nd_range<3>(gridDim * - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); - }); - break; + stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { concat_T_dim1(x, y, dst, ne0, ne01, item_ct1); }); + break; // dim >=2 will be dispatched to the default path default: - stream->parallel_for( - sycl::nd_range<3>(gridDim * - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); - }); - break; + stream->parallel_for(sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { concat_T_dim2(x, y, dst, ne0, ne02, item_ct1); }); + break; } } // non-contiguous kernel (slow) -static void concat_f32_sycl_non_cont( +template +static void concat_T_sycl_non_cont( queue_ptr stream, const char *src0, const char *src1, char *dst, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00, uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/, @@ -137,24 +133,25 @@ static void concat_f32_sycl_non_cont( int64_t o[4] = { 0, 0, 0, 0 }; o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - const float * x; + const T * x; for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00); + x = (const T *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00); } else { - x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 + + x = (const T *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10); } - float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + T *y = (T *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); *y = *x; } }); } -void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +template +void concat_impl_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -163,15 +160,14 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { const int32_t dim = ((int32_t *) dst->op_params)[0]; if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - - float * dst_d = (float *) dst->data; - + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; + size_t type_size = elem_size(dst->type); if (dim != 3) { for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], + concat_T_sycl(src0_d + i3 * (src0->nb[3] / type_size), src1_d + i3 * (src1->nb[3] / type_size), + dst_d + i3 * (dst->nb[3] / type_size), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } } else { @@ -179,13 +175,28 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { const size_t size1 = ggml_nbytes(src1); SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); - SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / type_size, src1_d, size1).wait())); } } else { - concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data, + concat_T_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } } + +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + switch (dst->type) { + case GGML_TYPE_F32: + concat_impl_sycl(ctx, dst); + break; + case GGML_TYPE_I32: + concat_impl_sycl(ctx, dst); + break; + default: + GGML_ASSERT(false && "ggml_sycl_op_concat: unsupported type"); + break; + } +} diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c97c589943..f3b3e36574 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4534,16 +4534,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g } return false; } - case GGML_OP_CONCAT: - { - ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; - } case GGML_OP_REPEAT_BACK: { ggml_type src0_type = op->src[0]->type; return src0_type == GGML_TYPE_F32; } + case GGML_OP_CONCAT: case GGML_OP_DUP: case GGML_OP_ARGMAX: case GGML_OP_NONE: