mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : improve clarity (minor) (#10171)
This commit is contained in:
		| @@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     const short D4  = D/4; |     const short D4  = D/4; | ||||||
|     const short D16 = D/16; |     const short D16 = D/16; | ||||||
|     const short NW  = N_SIMDWIDTH; |     const short NW  = N_SIMDWIDTH; | ||||||
|     const short NW4 = NW/4; |     const short NL  = NW/4; | ||||||
|     const short SH  = 2*C; // shared memory per simdgroup |     const short SH  = 2*C; // shared memory per simdgroup | ||||||
|  |  | ||||||
|     const short T = D + nsg*SH; // shared memory size per query in (half) |     const short T = D + nsg*SH; // shared memory size per query in (half) | ||||||
| @@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results |     threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results | ||||||
|  |  | ||||||
|     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) |     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) | ||||||
|     o4x4_t lo[D16/NW4]; |     o4x4_t lo[D16/NL]; | ||||||
|  |  | ||||||
|     // load heads from Q to shared memory |     // load heads from Q to shared memory | ||||||
|     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); |     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); | ||||||
| @@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // zero out lo |     // zero out lo | ||||||
|     for (short i = 0; i < D16/NW4; i += NW4) { |     for (short i = 0; i < D16/NL; ++i) { | ||||||
|         lo[i] = (o4x4_t) 0.0f; |         lo[i] = (o4x4_t) 0.0f; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|         half M = -__FLT16_MAX__/2; |         half M = -__FLT16_MAX__/2; | ||||||
|  |  | ||||||
|         // thread indices inside the simdgroup |         // thread indices inside the simdgroup | ||||||
|         const short tx = tiisg%8; |         const short tx = tiisg%NL; | ||||||
|         const short ty = tiisg/8; |         const short ty = tiisg/NL; | ||||||
|  |  | ||||||
|         // broadcast kv |         // broadcast kv | ||||||
|         //const short rk2 = ne02/ne12; |         //const short rk2 = ne02/ne12; | ||||||
| @@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|         const short ikv3 = iq3/(ne03/ne_12_3); |         const short ikv3 = iq3/(ne03/ne_12_3); | ||||||
|  |  | ||||||
|         // load the queries from shared memory into local memory |         // load the queries from shared memory into local memory | ||||||
|         q4x4_t mq[D16/NW4]; |         q4x4_t mq[D16/NL]; | ||||||
|  |  | ||||||
|         for (short ii = 0; ii < D16; ii += NW4) { |         for (short ii = 0; ii < D16; ii += NL) { | ||||||
|             mq[ii/NW4] = sq4x4[ii + tx]; |             mq[ii/NL] = sq4x4[ii + tx]; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         const bool has_mask = mask != q; |         const bool has_mask = mask != q; | ||||||
| @@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|                     device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); |                     device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                     for (short ii = 0; ii < D16; ii += NW4) { |                     for (short ii = 0; ii < D16; ii += NL) { | ||||||
|                         const short i = ii + tx; |                         const short i = ii + tx; | ||||||
|  |  | ||||||
|                         k4x4_t mk; |                         k4x4_t mk; | ||||||
|                         deq_k(pk + i/nl_k, i%nl_k, mk); |                         deq_k(pk + i/nl_k, i%nl_k, mk); | ||||||
|  |  | ||||||
|                         mqk += |                         mqk += | ||||||
|                             dot(mq[ii/NW4][0], mk[0]) + |                             dot(mq[ii/NL][0], mk[0]) + | ||||||
|                             dot(mq[ii/NW4][1], mk[1]) + |                             dot(mq[ii/NL][1], mk[1]) + | ||||||
|                             dot(mq[ii/NW4][2], mk[2]) + |                             dot(mq[ii/NL][2], mk[2]) + | ||||||
|                             dot(mq[ii/NW4][3], mk[3]); |                             dot(mq[ii/NL][3], mk[3]); | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     // simdgroup reduce |                     // simdgroup reduce | ||||||
| @@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|  |  | ||||||
|                 // O = diag(ms)*O |                 // O = diag(ms)*O | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                 for (short ii = 0; ii < D16; ii += NW4) { |                 for (short ii = 0; ii < D16; ii += NL) { | ||||||
|                     lo[ii/NW4] *= ms; |                     lo[ii/NL] *= ms; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|                     const s4x4_t ms(ss[4*cc + ty]); |                     const s4x4_t ms(ss[4*cc + ty]); | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                     for (short ii = 0; ii < D16; ii += NW4) { |                     for (short ii = 0; ii < D16; ii += NL) { | ||||||
|                         const short i = ii + tx; |                         const short i = ii + tx; | ||||||
|  |  | ||||||
|                         v4x4_t mv; |                         v4x4_t mv; | ||||||
|                         deq_v(pv4 + i/nl_v, i%nl_v, mv); |                         deq_v(pv4 + i/nl_v, i%nl_v, mv); | ||||||
|  |  | ||||||
|                         lo[ii/NW4] += mv*ms; |                         lo[ii/NL] += mv*ms; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     // [ 5, 13, 21, 29] -> [ 5] |     // [ 5, 13, 21, 29] -> [ 5] | ||||||
|     // [ 6, 14, 22, 30] -> [ 6] |     // [ 6, 14, 22, 30] -> [ 6] | ||||||
|     // [ 7, 15, 23, 31] -> [ 7] |     // [ 7, 15, 23, 31] -> [ 7] | ||||||
|     for (short ii = 0; ii < D16; ii += NW4) { |     for (short ii = 0; ii < D16; ii += NL) { | ||||||
|         lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16); |         lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); | ||||||
|         lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0],  8); |         lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8); | ||||||
|  |       //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4); | ||||||
|  |       //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2); | ||||||
|  |       //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1); | ||||||
|  |  | ||||||
|         lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16); |         lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); | ||||||
|         lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1],  8); |         lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8); | ||||||
|  |       //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4); | ||||||
|  |       //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2); | ||||||
|  |       //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1); | ||||||
|  |  | ||||||
|         lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16); |         lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); | ||||||
|         lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2],  8); |         lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8); | ||||||
|  |       //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4); | ||||||
|  |       //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2); | ||||||
|  |       //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1); | ||||||
|  |  | ||||||
|         lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16); |         lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); | ||||||
|         lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3],  8); |         lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8); | ||||||
|  |       //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4); | ||||||
|  |       //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2); | ||||||
|  |       //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|     // store results to shared memory |     // store results to shared memory | ||||||
|     for (short i = tiisg; i < D16; i += NW4) { |     for (short i = tiisg; i < D16; i += NL) { | ||||||
|         sr4x4[i] = lo[i/NW4]; |         sr4x4[i] = lo[i/NL]; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov