diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b493d80a06..55ebe76f89 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2 + nl - 1)/nl : x; - - y += NK; - - // load matrices from threadgroup memory and conduct outer products - threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); - threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { - simdgroup_barrier(mem_flags::mem_none); - - FOR_UNROLL (short i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); - } - - simdgroup_barrier(mem_flags::mem_none); - - FOR_UNROLL (short i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); - } - - simdgroup_barrier(mem_flags::mem_none); - - FOR_UNROLL (short i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); - } - - lsma += 8*64; - lsmb += 4*64; - } #else // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { @@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); } +#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); +#ifndef GGML_METAL_HAS_TENSOR + // load matrices from threadgroup memory and conduct outer products + threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); + + FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); + } + + simdgroup_barrier(mem_flags::mem_none); + + FOR_UNROLL (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + + lsma += 8*64; + lsmb += 4*64; + } +#else auto sA = tA.slice(0, 0); auto sB = tB.slice(0, 0);