add ggml_pad_ext for cpu & cuda backend

This commit is contained in:
leejet
2025-07-26 00:28:52 +08:00
parent c92f9b4a68
commit 93c7e775b8
5 changed files with 113 additions and 34 deletions

View File

@@ -2083,6 +2083,19 @@ extern "C" {
int p2,
int p3);
GGML_API struct ggml_tensor * ggml_pad_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
);
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,

View File

@@ -587,9 +587,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// the position of elements in the array means which dirction to padding,
// each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
// dim2.front, dim2.behind, dim3.front, dim3.behind]
int64_t paddings[] = {
0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
aclnn_pad(ctx, acl_src, acl_dst, paddings);
ggml_cann_release_resources(ctx, acl_src, acl_dst);
}

View File

@@ -8014,6 +8014,15 @@ static void ggml_compute_forward_pad_f32(
GGML_TENSOR_UNARY_OP_LOCALS
float * dst_ptr = (float *) dst->data;
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
// TODO: optimize
@@ -8022,10 +8031,12 @@ static void ggml_compute_forward_pad_f32(
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
if ((i0 >= lp0 && i0 < ne0 - rp0) \
&& (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;

View File

@@ -1,36 +1,50 @@
#include "pad.cuh"
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
static __global__ void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
// blockIdx.z: i3*ne2+i2
// blockIdx.y: i1
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
// gridDim.y: ne1
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
int i1 = blockIdx.y;
int i2 = blockIdx.z % ne2;
int i3 = blockIdx.z / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) \
&& (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
int i00 = i0 - lp0;
int i01 = i1 - lp1;
int i02 = i2 - lp2;
int i03 = i3 - lp3;
int ne02 = ne2 - lp2 - rp2;
int ne01 = ne1 - lp1 - rp1;
int ne00 = ne0 - lp0 - rp0;
int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[offset_dst] = 0.0f;
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
static void pad_f32_cuda(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}

View File

@@ -4792,11 +4792,36 @@ struct ggml_tensor * ggml_pad(
int p1,
int p2,
int p3) {
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
}
struct ggml_tensor * ggml_pad_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
) {
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
a->ne[0] + p0,
a->ne[1] + p1,
a->ne[2] + p2,
a->ne[3] + p3);
a->ne[0] + lp0 + rp0,
a->ne[1] + lp1 + rp1,
a->ne[2] + lp2 + rp2,
a->ne[3] + lp3 + rp3);
ggml_set_op_params_i32(result, 0, lp0);
ggml_set_op_params_i32(result, 1, rp0);
ggml_set_op_params_i32(result, 2, lp1);
ggml_set_op_params_i32(result, 3, rp1);
ggml_set_op_params_i32(result, 4, lp2);
ggml_set_op_params_i32(result, 5, rp2);
ggml_set_op_params_i32(result, 6, lp3);
ggml_set_op_params_i32(result, 7, rp3);
result->op = GGML_OP_PAD;
result->src[0] = a;