mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : gemma2 flash attention support (#9159)
This commit is contained in:
		| @@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx | |||||||
|             if (op->src[0]->ne[0] == 256) { |             if (op->src[0]->ne[0] == 256) { | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|             { |  | ||||||
|                 float logit_softcap; |  | ||||||
|  |  | ||||||
|                 memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap)); |  | ||||||
|  |  | ||||||
|                 if (logit_softcap != 0.0f) { |  | ||||||
|                     return false; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels |             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels | ||||||
|         case GGML_OP_MUL_MAT: |         case GGML_OP_MUL_MAT: | ||||||
|         case GGML_OP_MUL_MAT_ID: |         case GGML_OP_MUL_MAT_ID: | ||||||
| @@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|  |  | ||||||
|                         float scale; |                         float scale; | ||||||
|                         float max_bias; |                         float max_bias; | ||||||
|  |                         float logit_softcap; | ||||||
|  |                         memcpy(&scale,         ((int32_t *) dst->op_params) + 0, sizeof(scale)); | ||||||
|  |                         memcpy(&max_bias,      ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); | ||||||
|  |                         memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); | ||||||
|  |  | ||||||
|                         memcpy(&scale,    ((int32_t *) dst->op_params) + 0, sizeof(scale)); |                         if (logit_softcap != 0.0f) { | ||||||
|                         memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); |                             scale /= logit_softcap; | ||||||
|  |                         } | ||||||
|  |  | ||||||
|                         const uint32_t n_head      = src0->ne[2]; |                         const uint32_t n_head      = src0->ne[2]; | ||||||
|                         const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); |                         const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||||
| @@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                         } else { |                         } else { | ||||||
|                             [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3]; |                             [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3]; | ||||||
|                         } |                         } | ||||||
|                         [encoder setBuffer:id_dst      offset:offs_dst            atIndex:4]; |                         [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4]; | ||||||
|                         [encoder setBytes:&ne01        length:sizeof( int64_t)    atIndex:5]; |                         [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5]; | ||||||
|                         [encoder setBytes:&ne02        length:sizeof( int64_t)    atIndex:6]; |                         [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6]; | ||||||
|                         [encoder setBytes:&ne03        length:sizeof( int64_t)    atIndex:7]; |                         [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7]; | ||||||
|                         [encoder setBytes:&nb01        length:sizeof(uint64_t)    atIndex:8]; |                         [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8]; | ||||||
|                         [encoder setBytes:&nb02        length:sizeof(uint64_t)    atIndex:9]; |                         [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9]; | ||||||
|                         [encoder setBytes:&nb03        length:sizeof(uint64_t)    atIndex:10]; |                         [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10]; | ||||||
|                         [encoder setBytes:&ne11        length:sizeof( int64_t)    atIndex:11]; |                         [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11]; | ||||||
|                         [encoder setBytes:&ne12        length:sizeof( int64_t)    atIndex:12]; |                         [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12]; | ||||||
|                         [encoder setBytes:&ne13        length:sizeof( int64_t)    atIndex:13]; |                         [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13]; | ||||||
|                         [encoder setBytes:&nb11        length:sizeof(uint64_t)    atIndex:14]; |                         [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14]; | ||||||
|                         [encoder setBytes:&nb12        length:sizeof(uint64_t)    atIndex:15]; |                         [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15]; | ||||||
|                         [encoder setBytes:&nb13        length:sizeof(uint64_t)    atIndex:16]; |                         [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16]; | ||||||
|                         [encoder setBytes:&nb21        length:sizeof(uint64_t)    atIndex:17]; |                         [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17]; | ||||||
|                         [encoder setBytes:&nb22        length:sizeof(uint64_t)    atIndex:18]; |                         [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18]; | ||||||
|                         [encoder setBytes:&nb23        length:sizeof(uint64_t)    atIndex:19]; |                         [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19]; | ||||||
|                         [encoder setBytes:&nb31        length:sizeof(uint64_t)    atIndex:20]; |                         [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20]; | ||||||
|                         [encoder setBytes:&ne1         length:sizeof( int64_t)    atIndex:21]; |                         [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21]; | ||||||
|                         [encoder setBytes:&ne2         length:sizeof( int64_t)    atIndex:22]; |                         [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22]; | ||||||
|                         [encoder setBytes:&scale       length:sizeof(   float)    atIndex:23]; |                         [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23]; | ||||||
|                         [encoder setBytes:&max_bias    length:sizeof(   float)    atIndex:24]; |                         [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24]; | ||||||
|                         [encoder setBytes:&m0          length:sizeof(m0)          atIndex:25]; |                         [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25]; | ||||||
|                         [encoder setBytes:&m1          length:sizeof(m1)          atIndex:26]; |                         [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26]; | ||||||
|                         [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; |                         [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27]; | ||||||
|  |                         [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; | ||||||
|  |  | ||||||
|                         if (!use_vec_kernel) { |                         if (!use_vec_kernel) { | ||||||
|                             // half8x8 kernel |                             // half8x8 kernel | ||||||
|   | |||||||
| @@ -1976,6 +1976,7 @@ typedef void (flash_attn_ext_f16_t)( | |||||||
|         constant     float & m0, |         constant     float & m0, | ||||||
|         constant     float & m1, |         constant     float & m1, | ||||||
|         constant  uint32_t & n_head_log2, |         constant  uint32_t & n_head_log2, | ||||||
|  |         constant     float & logit_softcap, | ||||||
|         threadgroup   half * shared, |         threadgroup   half * shared, | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |         uint3  tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint3  tpitg[[thread_position_in_threadgroup]], |         uint3  tpitg[[thread_position_in_threadgroup]], | ||||||
| @@ -2014,6 +2015,7 @@ kernel void kernel_flash_attn_ext_f16( | |||||||
|         constant     float & m0, |         constant     float & m0, | ||||||
|         constant     float & m1, |         constant     float & m1, | ||||||
|         constant  uint32_t & n_head_log2, |         constant  uint32_t & n_head_log2, | ||||||
|  |         constant     float & logit_softcap, | ||||||
|         threadgroup   half * shared [[threadgroup(0)]], |         threadgroup   half * shared [[threadgroup(0)]], | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |         uint3  tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint3  tpitg[[thread_position_in_threadgroup]], |         uint3  tpitg[[thread_position_in_threadgroup]], | ||||||
| @@ -2142,14 +2144,19 @@ kernel void kernel_flash_attn_ext_f16( | |||||||
|                     const short tx = tiisg%4; |                     const short tx = tiisg%4; | ||||||
|                     const short ty = tiisg/4; |                     const short ty = tiisg/4; | ||||||
|  |  | ||||||
|  |                     // mqk = mqk*scale | ||||||
|  |                     ss[8*cc + ty*TF + 2*tx + 0] *= scale; | ||||||
|  |                     ss[8*cc + ty*TF + 2*tx + 1] *= scale; | ||||||
|  |  | ||||||
|  |                     if (logit_softcap != 0.0f) { | ||||||
|  |                         ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); | ||||||
|  |                         ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); | ||||||
|  |                     } | ||||||
|  |  | ||||||
|                     if (mask != q) { |                     if (mask != q) { | ||||||
|                         // mqk = mqk*scale + mask*slope |                         // mqk = mqk + mask*slope | ||||||
|                         ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; |                         ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; | ||||||
|                         ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; |                         ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; | ||||||
|                     } else { |  | ||||||
|                         // mqk = mqk*scale |  | ||||||
|                         ss[8*cc + ty*TF + 2*tx + 0] *= scale; |  | ||||||
|                         ss[8*cc + ty*TF + 2*tx + 1] *= scale; |  | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -2345,6 +2352,7 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|         constant     float & m0, |         constant     float & m0, | ||||||
|         constant     float & m1, |         constant     float & m1, | ||||||
|         constant  uint32_t & n_head_log2, |         constant  uint32_t & n_head_log2, | ||||||
|  |         constant     float & logit_softcap, | ||||||
|         threadgroup   half * shared [[threadgroup(0)]], |         threadgroup   half * shared [[threadgroup(0)]], | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |         uint3  tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint3  tpitg[[thread_position_in_threadgroup]], |         uint3  tpitg[[thread_position_in_threadgroup]], | ||||||
| @@ -2479,7 +2487,13 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
|  |  | ||||||
|                     // mqk = mqk*scale + mask*slope |                     // mqk = mqk*scale + mask*slope | ||||||
|                     if (tiisg == 0) { |                     if (tiisg == 0) { | ||||||
|                         mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f); |                         mqk *= scale; | ||||||
|  |  | ||||||
|  |                         if (logit_softcap != 0.0f) { | ||||||
|  |                             mqk = logit_softcap*precise::tanh(mqk); | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; | ||||||
|  |  | ||||||
|                         ss4[cc] = mqk; |                         ss4[cc] = mqk; | ||||||
|                     } |                     } | ||||||
|   | |||||||
| @@ -2487,7 +2487,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     GGML_ABORT("fatal error"); |     GGML_ABORT("fatal error"); | ||||||
|     return false; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| static void usage(char ** argv) { | static void usage(char ** argv) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren