mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	move mask to shared mem
This commit is contained in:
		@@ -3277,7 +3277,7 @@ static void ggml_metal_encode_node(
 | 
			
		||||
                    // the shared memory needed for the simdgroups to load the KV cache
 | 
			
		||||
                    // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
 | 
			
		||||
                    //
 | 
			
		||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
 | 
			
		||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
 | 
			
		||||
 | 
			
		||||
                    int64_t nsgmax = 2;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2836,8 +2836,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
        uint3  tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
        uint3  tpitg[[thread_position_in_threadgroup]],
 | 
			
		||||
        uint3    ntg[[threads_per_threadgroup]],
 | 
			
		||||
        ushort tiisg[[thread_index_in_simdgroup]],
 | 
			
		||||
        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
        uint   tiisg[[thread_index_in_simdgroup]],
 | 
			
		||||
        uint   sgitg[[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
    const short nsg = ntg.y; // number of simdgroups
 | 
			
		||||
 | 
			
		||||
    const int iq3 = tgpig[2];
 | 
			
		||||
@@ -2848,7 +2848,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
    const short D8  = D/8;
 | 
			
		||||
    const short D16 = D/16;
 | 
			
		||||
    const short NW  = N_SIMDWIDTH;
 | 
			
		||||
    const short SH  = (C + Q); // shared memory per simdgroup in (half)
 | 
			
		||||
    const short SH  = (2*C + Q); // shared memory per simdgroup in (half)
 | 
			
		||||
 | 
			
		||||
    const short SF = sizeof(s_t)/sizeof(half);
 | 
			
		||||
 | 
			
		||||
@@ -2933,9 +2933,6 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
            simdgroup_load(mq[i], sq + i*8, T);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // pointer to the mask
 | 
			
		||||
        device const half * mp = (device const half *) (mask + iq1*nb31);
 | 
			
		||||
 | 
			
		||||
        const bool has_mask = mask != q;
 | 
			
		||||
 | 
			
		||||
        float slope = 1.0f;
 | 
			
		||||
@@ -2958,6 +2955,26 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (has_mask) {
 | 
			
		||||
                // used to detect blocks full of -INF
 | 
			
		||||
                half smax = -INFINITY;
 | 
			
		||||
 | 
			
		||||
                for (short j = 0; j < Q; ++j) {
 | 
			
		||||
                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
 | 
			
		||||
 | 
			
		||||
                    const half m = pm[ic + tiisg];
 | 
			
		||||
 | 
			
		||||
                    ss[j*TS + C + tiisg] = m;
 | 
			
		||||
                    smax = max(smax, m);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                smax = simd_max(smax);
 | 
			
		||||
 | 
			
		||||
                if (smax == -INFINITY) {
 | 
			
		||||
                    continue;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Q*K^T
 | 
			
		||||
            {
 | 
			
		||||
                for (short cc = 0; cc < C/8; ++cc) {
 | 
			
		||||
@@ -3033,9 +3050,6 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // used to detect blocks full of -INF
 | 
			
		||||
            float smax = -INFINITY;
 | 
			
		||||
 | 
			
		||||
            // online softmax
 | 
			
		||||
            {
 | 
			
		||||
                float ms[Q];
 | 
			
		||||
@@ -3052,10 +3066,10 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                    if (has_mask) {
 | 
			
		||||
                        // mqk = mqk + mask*slope
 | 
			
		||||
                        s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
 | 
			
		||||
                        //s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
 | 
			
		||||
                        s += slope*ss[j*TS + C + tiisg];
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    smax = simd_max(max(smax, s));
 | 
			
		||||
                    M[j] = simd_max(max(M[j], s));
 | 
			
		||||
 | 
			
		||||
                                ms[j] = exp(m - M[j]);
 | 
			
		||||
@@ -3069,19 +3083,14 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                // create a QxQ diagonal matrix for rescaling the output
 | 
			
		||||
                if (tiisg < Q) {
 | 
			
		||||
                    ss[tiisg*TS + C + tiisg] = ms[tiisg];
 | 
			
		||||
                    ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // skip -INF blocks
 | 
			
		||||
            if (smax == -INFINITY) {
 | 
			
		||||
                continue;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // O = diag(ms)*O
 | 
			
		||||
            {
 | 
			
		||||
                s8x8_t mm;
 | 
			
		||||
                simdgroup_load(mm, ss + C, TS, 0, false);
 | 
			
		||||
                simdgroup_load(mm, ss + 2*C, TS, 0, false);
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
                for (short i = 0; i < D8; ++i) {
 | 
			
		||||
@@ -3199,8 +3208,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                    ss[j*TS + 0] = S;
 | 
			
		||||
                    ss[j*TS + 1] = M;
 | 
			
		||||
 | 
			
		||||
                    ss[j*TS + C + j        ] = ms0;
 | 
			
		||||
                    ss[j*TS + C + j + sg*SH] = ms1;
 | 
			
		||||
                    ss[j*TS + 2*C + j        ] = ms0;
 | 
			
		||||
                    ss[j*TS + 2*C + j + sg*SH] = ms1;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -3209,8 +3218,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                s8x8_t ms0;
 | 
			
		||||
                s8x8_t ms1;
 | 
			
		||||
 | 
			
		||||
                simdgroup_load(ms0, ss + C,         TS, 0, false);
 | 
			
		||||
                simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
 | 
			
		||||
                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
 | 
			
		||||
                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
 | 
			
		||||
 | 
			
		||||
                for (short i = 0; i < D8; ++i) {
 | 
			
		||||
                    o8x8_t t;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user