diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9f91662cbd..8aeefd2c68 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1237,7 +1237,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_ char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3b163d9a38..15cea72513 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; - ggml_metal_kargs_im2col args = { /*.ofs0 =*/ ofs0, /*.ofs1 =*/ ofs1, @@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); - const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2ba4cb50b9..339cbf91fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3987,60 +3987,7 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; -// TODO: obolete -- remove -//typedef void (im2col_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template -//kernel void kernel_im2col( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { -//// const int64_t IC = tgpg[0]; -// const int64_t OH = tgpg[1]; -// const int64_t OW = tgpg[2]; -// -//// const int64_t N = ntg[0]; -// const int64_t KH = ntg[1]; -// const int64_t KW = ntg[2]; -// -// const int64_t in = tpitg[0]; -// const int64_t ikh = tpitg[1]; -// const int64_t ikw = tpitg[2]; -// -// const int64_t iic = tgpig[0]; -// const int64_t ioh = tgpig[1]; -// const int64_t iow = tgpig[2]; -// -// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; -// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; -// -// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; -// pdst[offset_dst] = x[offset_src]; -// } -//} -// -//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; - -typedef void (im2col_ext_t)( +typedef void (im2col_t)( constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, @@ -4050,48 +3997,113 @@ typedef void (im2col_ext_t)( uint3 ntg[[threads_per_threadgroup]]); template -kernel void kernel_im2col_ext( +kernel void kernel_im2col( constant ggml_metal_kargs_im2col & args, device const float * x, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tgpg[[threadgroups_per_grid]], uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] - const int64_t KHW = (int64_t)args.KHW; + uint3 ntg[[threads_per_threadgroup]]) { +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; - const int64_t d = tgpig[0] / args.CHW; - const int64_t chw = tgpig[0] % args.CHW; - const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) - const int64_t HW = tgpig[0] % KHW; + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; - const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; - if (tpitg_0 >= args.N) { - return; - } + int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; - const int64_t tpitg_1 = HW / args.KW; - const int64_t tpitg_2 = HW % args.KW; + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; - const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; - const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; + const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; - const int64_t offset_dst = - (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + - (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); device T * pdst = (device T *) (dst); if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { - pdst[offset_dst] = 0.0f; + while (in < args.N) { + pdst[offset_dst] = 0.0f; + offset_dst += ntg[0]*args.CHW*OH*OW; + + in += ntg[0]; + } } else { - const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; - pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; + + while (in < args.N) { + pdst[offset_dst] = x[offset_src]; + + offset_dst += ntg[0]*args.CHW*OH*OW; + offset_src += ntg[0]*args.ofs0; + + in += ntg[0]; + } } } -template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +// TODO: obolete -- remove +//typedef void (im2col_ext_t)( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]); +// +//template +//kernel void kernel_im2col_ext( +// constant ggml_metal_kargs_im2col & args, +// device const float * x, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW +// uint3 tpitg[[thread_position_in_threadgroup]], +// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] +// const int64_t KHW = (int64_t)args.KHW; +// +// const int64_t d = tgpig[0] / args.CHW; +// const int64_t chw = tgpig[0] % args.CHW; +// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) +// const int64_t HW = tgpig[0] % KHW; +// +// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; +// if (tpitg_0 >= args.N) { +// return; +// } +// +// const int64_t tpitg_1 = HW / args.KW; +// const int64_t tpitg_2 = HW % args.KW; +// +// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; +// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; +// +// const int64_t offset_dst = +// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + +// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); +// +// device T * pdst = (device T *) (dst); +// +// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { +// pdst[offset_dst] = 0.0f; +// } else { +// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; +// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; +// } +//} +// +//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; typedef void (conv_transpose_1d_t)( constant ggml_metal_kargs_conv_transpose_1d & args,