mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : fix fa kernel
This commit is contained in:
		| @@ -2144,19 +2144,26 @@ kernel void kernel_flash_attn_ext_f16( | ||||
|                     const short tx = tiisg%4; | ||||
|                     const short ty = tiisg/4; | ||||
|  | ||||
|                     // mqk = mqk*scale | ||||
|                     ss[8*cc + ty*TF + 2*tx + 0] *= scale; | ||||
|                     ss[8*cc + ty*TF + 2*tx + 1] *= scale; | ||||
|  | ||||
|                     if (logit_softcap != 0.0f) { | ||||
|                         ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); | ||||
|                         ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); | ||||
|                     } | ||||
|  | ||||
|                     if (mask != q) { | ||||
|                         // mqk = mqk + mask*slope | ||||
|                         ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; | ||||
|                         ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; | ||||
|                     if (logit_softcap == 0.0f) { | ||||
|                         if (mask != q) { | ||||
|                             // mqk = mqk*scale + mask*slope | ||||
|                             ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; | ||||
|                             ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; | ||||
|                         } else { | ||||
|                             // mqk = mqk*scale | ||||
|                             ss[8*cc + ty*TF + 2*tx + 0] *= scale; | ||||
|                             ss[8*cc + ty*TF + 2*tx + 1] *= scale; | ||||
|                         } | ||||
|                     } else { | ||||
|                         if (mask != q) { | ||||
|                             // mqk = ls*tanh(mqk*scale) + mask*slope | ||||
|                             ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; | ||||
|                             ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; | ||||
|                         } else { | ||||
|                             // mqk = ls*tanh(mqk*scale) | ||||
|                             ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]); | ||||
|                             ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov