mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : separate scale and mask from QKT in FA kernel (#9189)
* metal : separate scale and mask from QKT in FA kernel * metal : ne01 check no longer necessary * metal : keep data in local memory
This commit is contained in:
		| @@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16( | ||||
|                     } | ||||
|  | ||||
|                     simdgroup_store(mqk, ss + 8*cc, TF, 0, false); | ||||
|  | ||||
|                     const short tx = tiisg%4; | ||||
|                     const short ty = tiisg/4; | ||||
|  | ||||
|                     // mqk = mqk*scale | ||||
|                     ss[8*cc + ty*TF + 2*tx + 0] *= scale; | ||||
|                     ss[8*cc + ty*TF + 2*tx + 1] *= scale; | ||||
|  | ||||
|                     if (logit_softcap != 0.0f) { | ||||
|                         ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); | ||||
|                         ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); | ||||
|                     } | ||||
|  | ||||
|                     if (mask != q) { | ||||
|                         // mqk = mqk + mask*slope | ||||
|                         ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; | ||||
|                         ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
| @@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16( | ||||
|                 float ms[Q]; | ||||
|  | ||||
|                 for (short j = 0; j < Q; ++j) { | ||||
|                     const short p = tiisg; | ||||
|  | ||||
|                     const float m = M[j]; | ||||
|                     const float s = ss[j*TF + p]; | ||||
|  | ||||
|                     // scale and apply the logitcap / mask | ||||
|                     float s = ss[j*TF + tiisg]*scale; | ||||
|  | ||||
|                     if (logit_softcap != 0.0f) { | ||||
|                         s = logit_softcap*precise::tanh(s); | ||||
|                     } | ||||
|  | ||||
|                     if (mask != q) { | ||||
|                         // mqk = mqk + mask*slope | ||||
|                         s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; | ||||
|                     } | ||||
|  | ||||
|                     smax = simd_max(max(smax, s)); | ||||
|                     M[j] = simd_max(max(M[j], s)); | ||||
| @@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16( | ||||
|                     S[j] = S[j]*ms[j] + simd_sum(vs); | ||||
|  | ||||
|                     // the P matrix from the paper (Q rows, C columns) | ||||
|                     ss[j*TF + p] = vs; | ||||
|                     ss[j*TF + tiisg] = vs; | ||||
|                 } | ||||
|  | ||||
|                 // create a QxQ diagonal matrix for rescaling the output | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov