mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	remove inner if mask
This commit is contained in:
		@@ -2834,7 +2834,6 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
        constant     float & logit_softcap,
 | 
			
		||||
        threadgroup   half * shared [[threadgroup(0)]],
 | 
			
		||||
        uint3  tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
        uint3  tpitg[[thread_position_in_threadgroup]],
 | 
			
		||||
        uint3    ntg[[threads_per_threadgroup]],
 | 
			
		||||
        uint   tiisg[[thread_index_in_simdgroup]],
 | 
			
		||||
        uint   sgitg[[simdgroup_index_in_threadgroup]]) {
 | 
			
		||||
@@ -2981,7 +2980,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                    qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 | 
			
		||||
 | 
			
		||||
                    // this is compile-time check, so it does not have runtime overhead
 | 
			
		||||
                    if constexpr (is_same<kd4x4_t, k4x4_t>::value) {
 | 
			
		||||
                    if (is_same<kd4x4_t, k4x4_t>::value) {
 | 
			
		||||
                        // we can read directly from global memory
 | 
			
		||||
                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
 | 
			
		||||
 | 
			
		||||
@@ -2996,7 +2995,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                        for (short ii = 0; ii < D16; ii += 4) {
 | 
			
		||||
                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
 | 
			
		||||
 | 
			
		||||
                            if constexpr (D16%4 == 0) {
 | 
			
		||||
                            if (D16%4 == 0) {
 | 
			
		||||
                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
 | 
			
		||||
                                {
 | 
			
		||||
                                    k4x4_t tmp;
 | 
			
		||||
@@ -3038,15 +3037,10 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if constexpr (is_same<qk_t, s_t>::value) {
 | 
			
		||||
                        // same type - store directly
 | 
			
		||||
                        simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        // cast qk_t -> s_t
 | 
			
		||||
                        s8x8_t mqks(1.0f);
 | 
			
		||||
                        simdgroup_multiply(mqks, mqk, mqks);
 | 
			
		||||
                        simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
 | 
			
		||||
                    }
 | 
			
		||||
                    // cast qk_t -> s_t
 | 
			
		||||
                    s8x8_t mqks(1.0f);
 | 
			
		||||
                    simdgroup_multiply(mqks, mqk, mqks);
 | 
			
		||||
                    simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -3062,11 +3056,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                        s = logit_softcap*precise::tanh(s);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    if (has_mask) {
 | 
			
		||||
                        // mqk = mqk + mask*slope
 | 
			
		||||
                        //s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
 | 
			
		||||
                        s += slope*ss[j*TS + C + tiisg];
 | 
			
		||||
                    }
 | 
			
		||||
                    // mqk = mqk + mask*slope
 | 
			
		||||
                    s += slope*ss[j*TS + C + tiisg];
 | 
			
		||||
 | 
			
		||||
                    M[j] = simd_max(max(M[j], s));
 | 
			
		||||
 | 
			
		||||
@@ -3078,6 +3069,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                    // the P matrix from the paper (Q rows, C columns)
 | 
			
		||||
                    ss[j*TS + tiisg] = vs;
 | 
			
		||||
 | 
			
		||||
                    // create a QxQ diagonal matrix for rescaling the output
 | 
			
		||||
                    if (tiisg == j) {
 | 
			
		||||
                        ss[j*TS + 2*C + j] = ms;
 | 
			
		||||
                    }
 | 
			
		||||
@@ -3101,7 +3093,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                    s8x8_t ms;
 | 
			
		||||
                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
 | 
			
		||||
 | 
			
		||||
                    if constexpr (is_same<vd4x4_t, v4x4_t>::value) {
 | 
			
		||||
                    if (is_same<vd4x4_t, v4x4_t>::value) {
 | 
			
		||||
                        // we can read directly from global memory
 | 
			
		||||
                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
 | 
			
		||||
#pragma unroll
 | 
			
		||||
@@ -3115,7 +3107,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                        for (short ii = 0; ii < D16; ii += 4) {
 | 
			
		||||
                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
 | 
			
		||||
 | 
			
		||||
                            if constexpr (D16%4 == 0) {
 | 
			
		||||
                            if (D16%4 == 0) {
 | 
			
		||||
                                // no need for bound checks
 | 
			
		||||
                                {
 | 
			
		||||
                                    v4x4_t tmp;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user