mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : minor fixup in FA kernel (#10143)
* metal : minor fixup in FA kernel ggml-ci * metal : use the unrolled loop variable * metal : remove unused var
This commit is contained in:
		| @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|         const short iv3 = iq3 / rv3; |         const short iv3 = iq3 / rv3; | ||||||
|  |  | ||||||
|         // load the queries from shared memory into local memory |         // load the queries from shared memory into local memory | ||||||
|         float4 mq[D4]; |         float4 mq[D4/NW]; | ||||||
|  |  | ||||||
|         for (short ii = 0; ii < D4; ii += NW) { |         for (short ii = 0; ii < D4; ii += NW) { | ||||||
|             short i = ii + tiisg; |             short i = ii + tiisg; | ||||||
|             mq[i] = (float4) sq4[i]; |             mq[ii/NW] = (float4) sq4[i]; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // pointer to the mask |         // pointer to the mask | ||||||
| @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|                         mk[2] = (float4) pk4[i + 2*(nb11/8)]; |                         mk[2] = (float4) pk4[i + 2*(nb11/8)]; | ||||||
|                         mk[3] = (float4) pk4[i + 3*(nb11/8)]; |                         mk[3] = (float4) pk4[i + 3*(nb11/8)]; | ||||||
|  |  | ||||||
|                         mqk += (float4) (mq[i] * mk); |                         mqk += (float4) (mq[ii/NW] * mk); | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     // reduce the results from the threads in the simdgroup |                     // reduce the results from the threads in the simdgroup | ||||||
| @@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|                 // O = diag(ms)*O |                 // O = diag(ms)*O | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                 for (short ii = 0; ii < D4; ii += NW) { |                 for (short ii = 0; ii < D4; ii += NW) { | ||||||
|                     const short i = ii + tiisg; |                     lo[ii/NW] *= ms; | ||||||
|                     lo[i/NW] *= ms; |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|                     for (short ii = 0; ii < D4; ii += NW) { |                     for (short ii = 0; ii < D4; ii += NW) { | ||||||
|                         const short i = ii + tiisg; |                         const short i = ii + tiisg; | ||||||
|  |  | ||||||
|                         lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; |                         lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; | ||||||
|                         lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; |                         lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; | ||||||
|                         lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; |                         lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; | ||||||
|                         lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; |                         lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov