mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
cont : better ifdefs
This commit is contained in:
@@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm(
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
||||
y += NK;
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
#else
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
@@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm(
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#endif
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
@@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm(
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
#else
|
||||
auto sA = tA.slice(0, 0);
|
||||
auto sB = tB.slice(0, 0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user