mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
cont : better ifdefs
This commit is contained in:
@@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm(
|
|||||||
|
|
||||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||||
}
|
}
|
||||||
|
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
||||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
|
||||||
|
|
||||||
y += NK;
|
|
||||||
|
|
||||||
// load matrices from threadgroup memory and conduct outer products
|
|
||||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
|
||||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
|
||||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
|
||||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
|
||||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
lsma += 8*64;
|
|
||||||
lsmb += 4*64;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
// load data and store to threadgroup memory
|
// load data and store to threadgroup memory
|
||||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||||
@@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm(
|
|||||||
|
|
||||||
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||||
@@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm(
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
#ifndef GGML_METAL_HAS_TENSOR
|
||||||
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
|
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||||
|
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||||
|
|
||||||
|
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||||
|
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||||
|
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||||
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
lsma += 8*64;
|
||||||
|
lsmb += 4*64;
|
||||||
|
}
|
||||||
|
#else
|
||||||
auto sA = tA.slice(0, 0);
|
auto sA = tA.slice(0, 0);
|
||||||
auto sB = tB.slice(0, 0);
|
auto sB = tB.slice(0, 0);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user