mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16( | ||||
|         const short iv3 = iq3 / rv3; | ||||
|  | ||||
|         // load the queries from shared memory into local memory | ||||
|         half4 mq[D4]; | ||||
|         float4 mq[D4]; | ||||
|  | ||||
|         for (short ii = 0; ii < D4; ii += NW) { | ||||
|             short i = ii + tiisg; | ||||
|             mq[i] = sq4[i]; | ||||
|             mq[i] = (float4) sq4[i]; | ||||
|         } | ||||
|  | ||||
|         // pointer to the mask | ||||
| @@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16( | ||||
|                     for (short ii = 0; ii < D4; ii += NW) { | ||||
|                         const short i = ii + tiisg; | ||||
|  | ||||
|                         half4x4 mk; | ||||
|                         mk[0] = pk4[i + 0*(nb11/8)]; | ||||
|                         mk[1] = pk4[i + 1*(nb11/8)]; | ||||
|                         mk[2] = pk4[i + 2*(nb11/8)]; | ||||
|                         mk[3] = pk4[i + 3*(nb11/8)]; | ||||
|                         float4x4 mk; | ||||
|                         mk[0] = (float4) pk4[i + 0*(nb11/8)]; | ||||
|                         mk[1] = (float4) pk4[i + 1*(nb11/8)]; | ||||
|                         mk[2] = (float4) pk4[i + 2*(nb11/8)]; | ||||
|                         mk[3] = (float4) pk4[i + 3*(nb11/8)]; | ||||
|  | ||||
|                         mqk += (float4) (mq[i] * mk); | ||||
|                     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov