diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 866cd2da58..7581163422 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 28ae2e1765..4d58297481 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c3c83abe4e..360fbe19f0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SCALE: case GGML_OP_CONV_TRANSPOSE_1D: return true; + case GGML_OP_CONV_TRANSPOSE_2D: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32; case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index fa2d82cefb..96f43d260a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -514,6 +514,19 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_conv_transpose_1d; +typedef struct { + int32_t IC; + int32_t IH; + int32_t IW; + int32_t KH; + int32_t KW; + int32_t OC; + int32_t s0; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_conv_transpose_2d; + typedef struct { uint64_t ofs0; uint64_t ofs1; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 4f9f6bda00..7a85edbdcd 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + + const int32_t IC = op->src[1]->ne[2]; + const int32_t IH = op->src[1]->ne[1]; + const int32_t IW = op->src[1]->ne[0]; + + const int32_t KH = op->src[0]->ne[1]; + const int32_t KW = op->src[0]->ne[0]; + + const int32_t OW = op->ne[0]; + const int32_t OH = op->ne[1]; + const int32_t OC = op->ne[2]; + + ggml_metal_kargs_conv_transpose_2d args = { + /*.IC =*/ IC, + /*.IH =*/ IH, + /*.IW =*/ IW, + /*.KH =*/ KH, + /*.KW =*/ KW, + /*.OC =*/ OC, + /*.s0 =*/ s0, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + // Metal requires buffer size to be multiple of 16 bytes + const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1); + + return 1; +} + int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f352738698..0d9cb8af7c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 496610b154..2c2f014151 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]); + +typedef void (conv_transpose_2d_t)( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const T * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t out_x = tgpig[0]; + const int64_t out_y = tgpig[1]; + const int64_t out_c = tgpig[2]; + + const int64_t kw = tpitg[0]; + const int64_t kh = tpitg[1]; + + float v = 0.0f; + + for (int64_t in_c = 0; in_c < args.IC; in_c++) { + int64_t in_y = out_y - kh; + + if (in_y < 0 || in_y % args.s0) continue; + + in_y /= args.s0; + + if (in_y >= args.IH) continue; + + int64_t in_x = out_x - kw; + + if (in_x < 0 || in_x % args.s0) continue; + + in_x /= args.s0; + + if (in_x >= args.IW) continue; + + const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x; + const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw; + + v += (float)src0[kernel_idx] * src1[input_idx]; + } + + const uint tid = tpitg.y * ntg.x + tpitg.x; + shared_sum[tid] = v; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tid == 0) { + float total = 0.0f; + const uint num_threads = ntg.x * ntg.y; + for (uint i = 0; i < num_threads; i++) { + total += shared_sum[i]; + } + + device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2); + dst_ptr[0] = total; + } +} + +template [[host_name("kernel_conv_transpose_2d_f32_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const float * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template [[host_name("kernel_conv_transpose_2d_f16_f32")]] +kernel void kernel_conv_transpose_2d( + constant ggml_metal_kargs_conv_transpose_2d & args, + device const half * src0, + device const float * src1, + device char * dst, + threadgroup float * shared_sum [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + kernel void kernel_upscale_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d5c5a2a665..82bb55ea0e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6989,6 +6989,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); + test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1)); + test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2)); test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));