mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
* Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
		
							
								
								
									
										33
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -718,7 +718,8 @@ void ggml_metal_graph_compute( | |||||||
|                             // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 |                             // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 | ||||||
|  |  | ||||||
|                             GGML_ASSERT(ne00 == ne10); |                             GGML_ASSERT(ne00 == ne10); | ||||||
|                             GGML_ASSERT(ne02 == ne12); |                             // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere | ||||||
|  |                             GGML_ASSERT(ne03 == ne13); | ||||||
|  |  | ||||||
|                             if (ggml_is_contiguous(src0) && |                             if (ggml_is_contiguous(src0) && | ||||||
|                                 ggml_is_contiguous(src1) && |                                 ggml_is_contiguous(src1) && | ||||||
| @@ -746,11 +747,11 @@ void ggml_metal_graph_compute( | |||||||
|                                     initWithDevice:ctx->device transposeLeft:false transposeRight:true |                                     initWithDevice:ctx->device transposeLeft:false transposeRight:true | ||||||
|                                         resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; |                                         resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; | ||||||
|  |  | ||||||
|                                 // we need to do ne02 multiplications |                                 // we need to do ne12 multiplications | ||||||
|                                 // TODO: is there a way to do this in parallel - currently very slow .. |                                 // TODO: is there a way to do this in parallel - currently very slow .. | ||||||
|                                 // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS |                                 // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS | ||||||
|                                 for (int64_t i02 = 0; i02 < ne02; ++i02) { |                                 for (int64_t i02 = 0; i02 < ne12; ++i02) { | ||||||
|                                     size_t offs_src0_cur = offs_src0 + i02*nb02; |                                     size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now | ||||||
|                                     size_t offs_src1_cur = offs_src1 + i02*nb12; |                                     size_t offs_src1_cur = offs_src1 + i02*nb12; | ||||||
|                                     size_t offs_dst_cur  = offs_dst  + i02*nb2; |                                     size_t offs_dst_cur  = offs_dst  + i02*nb2; | ||||||
|  |  | ||||||
| @@ -772,8 +773,6 @@ void ggml_metal_graph_compute( | |||||||
|                                 switch (src0t) { |                                 switch (src0t) { | ||||||
|                                     case GGML_TYPE_F16: |                                     case GGML_TYPE_F16: | ||||||
|                                         { |                                         { | ||||||
|                                             GGML_ASSERT(ne02 == ne12); |  | ||||||
|  |  | ||||||
|                                             nth0 = 64; |                                             nth0 = 64; | ||||||
|                                             nth1 = 1; |                                             nth1 = 1; | ||||||
|                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; |                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; | ||||||
| @@ -853,16 +852,18 @@ void ggml_metal_graph_compute( | |||||||
|                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|                                 [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |                                 [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; | ||||||
|                                 [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |                                 [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; | ||||||
|                                 [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; |                                 [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; | ||||||
|                                 [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; |                                 [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; | ||||||
|                                 [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; |                                 [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; | ||||||
|                                 [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; |                                 [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; | ||||||
|                                 [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; |                                 [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; | ||||||
|                                 [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; |                                 [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; | ||||||
|                                 [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; |                                 [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; | ||||||
|                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; |                                 [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; | ||||||
|                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13]; |                                 [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; | ||||||
|                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14]; |                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; | ||||||
|  |                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15]; | ||||||
|  |                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16]; | ||||||
|  |  | ||||||
|                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || |                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || | ||||||
|                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { |                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { | ||||||
|   | |||||||
| @@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32( | |||||||
|         device       float * dst, |         device       float * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & ne00, | ||||||
|         constant   int64_t & ne01, |         constant   int64_t & ne01, | ||||||
|  |         constant   int64_t & ne02, | ||||||
|         constant  uint64_t & nb00, |         constant  uint64_t & nb00, | ||||||
|         constant  uint64_t & nb01, |         constant  uint64_t & nb01, | ||||||
|         constant  uint64_t & nb02, |         constant  uint64_t & nb02, | ||||||
|         constant   int64_t & ne10, |         constant   int64_t & ne10, | ||||||
|         constant   int64_t & ne11, |         constant   int64_t & ne11, | ||||||
|  |         constant   int64_t & ne12, | ||||||
|         constant  uint64_t & nb10, |         constant  uint64_t & nb10, | ||||||
|         constant  uint64_t & nb11, |         constant  uint64_t & nb11, | ||||||
|         constant  uint64_t & nb12, |         constant  uint64_t & nb12, | ||||||
| @@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32( | |||||||
|     const int64_t r1 = tgpig.y; |     const int64_t r1 = tgpig.y; | ||||||
|     const int64_t im = tgpig.z; |     const int64_t im = tgpig.z; | ||||||
|  |  | ||||||
|     device const half  * x = (device const half  *) (src0 + r0*nb01 + im*nb02); |     device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); | ||||||
|     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|     sum[tpitg.x] = 0.0f; |     sum[tpitg.x] = 0.0f; | ||||||
| @@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| kernel void kernel_alibi_f32( | kernel void kernel_alibi_f32( | ||||||
|         device const float * src0, |         device const float * src0, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Matteo Boschini
					Matteo Boschini