From 81c72f1faf8a7b667ebfb8ca7ec638df3e3b9a73 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 17 Oct 2025 15:16:36 +0300 Subject: [PATCH] metal : initial Metal4 support --- ggml/src/ggml-metal/ggml-metal.metal | 120 ++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a2c7254028..03cc2716a7 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9,6 +9,18 @@ __embed_ggml-common.h__ #include +#define GGML_METAL_USE_METAL4 + +#ifdef GGML_METAL_USE_METAL4 +#include +#include + +#include + +using namespace metal; +using namespace mpp::tensor_ops; +#endif + using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) @@ -8145,6 +8157,8 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); + threadgroup float * sc = (threadgroup float *)(shmem); + constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8164,15 +8178,6 @@ kernel void kernel_mul_mm( const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63 const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31 - S0_8x8 ma[4]; - S1_8x8 mb[2]; - - simdgroup_float8x8 mc[8]; - - for (short i = 0; i < 8; i++){ - mc[i] = make_filled_simdgroup_matrix(0.f); - } - const short il0 = (tiitg % NL0); short il = il0; @@ -8193,7 +8198,28 @@ kernel void kernel_mul_mm( + args.nb11*(r1 + lr1) + args.nb10*iy); +#ifndef GGML_METAL_USE_METAL4 + S0_8x8 ma[4]; + S1_8x8 mb[2]; + + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } +#else + auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); + auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); + + constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate); + + matmul2d> mm; + + auto cT = mm.get_destination_cooperative_tensor(); +#endif + for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { +#ifndef GGML_METAL_USE_METAL4 // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -8297,26 +8323,100 @@ kernel void kernel_mul_mm( lsma += 8*64; lsmb += 4*64; } +#else + // load data and store to threadgroup memory + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // no need for dequantization + for (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + const short sx = 2*il0 + i/8; + const short sy = (tiitg/NL0)/8; + + const short lx = i%8; + const short ly = (tiitg/NL0)%8; + //const short lx = (tiitg/NL0)%8; + //const short ly = i%8; + + *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; + } + } + + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + + y += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mm.run(sB, sA, cT); +#endif } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory +#ifdef GGML_METAL_USE_METAL4 + device float * C = (device float *) dst + + r0 + \ + r1 * args.ne0 + im*args.ne1*args.ne0; + + auto tC = tensor, tensor_inline>(C, dextents(args.ne0, NR1)); + cT.store(tC); +#else device float * C = (device float *) dst + (r0 + 32*(sgitg & 1)) + \ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0, 0, false); + simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } +#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; +#ifdef GGML_METAL_USE_METAL4 + auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); + cT.store(tC); +#else for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } +#endif threadgroup_barrier(mem_flags::mem_threadgroup);