mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : use less stack memory in FA kernel (#14088)
* metal : use less stack memory in FA kernel ggml-ci * cont : fix BF16 variant
This commit is contained in:
		| @@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext( | |||||||
|  |  | ||||||
|     threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data |     threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data | ||||||
|     threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t |     threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t | ||||||
|     threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +                0*DK); // reuse query data for accumulation |  | ||||||
|     threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +                0*DK); // same as above but in o4_t |  | ||||||
|     threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix |     threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix | ||||||
|  |  | ||||||
|     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory |     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory | ||||||
| @@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext( | |||||||
|  |  | ||||||
|             // O = diag(ms)*O |             // O = diag(ms)*O | ||||||
|             { |             { | ||||||
|                 s8x8_t mm; |                 s8x8_t ms; | ||||||
|                 simdgroup_load(mm, ss + 2*C, TS, 0, false); |                 simdgroup_load(ms, ss + 2*C, TS, 0, false); | ||||||
|  |  | ||||||
|                 #pragma unroll(DV8) |                 #pragma unroll(DV8) | ||||||
|                 for (short i = 0; i < DV8; ++i) { |                 for (short i = 0; i < DV8; ++i) { | ||||||
|                     simdgroup_multiply(lo[i], mm, lo[i]); |                     simdgroup_multiply(lo[i], ms, lo[i]); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             // O = O + (Q*K^T)*V |             // O = O + (Q*K^T)*V | ||||||
|             { |             { | ||||||
|                 for (short cc = 0; cc < C/8; ++cc) { |                 for (short cc = 0; cc < C/8; ++cc) { | ||||||
|                     s8x8_t ms; |                     s8x8_t vs; | ||||||
|                     simdgroup_load(ms, ss + 8*cc, TS, 0, false); |                     simdgroup_load(vs, ss + 8*cc, TS, 0, false); | ||||||
|  |  | ||||||
|                     if (is_same<vd4x4_t, v4x4_t>::value) { |                     if (is_same<vd4x4_t, v4x4_t>::value) { | ||||||
|                         // we can read directly from global memory |                         // we can read directly from global memory | ||||||
| @@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|                             v8x8_t mv; |                             v8x8_t mv; | ||||||
|                             simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 |                             simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 | ||||||
|  |  | ||||||
|                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); |                             simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); | ||||||
|                         } |                         } | ||||||
|                     } else { |                     } else { | ||||||
|                         for (short ii = 0; ii < DV16; ii += 4) { |                         for (short ii = 0; ii < DV16; ii += 4) { | ||||||
| @@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext( | |||||||
|                                     v8x8_t mv; |                                     v8x8_t mv; | ||||||
|  |  | ||||||
|                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); |                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); | ||||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); |                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); | ||||||
|  |  | ||||||
|                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); |                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); | ||||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); |                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); | ||||||
|                                 } |                                 } | ||||||
|                             } else { |                             } else { | ||||||
|                                 if (ii + tx < DV16) { |                                 if (ii + tx < DV16) { | ||||||
| @@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext( | |||||||
|                                     v8x8_t mv; |                                     v8x8_t mv; | ||||||
|  |  | ||||||
|                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); |                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); | ||||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); |                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); | ||||||
|  |  | ||||||
|                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); |                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); | ||||||
|                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); |                                     simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
| @@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext( | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // these are needed for reducing the results from the simdgroups (reuse the ss buffer) |         // these are needed for reducing the results from the simdgroups (reuse the ss buffer) | ||||||
|         for (short j = 0; j < Q; ++j) { |         for (short j = tiisg; j < Q; j += NW) { | ||||||
|             if (tiisg == 0) { |  | ||||||
|             ss[j*TS + 0] = S[j]; |             ss[j*TS + 0] = S[j]; | ||||||
|             ss[j*TS + 1] = M[j]; |             ss[j*TS + 1] = M[j]; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |     threadgroup float  * so  = (threadgroup float  *) (shmem_f16 + 0*DK); // reuse query data for accumulation | ||||||
|  |     threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); | ||||||
|  |  | ||||||
|  |     // store result to shared memory in F32 | ||||||
|  |     if (sgitg == 0) { | ||||||
|  |         for (short i = 0; i < DV8; ++i) { | ||||||
|  |             //simdgroup_store(lo[i], so + i*8, DV, 0, false); | ||||||
|  |             simdgroup_float8x8 t(1.0f); | ||||||
|  |             simdgroup_multiply(t, lo[i], t); | ||||||
|  |             simdgroup_store(t, so + i*8, DV, 0, false); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|     // reduce the warps sequentially |     // reduce the warps sequentially | ||||||
|     for (ushort sg = 1; sg < nsg; ++sg) { |     for (ushort sg = 1; sg < nsg; ++sg) { | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |  | ||||||
|  |  | ||||||
|         // each simdgroup stores its output to shared memory, reusing sq |  | ||||||
|         if (sgitg == sg) { |         if (sgitg == sg) { | ||||||
|             for (short i = 0; i < DV8; ++i) { |             for (short j = tiisg; j < Q; j += NW) { | ||||||
|                 simdgroup_store(lo[i], so + i*8, DV, 0, false); |                 const float S0 = ss[j*TS - 1*SH + 0]; | ||||||
|             } |                 const float S1 = ss[j*TS        + 0]; | ||||||
|         } |  | ||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |                 const float M0 = ss[j*TS - 1*SH + 1]; | ||||||
|  |                 const float M1 = ss[j*TS        + 1]; | ||||||
|         // the first simdgroup accumulates the results from the other simdgroups |  | ||||||
|         if (sgitg == 0) { |  | ||||||
|             for (short j = 0; j < Q; ++j) { |  | ||||||
|                 const float S0 = ss[j*TS +         0]; |  | ||||||
|                 const float S1 = ss[j*TS + sg*SH + 0]; |  | ||||||
|  |  | ||||||
|                 const float M0 = ss[j*TS +         1]; |  | ||||||
|                 const float M1 = ss[j*TS + sg*SH + 1]; |  | ||||||
|  |  | ||||||
|                 const float M = max(M0, M1); |                 const float M = max(M0, M1); | ||||||
|  |  | ||||||
|                 const float ms0 = exp(M0 - M); |                 float ms0 = exp(M0 - M); | ||||||
|                 const float ms1 = exp(M1 - M); |                 float ms1 = exp(M1 - M); | ||||||
|  |  | ||||||
|                 const float S = S0*ms0 + S1*ms1; |                 const float S = S0*ms0 + S1*ms1; | ||||||
|  |  | ||||||
|                 if (tiisg == 0) { |  | ||||||
|                 ss[j*TS + 0] = S; |                 ss[j*TS + 0] = S; | ||||||
|                 ss[j*TS + 1] = M; |                 ss[j*TS + 1] = M; | ||||||
|  |  | ||||||
|                     ss[j*TS + 2*C + j        ] = ms0; |                 ss[j*TS + 2*C + j - 1*SH] = ms0; | ||||||
|                     ss[j*TS + 2*C + j + sg*SH] = ms1; |                 ss[j*TS + 2*C + j       ] = ms1; | ||||||
|                 } |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             //simdgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 |             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||||||
|             { |             { | ||||||
|                 s8x8_t ms0; |                 s8x8_t ms0; | ||||||
|                 s8x8_t ms1; |                 s8x8_t ms1; | ||||||
|  |  | ||||||
|                 simdgroup_load(ms0, ss + 2*C,         TS, 0, false); |                 simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); | ||||||
|                 simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); |                 simdgroup_load(ms1, ss + 2*C,        TS, 0, false); | ||||||
|  |  | ||||||
|                 #pragma unroll(DV8) |                 #pragma unroll(DV8) | ||||||
|                 for (short i = 0; i < DV8; ++i) { |                 for (short i = 0; i < DV8; ++i) { | ||||||
|                     o8x8_t t; |                     simdgroup_float8x8 t; | ||||||
|  |  | ||||||
|                     simdgroup_load    (t, so + i*8, DV, 0, false); |                     simdgroup_load    (t, so + i*8, DV, 0, false); | ||||||
|                     simdgroup_multiply(t, ms1, t); |                     simdgroup_multiply(t, ms0, t); | ||||||
|  |  | ||||||
|                     simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); |                     simdgroup_multiply_accumulate(t, ms1, lo[i], t); | ||||||
|  |                     simdgroup_store(t, so + i*8, DV, 0, false); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // store result to shared memory (reuse sq) |  | ||||||
|     if (sgitg == 0) { |  | ||||||
|         for (short i = 0; i < DV8; ++i) { |  | ||||||
|             simdgroup_store(lo[i], so + i*8, DV, 0, false); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); |     threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); | ||||||
|  |  | ||||||
|     // final rescale with 1/S and store to global memory |     // final rescale with 1/S and store to global memory | ||||||
|     for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { |     for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { | ||||||
| @@ -3723,8 +3718,8 @@ kernel void kernel_flash_attn_ext( | |||||||
|     half,   half4x4,   simdgroup_half8x8,  \ |     half,   half4x4,   simdgroup_half8x8,  \ | ||||||
|     float,             simdgroup_float8x8, \ |     float,             simdgroup_float8x8, \ | ||||||
|     float,             simdgroup_float8x8, \ |     float,             simdgroup_float8x8, \ | ||||||
|     float,  float4,    simdgroup_float8x8 |     half,   half4,     simdgroup_half8x8 | ||||||
|     //half,   half4,     simdgroup_half8x8 |     //float,  float4,    simdgroup_float8x8 | ||||||
|  |  | ||||||
| #define FA_TYPES_BF \ | #define FA_TYPES_BF \ | ||||||
|     bfloat, bfloat4,   simdgroup_bfloat8x8, \ |     bfloat, bfloat4,   simdgroup_bfloat8x8, \ | ||||||
| @@ -3732,8 +3727,8 @@ kernel void kernel_flash_attn_ext( | |||||||
|     bfloat, bfloat4x4, simdgroup_bfloat8x8, \ |     bfloat, bfloat4x4, simdgroup_bfloat8x8, \ | ||||||
|     float,             simdgroup_float8x8,  \ |     float,             simdgroup_float8x8,  \ | ||||||
|     float,             simdgroup_float8x8,  \ |     float,             simdgroup_float8x8,  \ | ||||||
|     float,  float4,    simdgroup_float8x8 |     half,   half4,     simdgroup_half8x8 | ||||||
|     //half,   half4,     simdgroup_half8x8 |     //float,  float4,    simdgroup_float8x8 | ||||||
|  |  | ||||||
| typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; | typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov