metal : restore im2col perf (#16219)

This commit is contained in:
Georgi Gerganov
2025-09-25 11:29:08 +03:00
committed by GitHub
parent c498fc82fe
commit 02a6a82ae7
3 changed files with 95 additions and 83 deletions

View File

@@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
char base[256];
char name[256];
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

View File

@@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
ggml_metal_kargs_im2col args = {
/*.ofs0 =*/ ofs0,
/*.ofs1 =*/ ofs1,
@@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
return 1;
}

View File

@@ -3987,60 +3987,7 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
// TODO: obolete -- remove
//typedef void (im2col_t)(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// uint3 tgpg[[threadgroups_per_grid]],
// uint3 tpitg[[thread_position_in_threadgroup]],
// uint3 ntg[[threads_per_threadgroup]]);
//
//template <typename T>
//kernel void kernel_im2col(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// uint3 tgpg[[threadgroups_per_grid]],
// uint3 tpitg[[thread_position_in_threadgroup]],
// uint3 ntg[[threads_per_threadgroup]]) {
//// const int64_t IC = tgpg[0];
// const int64_t OH = tgpg[1];
// const int64_t OW = tgpg[2];
//
//// const int64_t N = ntg[0];
// const int64_t KH = ntg[1];
// const int64_t KW = ntg[2];
//
// const int64_t in = tpitg[0];
// const int64_t ikh = tpitg[1];
// const int64_t ikw = tpitg[2];
//
// const int64_t iic = tgpig[0];
// const int64_t ioh = tgpig[1];
// const int64_t iow = tgpig[2];
//
// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
//
// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
//
// device T * pdst = (device T *) (dst);
//
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
// pdst[offset_dst] = 0.0f;
// } else {
// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
// pdst[offset_dst] = x[offset_src];
// }
//}
//
//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
typedef void (im2col_ext_t)(
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
@@ -4050,48 +3997,113 @@ typedef void (im2col_ext_t)(
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: obolete -- remove
//typedef void (im2col_ext_t)(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// uint3 tgpg[[threadgroups_per_grid]],
// uint3 tpitg[[thread_position_in_threadgroup]],
// uint3 ntg[[threads_per_threadgroup]]);
//
//template <typename T>
//kernel void kernel_im2col_ext(
// constant ggml_metal_kargs_im2col & args,
// device const float * x,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
// uint3 tpitg[[thread_position_in_threadgroup]],
// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
// const int64_t KHW = (int64_t)args.KHW;
//
// const int64_t d = tgpig[0] / args.CHW;
// const int64_t chw = tgpig[0] % args.CHW;
// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
// const int64_t HW = tgpig[0] % KHW;
//
// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
// if (tpitg_0 >= args.N) {
// return;
// }
//
// const int64_t tpitg_1 = HW / args.KW;
// const int64_t tpitg_2 = HW % args.KW;
//
// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
//
// const int64_t offset_dst =
// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
//
// device T * pdst = (device T *) (dst);
//
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
// pdst[offset_dst] = 0.0f;
// } else {
// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
// }
//}
//
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,