opencl: support pad_ext (#15888)

This commit is contained in:
lhez
2025-09-30 10:45:45 -07:00
committed by GitHub
parent 16b0ca0d2e
commit 7c156df414
2 changed files with 80 additions and 36 deletions

View File

@@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_OP_REPEAT:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
@@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
GGML_ASSERT(dst->extra);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
const int s_ne0 = src0->ne[0];
const int s_ne1 = src0->ne[1];
const int s_ne2 = src0->ne[2];
const int s_ne3 = src0->ne[3];
const int s_nb0 = src0->nb[0];
const int s_nb1 = src0->nb[1];
const int s_nb2 = src0->nb[2];
const int s_nb3 = src0->nb[3];
const int d_ne0 = dst->ne[0];
const int d_ne1 = dst->ne[1];
const int d_ne2 = dst->ne[2];
const int d_ne3 = dst->ne[3];
const int d_nb0 = dst->nb[0];
const int d_nb1 = dst->nb[1];
const int d_nb2 = dst->nb[2];
const int d_nb3 = dst->nb[3];
const int lp0 = ((const int*)(dst->op_params))[0];
const int rp0 = ((const int*)(dst->op_params))[1];
const int lp1 = ((const int*)(dst->op_params))[2];
const int rp1 = ((const int*)(dst->op_params))[3];
const int lp2 = ((const int*)(dst->op_params))[4];
const int rp2 = ((const int*)(dst->op_params))[5];
const int lp3 = ((const int*)(dst->op_params))[6];
const int rp3 = ((const int*)(dst->op_params))[7];
cl_kernel kernel = backend_ctx->kernel_pad;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3));
size_t lws0 = 64;
size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
size_t local_work_size[] = { lws0, 1, 1 };
size_t * local_work_size_ptr = local_work_size;

View File

@@ -1,30 +1,39 @@
kernel void kernel_pad(
global const void * src0_ptr,
ulong src0_offset,
global void * dst_ptr,
ulong dst_offset,
int s_ne0, int s_ne1, int s_ne2,
int d_ne0, int d_ne1, int d_ne2
global void * src0,
ulong offset0,
global void * dst,
ulong offsetd,
int ne00, int ne01, int ne02, int ne03,
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
int ne0, int ne1, int ne2, int ne3,
ulong nb0, ulong nb1, ulong nb2, ulong nb3,
int lp0, int rp0,
int lp1, int rp1,
int lp2, int rp2,
int lp3, int rp3
) {
global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int nidx = get_global_id(0);
int idx_d1 = get_group_id(1);
int idx_d2 = get_group_id(2);
int i0 = get_global_id(0);
int i1 = get_group_id(1);
int i2 = get_group_id(2) % ne2;
int i3 = get_group_id(2) / ne2;
if (nidx >= d_ne0) {
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0;
bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
global float * dst_ptr = (global float *)((global char *)dst + dst_idx);
if (in_src_bounds) {
int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
dst[dst_el_offset] = src0[src_el_offset];
} else {
dst[dst_el_offset] = 0.0f;
}
bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3);
*dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
}