mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : add im2col F32 dst support (#5132)
This commit is contained in:
		
							
								
								
									
										13
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -135,6 +135,7 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_ROPE_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_ALIBI_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_UPSCALE_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_PAD_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, | ||||
| @@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true); | ||||
| @@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | ||||
|         case GGML_OP_ALIBI: | ||||
|         case GGML_OP_ROPE: | ||||
|         case GGML_OP_IM2COL: | ||||
|             return true; | ||||
|         case GGML_OP_POOL_1D: | ||||
|         case GGML_OP_POOL_2D: | ||||
|             return false; | ||||
|         case GGML_OP_UPSCALE: | ||||
|         case GGML_OP_PAD: | ||||
|         case GGML_OP_ARGSORT: | ||||
| @@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute( | ||||
|                     { | ||||
|                         GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||
|                         GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||
|                         GGML_ASSERT( dst->type == GGML_TYPE_F16); | ||||
|                         GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|                         const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; | ||||
|                         const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; | ||||
| @@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute( | ||||
|                         const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; | ||||
|                         const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; | ||||
|                         const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; | ||||
|  | ||||
|                         const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; | ||||
|  | ||||
|                         const int32_t N  = src1->ne[is_2D ? 3 : 2]; | ||||
| @@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute( | ||||
|  | ||||
|                         id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                         switch (src0->type) { | ||||
|                             case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; | ||||
|                         switch (dst->type) { | ||||
|                             case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; | ||||
|                             case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; | ||||
|                             default: GGML_ASSERT(false); | ||||
|                         }; | ||||
|   | ||||
| @@ -1775,9 +1775,29 @@ kernel void kernel_rope( | ||||
| template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>; | ||||
| template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>; | ||||
|  | ||||
| kernel void kernel_im2col_f16( | ||||
| typedef void (im2col_t)( | ||||
|         device const float * x, | ||||
|         device       half * dst, | ||||
|         device        char * dst, | ||||
|         constant   int32_t & ofs0, | ||||
|         constant   int32_t & ofs1, | ||||
|         constant   int32_t & IW, | ||||
|         constant   int32_t & IH, | ||||
|         constant   int32_t & CHW, | ||||
|         constant   int32_t & s0, | ||||
|         constant   int32_t & s1, | ||||
|         constant   int32_t & p0, | ||||
|         constant   int32_t & p1, | ||||
|         constant   int32_t & d0, | ||||
|         constant   int32_t & d1, | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3  tgpg[[threadgroups_per_grid]], | ||||
|         uint3 tpitg[[thread_position_in_threadgroup]], | ||||
|         uint3   ntg[[threads_per_threadgroup]]); | ||||
|  | ||||
| template <typename T> | ||||
| kernel void kernel_im2col( | ||||
|         device const float * x, | ||||
|         device        char * dst, | ||||
|         constant   int32_t & ofs0, | ||||
|         constant   int32_t & ofs1, | ||||
|         constant   int32_t & IW, | ||||
| @@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16( | ||||
|         (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + | ||||
|         (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); | ||||
|  | ||||
|     device T * pdst = (device T *) (dst); | ||||
|  | ||||
|     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { | ||||
|         dst[offset_dst] = 0.0f; | ||||
|         pdst[offset_dst] = 0.0f; | ||||
|     } else { | ||||
|         const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; | ||||
|         dst[offset_dst] = x[offset_src + iih * IW + iiw]; | ||||
|         pdst[offset_dst] = x[offset_src + iih * IW + iiw]; | ||||
|     } | ||||
| } | ||||
|  | ||||
| template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; | ||||
| template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; | ||||
|  | ||||
| kernel void kernel_upscale_f32( | ||||
|     device  const char * src0, | ||||
|     device        char * dst, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov