mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	metal : use less stack memory in FA kernel (#14088)
* metal : use less stack memory in FA kernel ggml-ci * cont : fix BF16 variant
This commit is contained in:
		| @@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext( | ||||
|  | ||||
|     threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data | ||||
|     threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t | ||||
|     threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +                0*DK); // reuse query data for accumulation | ||||
|     threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +                0*DK); // same as above but in o4_t | ||||
|     threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix | ||||
|  | ||||
|     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory | ||||
| @@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext( | ||||
|  | ||||
|             // O = diag(ms)*O | ||||
|             { | ||||
|                 s8x8_t mm; | ||||
|                 simdgroup_load(mm, ss + 2*C, TS, 0, false); | ||||
|                 s8x8_t ms; | ||||
|                 simdgroup_load(ms, ss + 2*C, TS, 0, false); | ||||
|  | ||||
|                 #pragma unroll(DV8) | ||||
|                 for (short i = 0; i < DV8; ++i) { | ||||
|                     simdgroup_multiply(lo[i], mm, lo[i]); | ||||
|                     simdgroup_multiply(lo[i], ms, lo[i]); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // O = O + (Q*K^T)*V | ||||
|             { | ||||
|                 for (short cc = 0; cc < C/8; ++cc) { | ||||
|                     s8x8_t ms; | ||||
|                     simdgroup_load(ms, ss + 8*cc, TS, 0, false); | ||||
|                     s8x8_t vs; | ||||
|                     simdgroup_load(vs, ss + 8*cc, TS, 0, false); | ||||
|  | ||||
|                     if (is_same<vd4x4_t, v4x4_t>::value) { | ||||
|                         // we can read directly from global memory | ||||
| @@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext( | ||||
|                             v8x8_t mv; | ||||
|                             simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 | ||||
|  | ||||
|                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); | ||||
|                             simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); | ||||
|                         } | ||||
|                     } else { | ||||
|                         for (short ii = 0; ii < DV16; ii += 4) { | ||||
| @@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext( | ||||
|                                     v8x8_t mv; | ||||
|  | ||||
|                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); | ||||
|  | ||||
|                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); | ||||
|                                 } | ||||
|                             } else { | ||||
|                                 if (ii + tx < DV16) { | ||||
| @@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext( | ||||
|                                     v8x8_t mv; | ||||
|  | ||||
|                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); | ||||
|  | ||||
|                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); | ||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
| @@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext( | ||||
|         } | ||||
|  | ||||
|         // these are needed for reducing the results from the simdgroups (reuse the ss buffer) | ||||
|         for (short j = 0; j < Q; ++j) { | ||||
|             if (tiisg == 0) { | ||||
|                 ss[j*TS + 0] = S[j]; | ||||
|                 ss[j*TS + 1] = M[j]; | ||||
|             } | ||||
|         for (short j = tiisg; j < Q; j += NW) { | ||||
|             ss[j*TS + 0] = S[j]; | ||||
|             ss[j*TS + 1] = M[j]; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     threadgroup float  * so  = (threadgroup float  *) (shmem_f16 + 0*DK); // reuse query data for accumulation | ||||
|     threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); | ||||
|  | ||||
|     // store result to shared memory in F32 | ||||
|     if (sgitg == 0) { | ||||
|         for (short i = 0; i < DV8; ++i) { | ||||
|             //simdgroup_store(lo[i], so + i*8, DV, 0, false); | ||||
|             simdgroup_float8x8 t(1.0f); | ||||
|             simdgroup_multiply(t, lo[i], t); | ||||
|             simdgroup_store(t, so + i*8, DV, 0, false); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     // reduce the warps sequentially | ||||
|     for (ushort sg = 1; sg < nsg; ++sg) { | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         // each simdgroup stores its output to shared memory, reusing sq | ||||
|         if (sgitg == sg) { | ||||
|             for (short i = 0; i < DV8; ++i) { | ||||
|                 simdgroup_store(lo[i], so + i*8, DV, 0, false); | ||||
|             } | ||||
|         } | ||||
|             for (short j = tiisg; j < Q; j += NW) { | ||||
|                 const float S0 = ss[j*TS - 1*SH + 0]; | ||||
|                 const float S1 = ss[j*TS        + 0]; | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         // the first simdgroup accumulates the results from the other simdgroups | ||||
|         if (sgitg == 0) { | ||||
|             for (short j = 0; j < Q; ++j) { | ||||
|                 const float S0 = ss[j*TS +         0]; | ||||
|                 const float S1 = ss[j*TS + sg*SH + 0]; | ||||
|  | ||||
|                 const float M0 = ss[j*TS +         1]; | ||||
|                 const float M1 = ss[j*TS + sg*SH + 1]; | ||||
|                 const float M0 = ss[j*TS - 1*SH + 1]; | ||||
|                 const float M1 = ss[j*TS        + 1]; | ||||
|  | ||||
|                 const float M = max(M0, M1); | ||||
|  | ||||
|                 const float ms0 = exp(M0 - M); | ||||
|                 const float ms1 = exp(M1 - M); | ||||
|                 float ms0 = exp(M0 - M); | ||||
|                 float ms1 = exp(M1 - M); | ||||
|  | ||||
|                 const float S = S0*ms0 + S1*ms1; | ||||
|  | ||||
|                 if (tiisg == 0) { | ||||
|                     ss[j*TS + 0] = S; | ||||
|                     ss[j*TS + 1] = M; | ||||
|                 ss[j*TS + 0] = S; | ||||
|                 ss[j*TS + 1] = M; | ||||
|  | ||||
|                     ss[j*TS + 2*C + j        ] = ms0; | ||||
|                     ss[j*TS + 2*C + j + sg*SH] = ms1; | ||||
|                 } | ||||
|                 ss[j*TS + 2*C + j - 1*SH] = ms0; | ||||
|                 ss[j*TS + 2*C + j       ] = ms1; | ||||
|             } | ||||
|  | ||||
|             //simdgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||||
|             { | ||||
|                 s8x8_t ms0; | ||||
|                 s8x8_t ms1; | ||||
|  | ||||
|                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false); | ||||
|                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); | ||||
|                 simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); | ||||
|                 simdgroup_load(ms1, ss + 2*C,        TS, 0, false); | ||||
|  | ||||
|                 #pragma unroll(DV8) | ||||
|                 for (short i = 0; i < DV8; ++i) { | ||||
|                     o8x8_t t; | ||||
|                     simdgroup_float8x8 t; | ||||
|  | ||||
|                     simdgroup_load    (t, so + i*8, DV, 0, false); | ||||
|                     simdgroup_multiply(t, ms1, t); | ||||
|                     simdgroup_multiply(t, ms0, t); | ||||
|  | ||||
|                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); | ||||
|                     simdgroup_multiply_accumulate(t, ms1, lo[i], t); | ||||
|                     simdgroup_store(t, so + i*8, DV, 0, false); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     } | ||||
|  | ||||
|     // store result to shared memory (reuse sq) | ||||
|     if (sgitg == 0) { | ||||
|         for (short i = 0; i < DV8; ++i) { | ||||
|             simdgroup_store(lo[i], so + i*8, DV, 0, false); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); | ||||
|     threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); | ||||
|  | ||||
|     // final rescale with 1/S and store to global memory | ||||
|     for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { | ||||
| @@ -3723,8 +3718,8 @@ kernel void kernel_flash_attn_ext( | ||||
|     half,   half4x4,   simdgroup_half8x8,  \ | ||||
|     float,             simdgroup_float8x8, \ | ||||
|     float,             simdgroup_float8x8, \ | ||||
|     float,  float4,    simdgroup_float8x8 | ||||
|     //half,   half4,     simdgroup_half8x8 | ||||
|     half,   half4,     simdgroup_half8x8 | ||||
|     //float,  float4,    simdgroup_float8x8 | ||||
|  | ||||
| #define FA_TYPES_BF \ | ||||
|     bfloat, bfloat4,   simdgroup_bfloat8x8, \ | ||||
| @@ -3732,8 +3727,8 @@ kernel void kernel_flash_attn_ext( | ||||
|     bfloat, bfloat4x4, simdgroup_bfloat8x8, \ | ||||
|     float,             simdgroup_float8x8,  \ | ||||
|     float,             simdgroup_float8x8,  \ | ||||
|     float,  float4,    simdgroup_float8x8 | ||||
|     //half,   half4,     simdgroup_half8x8 | ||||
|     half,   half4,     simdgroup_half8x8 | ||||
|     //float,  float4,    simdgroup_float8x8 | ||||
|  | ||||
| typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov