[SYCL] fix UT fault cases: count-equal, argsort, pad OPs (#16521)

* fix/refactor OP argsort, pad

* fix count-equal op

* update SYCL OP list

* fix format issue

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
This commit is contained in:
Neo Zhang Jianyu
2025-10-12 21:53:35 +08:00
committed by GitHub
parent 8415f61e23
commit c7be9febcb
13 changed files with 12381 additions and 4393 deletions

View File

@@ -18,6 +18,7 @@
#include "concat.hpp"
#include "conv.hpp"
#include "convert.hpp"
#include "count-equal.hpp"
#include "cpy.hpp"
#include "dequantize.hpp"
#include "dmmv.hpp"
@@ -28,6 +29,7 @@
#include "mmvq.hpp"
#include "norm.hpp"
#include "outprod.hpp"
#include "pad.hpp"
#include "quantize.hpp"
#include "quants.hpp"
#include "rope.hpp"

View File

@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);

View File

@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}
static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}

View File

@@ -195,7 +195,8 @@ struct optimize_feature {
struct sycl_device_info {
int cc; // compute capability
// int nsm; // number of streaming multiprocessors
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool vmm; // virtual memory support

View File

@@ -0,0 +1,79 @@
#include "count-equal.hpp"
#include <cstdint>
template <typename T>
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
int64_t *__restrict__ dst, const int64_t dk,
const int64_t k) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
const int64_t i1 = sycl::min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;
}
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
(int *)dst, nequal);
}
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
dpct::queue_ptr stream = ctx.stream();
const int id = get_current_device_id();
const int nsm = ggml_sycl_info().devices[id].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne =
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
const dpct::dim3 block_nums(
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
SYCL_COUNT_EQUAL_CHUNK_SIZE),
1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int *src0_d = (const int *)src0->data;
const int *src1_d = (const int *)src1->data;
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
count_equal(src0_d, src1_d, dst_d, dne, ne);
GGML_UNUSED(item_ct1);
});
} break;
default:
GGML_ASSERT(false);
break;
}
}

View File

@@ -0,0 +1,9 @@
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
#define GGML_SYCL_COUNT_EQUAL_HPP
#include "common.hpp"
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif //GGML_SYCL_COUNT_EQUAL_HPP

View File

@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}
template <typename T>
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
const sycl::nd_item<3> &item_ct1) {
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
if (nidx >= ne0) {
return;
}
// operation
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
item_ct1.get_group(0) * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = static_cast<T>(0.0f);
}
}
template<typename T>
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
const sycl::nd_item<1> &item_ct1) {
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
});
}
template<typename T>
static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne01, const int ne02, const int ne0,
const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
}
}
template<typename KernelInvoker, typename... Args>
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
auto data_pts = cast_data<sycl::half>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
#endif
case GGML_TYPE_F32:
{
auto data_pts = cast_data<float>(dst);
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
} // namespace ggml_sycl_detail
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
});
}
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
queue_ptr stream) {
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
});
}
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
float min_val;
float max_val;
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_upscale(ctx, dst);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);

View File

@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

View File

@@ -85,9 +85,11 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].nsm = prop.get_max_compute_units();
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].smpbo = prop.get_local_mem_size();
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
}
for (int id = 0; id < info.device_count; ++id) {
@@ -1512,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
template <ggml_sort_order order>
__dpct_inline__ static void
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
uint8_t *dpct_local) {
// bitonic sort
int col = item_ct1.get_local_id(2);
int col_index = item_ct1.get_local_id(2);
int row = item_ct1.get_group(1);
if (col >= ncols_pad) {
return;
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col >= ncols_pad) {
return;
}
}
const float * x_row = x + row * ncols;
auto dst_row = (int *)dpct_local;
// initialize indices
dst_row[col] = col;
for (int i=0;i<tasks_per_thread;i++){
int col = col_index*tasks_per_thread+i;
dst_row[col] = col;
}
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (dst_row[col] >= ncols ||
(dst_row[ixj] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
: x_row[dst_row[col]] <
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (dst_row[ixj] >= ncols ||
(dst_row[col] < ncols &&
(order == GGML_SORT_ORDER_ASC
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
: x_row[dst_row[col]] >
x_row[dst_row[ixj]]))) {
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
}
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
/*
DPCT1118:1: SYCL group functions and algorithms must be encountered
in converged control flow. You may need to adjust the code.
*/
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
// copy the result to dst without the padding
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
for (int i = 0; i < tasks_per_thread; i++) {
int col = col_index * tasks_per_thread + i;
if (col < ncols) {
dst[row * ncols + col] = dst_row[col];
}
}
}
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
const sycl::nd_item<3> &item_ct1) {
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
@@ -1738,11 +1750,20 @@ static int next_power_of_2(int x) {
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const int nrows, ggml_sort_order order,
queue_ptr stream) {
queue_ptr stream, int device) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
const sycl::range<3> block_dims(1, 1, ncols_pad);
int nth = 1;
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
while (nth < ncols_pad && nth < max_block_size)
nth *= 2;
if (nth > max_block_size)
nth = max_block_size;
const int tasks_per_thread = ncols_pad / nth;
const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
@@ -1755,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -1769,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
dpct_local_acc_ct1
.get_multi_ptr<sycl::access::decorated::no>()
.get());
});
});
@@ -2142,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
main_stream, ctx.device);
}
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -4413,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ACC:
return true;
case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6:

View File

@@ -0,0 +1,97 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//#include "common.hpp"
#include "pad.hpp"
static void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
int i0 = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
int i1 = item_ct1.get_group(1);
int i2 = item_ct1.get_group(0) % ne2;
int i3 = item_ct1.get_group(0) / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
i02 * (ne00 * ne01) + i01 * ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3,
const int rp3, const int ne0, const int ne1,
const int ne2, const int ne3,
dpct::queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
ne2, ne3);
});
}
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_sycl(src0_d, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_pad(ctx, dst);
}

View File

@@ -0,0 +1,24 @@
//
// MIT license
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_PAD_HPP
#define GGML_SYCL_PAD_HPP
#include "common.hpp"
#define SYCL_PAD_BLOCK_SIZE 256
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_PAD_HPP