mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : support bcast add & dup & cont op (#2323)
This commit is contained in:
		
							
								
								
									
										10
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -42,6 +42,7 @@ struct ggml_metal_context { | |||||||
|     id<MTLComputePipelineState> pipeline_##name |     id<MTLComputePipelineState> pipeline_##name | ||||||
|  |  | ||||||
|     GGML_METAL_DECL_KERNEL(add); |     GGML_METAL_DECL_KERNEL(add); | ||||||
|  |     GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast | ||||||
|     GGML_METAL_DECL_KERNEL(mul); |     GGML_METAL_DECL_KERNEL(mul); | ||||||
|     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast |     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast | ||||||
|     GGML_METAL_DECL_KERNEL(scale); |     GGML_METAL_DECL_KERNEL(scale); | ||||||
| @@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); |         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); | ||||||
|  |  | ||||||
|         GGML_METAL_ADD_KERNEL(add); |         GGML_METAL_ADD_KERNEL(add); | ||||||
|  |         GGML_METAL_ADD_KERNEL(add_row); | ||||||
|         GGML_METAL_ADD_KERNEL(mul); |         GGML_METAL_ADD_KERNEL(mul); | ||||||
|         GGML_METAL_ADD_KERNEL(mul_row); |         GGML_METAL_ADD_KERNEL(mul_row); | ||||||
|         GGML_METAL_ADD_KERNEL(scale); |         GGML_METAL_ADD_KERNEL(scale); | ||||||
| @@ -464,10 +466,16 @@ void ggml_metal_graph_compute( | |||||||
|                                 encoder = [command_buffer computeCommandEncoder]; |                                 encoder = [command_buffer computeCommandEncoder]; | ||||||
|                             } |                             } | ||||||
|  |  | ||||||
|  |                             if (ggml_nelements(src1) == ne10) { | ||||||
|  |                                 // src1 is a row | ||||||
|  |                                 [encoder setComputePipelineState:ctx->pipeline_add_row]; | ||||||
|  |                             } else { | ||||||
|                                 [encoder setComputePipelineState:ctx->pipeline_add]; |                                 [encoder setComputePipelineState:ctx->pipeline_add]; | ||||||
|  |                             } | ||||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|  |                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; | ||||||
|  |  | ||||||
|                             const int64_t n = ggml_nelements(dst); |                             const int64_t n = ggml_nelements(dst); | ||||||
|  |  | ||||||
| @@ -919,7 +927,9 @@ void ggml_metal_graph_compute( | |||||||
|  |  | ||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||||
|                         } break; |                         } break; | ||||||
|  |                     case GGML_OP_DUP: | ||||||
|                     case GGML_OP_CPY: |                     case GGML_OP_CPY: | ||||||
|  |                     case GGML_OP_CONT: | ||||||
|                         { |                         { | ||||||
|                             if (encoder == nil) { |                             if (encoder == nil) { | ||||||
|                                 encoder = [command_buffer computeCommandEncoder]; |                                 encoder = [command_buffer computeCommandEncoder]; | ||||||
|   | |||||||
| @@ -67,6 +67,17 @@ kernel void kernel_add( | |||||||
|     dst[tpig] = src0[tpig] + src1[tpig]; |     dst[tpig] = src0[tpig] + src1[tpig]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // assumption: src1 is a row | ||||||
|  | // broadcast src1 into src0 | ||||||
|  | kernel void kernel_add_row( | ||||||
|  |         device const float * src0, | ||||||
|  |         device const float * src1, | ||||||
|  |         device       float * dst, | ||||||
|  |         constant   int64_t & ne00, | ||||||
|  |         uint tpig[[thread_position_in_grid]]) { | ||||||
|  |     dst[tpig] = src0[tpig] + src1[tpig % ne00]; | ||||||
|  | } | ||||||
|  |  | ||||||
| kernel void kernel_mul( | kernel void kernel_mul( | ||||||
|         device const float * src0, |         device const float * src0, | ||||||
|         device const float * src1, |         device const float * src1, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jiahao Li
					Jiahao Li