mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : add mean kernel (#14267)
* metal : add mean kernel ggml-ci * cont : dedup implementation ggml-ci
This commit is contained in:
		| @@ -498,6 +498,7 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_COS, | ||||
|     GGML_METAL_KERNEL_TYPE_NEG, | ||||
|     GGML_METAL_KERNEL_TYPE_SUM_ROWS, | ||||
|     GGML_METAL_KERNEL_TYPE_MEAN, | ||||
|     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_ARGMAX, | ||||
| @@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true); | ||||
| @@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|         case GGML_OP_LOG: | ||||
|             return false; // TODO: implement | ||||
|         case GGML_OP_SUM_ROWS: | ||||
|         case GGML_OP_MEAN: | ||||
|         case GGML_OP_SOFT_MAX: | ||||
|         case GGML_OP_GROUP_NORM: | ||||
|             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); | ||||
| @@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node( | ||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|             } break; | ||||
|         case GGML_OP_SUM_ROWS: | ||||
|         case GGML_OP_MEAN: | ||||
|             { | ||||
|                 GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); | ||||
|  | ||||
|                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; | ||||
|                 id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                 switch (dst->op) { | ||||
|                     case GGML_OP_SUM_ROWS: | ||||
|                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; | ||||
|                         break; | ||||
|                     case GGML_OP_MEAN: | ||||
|                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline; | ||||
|                         break; | ||||
|                     default: | ||||
|                         GGML_ABORT("fatal error"); | ||||
|                 } | ||||
|  | ||||
|                 int nth = 32; // SIMD width | ||||
|  | ||||
|                 while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { | ||||
|                     nth *= 2; | ||||
|                 } | ||||
|  | ||||
|                 nth = MIN(nth, ne00); | ||||
|  | ||||
|                 ggml_metal_kargs_sum_rows args = { | ||||
|                    /*.ne00 =*/ ne00, | ||||
| @@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node( | ||||
|                 }; | ||||
|  | ||||
|                 [encoder setComputePipelineState:pipeline]; | ||||
|                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||
|                 [encoder setBytes:&args length:sizeof(args) atIndex:2]; | ||||
|                 [encoder setBytes:&args length:sizeof(args) atIndex:0]; | ||||
|                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; | ||||
|                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||
|                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; | ||||
|  | ||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|             } break; | ||||
|         case GGML_OP_SOFT_MAX: | ||||
|             { | ||||
|   | ||||
| @@ -993,31 +993,61 @@ kernel void kernel_neg( | ||||
|     dst[tpig] = -src0[tpig]; | ||||
| } | ||||
|  | ||||
| template <bool norm> | ||||
| kernel void kernel_sum_rows( | ||||
|         constant ggml_metal_kargs_sum_rows & args, | ||||
|         device const float * src0, | ||||
|         device       float * dst, | ||||
|         constant ggml_metal_kargs_sum_rows & args, | ||||
|         uint3 tpig[[thread_position_in_grid]]) { | ||||
|     int64_t i3 = tpig.z; | ||||
|     int64_t i2 = tpig.y; | ||||
|     int64_t i1 = tpig.x; | ||||
|         threadgroup  float * shmem_f32 [[threadgroup(0)]], | ||||
|         uint3   tgpig[[threadgroup_position_in_grid]], | ||||
|         ushort3 tpitg[[thread_position_in_threadgroup]], | ||||
|         ushort  sgitg[[simdgroup_index_in_threadgroup]], | ||||
|         ushort  tiisg[[thread_index_in_simdgroup]], | ||||
|         ushort3   ntg[[threads_per_threadgroup]]) { | ||||
|     int64_t i3 = tgpig.z; | ||||
|     int64_t i2 = tgpig.y; | ||||
|     int64_t i1 = tgpig.x; | ||||
|  | ||||
|     if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     if (sgitg == 0) { | ||||
|         shmem_f32[tiisg] = 0.0f; | ||||
|     } | ||||
|  | ||||
|     device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); | ||||
|     device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3); | ||||
|  | ||||
|     float row_sum = 0; | ||||
|     float sumf = 0; | ||||
|  | ||||
|     for (int64_t i0 = 0; i0 < args.ne00; i0++) { | ||||
|         row_sum += src_row[i0]; | ||||
|     for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { | ||||
|         sumf += src_row[i0]; | ||||
|     } | ||||
|  | ||||
|     dst_row[0] = row_sum; | ||||
|     sumf = simd_sum(sumf); | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     if (tiisg == 0) { | ||||
|         shmem_f32[sgitg] = sumf; | ||||
|     } | ||||
|  | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|     sumf = shmem_f32[tiisg]; | ||||
|     sumf = simd_sum(sumf); | ||||
|  | ||||
|     if (tpitg.x == 0) { | ||||
|         dst_row[0] = norm ? sumf / args.ne00 : sumf; | ||||
|     } | ||||
| } | ||||
|  | ||||
| typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t; | ||||
|  | ||||
| template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>; | ||||
| template [[host_name("kernel_mean")]]     kernel kernel_sum_rows_t kernel_sum_rows<true>; | ||||
|  | ||||
| template<typename T> | ||||
| kernel void kernel_soft_max( | ||||
|         device const  char * src0, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov