mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	
							
								
								
									
										53
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -35,6 +35,10 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_MUL_ROW, |     GGML_METAL_KERNEL_TYPE_MUL_ROW, | ||||||
|     GGML_METAL_KERNEL_TYPE_DIV, |     GGML_METAL_KERNEL_TYPE_DIV, | ||||||
|     GGML_METAL_KERNEL_TYPE_DIV_ROW, |     GGML_METAL_KERNEL_TYPE_DIV_ROW, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_REPEAT_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_REPEAT_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_REPEAT_I32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_REPEAT_I16, | ||||||
|     GGML_METAL_KERNEL_TYPE_SCALE, |     GGML_METAL_KERNEL_TYPE_SCALE, | ||||||
|     GGML_METAL_KERNEL_TYPE_SCALE_4, |     GGML_METAL_KERNEL_TYPE_SCALE_4, | ||||||
|     GGML_METAL_KERNEL_TYPE_CLAMP, |     GGML_METAL_KERNEL_TYPE_CLAMP, | ||||||
| @@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true); | ||||||
| @@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |||||||
|         case GGML_OP_ACC: |         case GGML_OP_ACC: | ||||||
|         case GGML_OP_MUL: |         case GGML_OP_MUL: | ||||||
|         case GGML_OP_DIV: |         case GGML_OP_DIV: | ||||||
|  |         case GGML_OP_REPEAT: | ||||||
|         case GGML_OP_SCALE: |         case GGML_OP_SCALE: | ||||||
|         case GGML_OP_CLAMP: |         case GGML_OP_CLAMP: | ||||||
|         case GGML_OP_SQR: |         case GGML_OP_SQR: | ||||||
| @@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|             switch (dst->op) { |             switch (dst->op) { | ||||||
|                 case GGML_OP_CONCAT: |                 case GGML_OP_CONCAT: | ||||||
|                     { |                     { | ||||||
|                         const int64_t nb = ne00; |  | ||||||
|  |  | ||||||
|                         id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; |                         id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; | ||||||
|  |  | ||||||
|                         [encoder setComputePipelineState:pipeline]; |                         [encoder setComputePipelineState:pipeline]; | ||||||
| @@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24]; |                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24]; | ||||||
|                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25]; |                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25]; | ||||||
|                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26]; |                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26]; | ||||||
|                         [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27]; |  | ||||||
|  |  | ||||||
|                         const int nth = MIN(1024, ne0); |                         const int nth = MIN(1024, ne0); | ||||||
|  |  | ||||||
| @@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                 case GGML_OP_MUL: |                 case GGML_OP_MUL: | ||||||
|                 case GGML_OP_DIV: |                 case GGML_OP_DIV: | ||||||
|                     { |                     { | ||||||
|  |                         GGML_ASSERT(src0t == GGML_TYPE_F32); | ||||||
|  |                         GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|                         const size_t offs = 0; |                         const size_t offs = 0; | ||||||
|  |  | ||||||
|                         bool bcast_row = false; |                         bool bcast_row = false; | ||||||
|  |  | ||||||
|                         int64_t nb = ne00; |                         int64_t nb = ne00; // used by the "row" kernels | ||||||
|  |  | ||||||
|                         id<MTLComputePipelineState> pipeline = nil; |                         id<MTLComputePipelineState> pipeline = nil; | ||||||
|  |  | ||||||
| @@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; |                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||||
|                         } |                         } | ||||||
|                     } break; |                     } break; | ||||||
|  |                 case GGML_OP_REPEAT: | ||||||
|  |                     { | ||||||
|  |                         id<MTLComputePipelineState> pipeline; | ||||||
|  |  | ||||||
|  |                         switch (src0t) { | ||||||
|  |                             case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; | ||||||
|  |                             case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; | ||||||
|  |                             case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; | ||||||
|  |                             case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; | ||||||
|  |                             default: GGML_ASSERT(false); | ||||||
|  |                         } | ||||||
|  |  | ||||||
|  |                         [encoder setComputePipelineState:pipeline]; | ||||||
|  |                         [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|  |                         [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||||
|  |                         [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; | ||||||
|  |                         [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; | ||||||
|  |                         [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; | ||||||
|  |                         [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; | ||||||
|  |                         [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; | ||||||
|  |                         [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; | ||||||
|  |                         [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; | ||||||
|  |                         [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; | ||||||
|  |                         [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10]; | ||||||
|  |                         [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11]; | ||||||
|  |                         [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12]; | ||||||
|  |                         [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13]; | ||||||
|  |                         [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14]; | ||||||
|  |                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15]; | ||||||
|  |                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16]; | ||||||
|  |                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17]; | ||||||
|  |  | ||||||
|  |                         const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); | ||||||
|  |  | ||||||
|  |                         [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||||
|  |                     } break; | ||||||
|                 case GGML_OP_ACC: |                 case GGML_OP_ACC: | ||||||
|                     { |                     { | ||||||
|                         GGML_ASSERT(src0t == GGML_TYPE_F32); |                         GGML_ASSERT(src0t == GGML_TYPE_F32); | ||||||
|   | |||||||
| @@ -168,6 +168,53 @@ kernel void kernel_div( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<typename T> | ||||||
|  | kernel void kernel_repeat( | ||||||
|  |         device const char * src0, | ||||||
|  |         device       char * dst, | ||||||
|  |         constant  int64_t & ne00, | ||||||
|  |         constant  int64_t & ne01, | ||||||
|  |         constant  int64_t & ne02, | ||||||
|  |         constant  int64_t & ne03, | ||||||
|  |         constant uint64_t & nb00, | ||||||
|  |         constant uint64_t & nb01, | ||||||
|  |         constant uint64_t & nb02, | ||||||
|  |         constant uint64_t & nb03, | ||||||
|  |         constant  int64_t & ne0, | ||||||
|  |         constant  int64_t & ne1, | ||||||
|  |         constant  int64_t & ne2, | ||||||
|  |         constant  int64_t & ne3, | ||||||
|  |         constant uint64_t & nb0, | ||||||
|  |         constant uint64_t & nb1, | ||||||
|  |         constant uint64_t & nb2, | ||||||
|  |         constant uint64_t & nb3, | ||||||
|  |         uint3 tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         uint3 tpitg[[thread_position_in_threadgroup]], | ||||||
|  |         uint3   ntg[[threads_per_threadgroup]]) { | ||||||
|  |     const int64_t i3 = tgpig.z; | ||||||
|  |     const int64_t i2 = tgpig.y; | ||||||
|  |     const int64_t i1 = tgpig.x; | ||||||
|  |  | ||||||
|  |     const int64_t i03 = i3 % ne03; | ||||||
|  |     const int64_t i02 = i2 % ne02; | ||||||
|  |     const int64_t i01 = i1 % ne01; | ||||||
|  |  | ||||||
|  |     device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; | ||||||
|  |     device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ; | ||||||
|  |  | ||||||
|  |     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { | ||||||
|  |         const int i00 = i0 % ne00; | ||||||
|  |         *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | typedef decltype(kernel_repeat<float>) kernel_repeat_t; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>; | ||||||
|  | template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>; | ||||||
|  | template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; | ||||||
|  | template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; | ||||||
|  |  | ||||||
| // assumption: src1 is a row | // assumption: src1 is a row | ||||||
| // broadcast src1 into src0 | // broadcast src1 into src0 | ||||||
| kernel void kernel_add_row( | kernel void kernel_add_row( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov