mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	cont : better ifdefs
This commit is contained in:
		@@ -8285,40 +8285,6 @@ kernel void kernel_mul_mm(
 | 
			
		||||
 | 
			
		||||
            *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        il = (il + 2 < nl) ? il + 2 : il % 2;
 | 
			
		||||
        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
 | 
			
		||||
 | 
			
		||||
        y += NK;
 | 
			
		||||
 | 
			
		||||
        // load matrices from threadgroup memory and conduct outer products
 | 
			
		||||
        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
 | 
			
		||||
        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
 | 
			
		||||
 | 
			
		||||
        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 4; i++) {
 | 
			
		||||
                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 2; i++) {
 | 
			
		||||
                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 8; i++){
 | 
			
		||||
                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            lsma += 8*64;
 | 
			
		||||
            lsmb += 4*64;
 | 
			
		||||
        }
 | 
			
		||||
#else
 | 
			
		||||
        // load data and store to threadgroup memory
 | 
			
		||||
        if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
 | 
			
		||||
@@ -8378,6 +8344,7 @@ kernel void kernel_mul_mm(
 | 
			
		||||
 | 
			
		||||
            *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        il = (il + 2 < nl) ? il + 2 : il % 2;
 | 
			
		||||
        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
 | 
			
		||||
@@ -8386,6 +8353,34 @@ kernel void kernel_mul_mm(
 | 
			
		||||
 | 
			
		||||
        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
#ifndef GGML_METAL_HAS_TENSOR
 | 
			
		||||
        // load matrices from threadgroup memory and conduct outer products
 | 
			
		||||
        threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
 | 
			
		||||
        threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
 | 
			
		||||
 | 
			
		||||
        FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 4; i++) {
 | 
			
		||||
                simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 2; i++) {
 | 
			
		||||
                simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            simdgroup_barrier(mem_flags::mem_none);
 | 
			
		||||
 | 
			
		||||
            FOR_UNROLL (short i = 0; i < 8; i++){
 | 
			
		||||
                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            lsma += 8*64;
 | 
			
		||||
            lsmb += 4*64;
 | 
			
		||||
        }
 | 
			
		||||
#else
 | 
			
		||||
        auto sA = tA.slice(0, 0);
 | 
			
		||||
        auto sB = tB.slice(0, 0);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user