mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
metal : initial Metal4 support
This commit is contained in:
@@ -9,6 +9,18 @@ __embed_ggml-common.h__
|
|||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#define GGML_METAL_USE_METAL4
|
||||||
|
|
||||||
|
#ifdef GGML_METAL_USE_METAL4
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <metal_tensor>
|
||||||
|
|
||||||
|
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
using namespace mpp::tensor_ops;
|
||||||
|
#endif
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
@@ -8145,6 +8157,8 @@ kernel void kernel_mul_mm(
|
|||||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||||
|
|
||||||
|
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||||
|
|
||||||
constexpr int NR0 = 64;
|
constexpr int NR0 = 64;
|
||||||
constexpr int NR1 = 32;
|
constexpr int NR1 = 32;
|
||||||
|
|
||||||
@@ -8164,15 +8178,6 @@ kernel void kernel_mul_mm(
|
|||||||
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
||||||
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
||||||
|
|
||||||
S0_8x8 ma[4];
|
|
||||||
S1_8x8 mb[2];
|
|
||||||
|
|
||||||
simdgroup_float8x8 mc[8];
|
|
||||||
|
|
||||||
for (short i = 0; i < 8; i++){
|
|
||||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
||||||
}
|
|
||||||
|
|
||||||
const short il0 = (tiitg % NL0);
|
const short il0 = (tiitg % NL0);
|
||||||
|
|
||||||
short il = il0;
|
short il = il0;
|
||||||
@@ -8193,7 +8198,28 @@ kernel void kernel_mul_mm(
|
|||||||
+ args.nb11*(r1 + lr1)
|
+ args.nb11*(r1 + lr1)
|
||||||
+ args.nb10*iy);
|
+ args.nb10*iy);
|
||||||
|
|
||||||
|
#ifndef GGML_METAL_USE_METAL4
|
||||||
|
S0_8x8 ma[4];
|
||||||
|
S1_8x8 mb[2];
|
||||||
|
|
||||||
|
simdgroup_float8x8 mc[8];
|
||||||
|
|
||||||
|
for (short i = 0; i < 8; i++){
|
||||||
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
||||||
|
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
||||||
|
|
||||||
|
constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate);
|
||||||
|
|
||||||
|
matmul2d<desc, execution_simdgroups<4>> mm;
|
||||||
|
|
||||||
|
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||||
|
#ifndef GGML_METAL_USE_METAL4
|
||||||
// load data and store to threadgroup memory
|
// load data and store to threadgroup memory
|
||||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@@ -8297,26 +8323,100 @@ kernel void kernel_mul_mm(
|
|||||||
lsma += 8*64;
|
lsma += 8*64;
|
||||||
lsmb += 4*64;
|
lsmb += 4*64;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
// load data and store to threadgroup memory
|
||||||
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// no need for dequantization
|
||||||
|
for (short i = 0; i < 16; i++) {
|
||||||
|
const short sx = 2*il0 + i/8;
|
||||||
|
const short sy = (tiitg/NL0)/8;
|
||||||
|
|
||||||
|
const short lx = i%8;
|
||||||
|
const short ly = (tiitg/NL0)%8;
|
||||||
|
//const short lx = (tiitg/NL0)%8;
|
||||||
|
//const short ly = i%8;
|
||||||
|
|
||||||
|
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
S0_4x4 temp_a;
|
||||||
|
dequantize_func(x, il, temp_a);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||||
|
const short sx = 2*il0 + i/8;
|
||||||
|
const short sy = (tiitg/NL0)/8;
|
||||||
|
|
||||||
|
const short lx = i%8;
|
||||||
|
const short ly = (tiitg/NL0)%8;
|
||||||
|
//const short lx = (tiitg/NL0)%8;
|
||||||
|
//const short ly = i%8;
|
||||||
|
|
||||||
|
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (short i = 0; i < 8; ++i) {
|
||||||
|
const short sx = (tiitg%NL1);
|
||||||
|
const short sy = (tiitg/NL1)/8;
|
||||||
|
|
||||||
|
const short lx = i;
|
||||||
|
const short ly = (tiitg/NL1)%8;
|
||||||
|
//const short lx = (tiitg/NL1)%8;
|
||||||
|
//const short ly = i;
|
||||||
|
|
||||||
|
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||||
|
|
||||||
|
y += NK;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
auto sA = tA.slice(0, 0);
|
||||||
|
auto sB = tB.slice(0, 0);
|
||||||
|
|
||||||
|
mm.run(sB, sA, cT);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
||||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||||
|
#ifdef GGML_METAL_USE_METAL4
|
||||||
|
device float * C = (device float *) dst +
|
||||||
|
r0 + \
|
||||||
|
r1 * args.ne0 + im*args.ne1*args.ne0;
|
||||||
|
|
||||||
|
auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
|
||||||
|
cT.store(tC);
|
||||||
|
#else
|
||||||
device float * C = (device float *) dst +
|
device float * C = (device float *) dst +
|
||||||
(r0 + 32*(sgitg & 1)) + \
|
(r0 + 32*(sgitg & 1)) + \
|
||||||
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||||
|
|
||||||
for (short i = 0; i < 8; i++) {
|
for (short i = 0; i < 8; i++) {
|
||||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0, 0, false);
|
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||||
|
|
||||||
|
#ifdef GGML_METAL_USE_METAL4
|
||||||
|
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
||||||
|
cT.store(tC);
|
||||||
|
#else
|
||||||
for (short i = 0; i < 8; i++) {
|
for (short i = 0; i < 8; i++) {
|
||||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user