mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : add GGML_OP_CONV_TRANSPOSE_1D kernels (ggml/1026)
				
					
				
			* wip * wip implementation f32 * kernel conv transpose 1d f32 working * initial commit
This commit is contained in:
		| @@ -306,6 +306,8 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_UPSCALE_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_PAD_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_ARANGE_F32, | ||||
| @@ -870,6 +872,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,     conv_transpose_1d_f32_f32,      true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,     conv_transpose_1d_f16_f32,      true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true); | ||||
| @@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|         case GGML_OP_REPEAT: | ||||
|         case GGML_OP_SCALE: | ||||
|         case GGML_OP_CLAMP: | ||||
|         case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|             return true; | ||||
|         case GGML_OP_SQR: | ||||
|         case GGML_OP_SQRT: | ||||
| @@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node( | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; | ||||
|                 } | ||||
|             } break; | ||||
|         case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|             { | ||||
|                 GGML_ASSERT(ggml_is_contiguous(src0)); | ||||
|                 GGML_ASSERT(ggml_is_contiguous(src1)); | ||||
|                 GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); | ||||
|                 GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||
|                 GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|                 const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; | ||||
|  | ||||
|                 const int32_t IC = src1->ne[1]; | ||||
|                 const int32_t IL = src1->ne[0]; | ||||
|  | ||||
|                 const int32_t K  = src0->ne[0]; | ||||
|  | ||||
|                 const int32_t OL = dst->ne[0]; | ||||
|                 const int32_t OC = dst->ne[1]; | ||||
|  | ||||
|                 id<MTLComputePipelineState> pipeline; | ||||
|  | ||||
|                 switch (src0->type) { | ||||
|                     case GGML_TYPE_F32: { | ||||
|                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; | ||||
|                     } break; | ||||
|                     case GGML_TYPE_F16: { | ||||
|                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; | ||||
|                     } break; | ||||
|                     default: GGML_ABORT("fatal error"); | ||||
|                 }; | ||||
|  | ||||
|                 [encoder setComputePipelineState:pipeline]; | ||||
|                 [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0]; | ||||
|                 [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1]; | ||||
|                 [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2]; | ||||
|                 [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3]; | ||||
|                 [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4]; | ||||
|                 [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5]; | ||||
|                 [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6]; | ||||
|                 [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7]; | ||||
|                 [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8]; | ||||
|  | ||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|             } break; | ||||
|         case GGML_OP_UPSCALE: | ||||
|             { | ||||
|                 GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||
|   | ||||
| @@ -2671,6 +2671,79 @@ kernel void kernel_im2col_ext( | ||||
| template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; | ||||
| template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; | ||||
|  | ||||
| typedef void (conv_transpose_1d_t)( | ||||
|         device const float * src0, | ||||
|         device const float * src1, | ||||
|         device        char * dst, | ||||
|         constant   int32_t & IC, | ||||
|         constant   int32_t & IL, | ||||
|         constant   int32_t & K, | ||||
|         constant   int32_t & s0, | ||||
|         constant  uint64_t & nb0, | ||||
|         constant  uint64_t & nb1, | ||||
|         uint3   tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3    tgpg[[threadgroups_per_grid]]); | ||||
|  | ||||
| template <typename T> | ||||
| kernel void kernel_conv_transpose_1d( | ||||
|         device const     T * src0, | ||||
|         device const float * src1, | ||||
|         device        char * dst, | ||||
|         constant   int32_t & IC, | ||||
|         constant   int32_t & IL, | ||||
|         constant   int32_t & K, | ||||
|         constant   int32_t & s0, | ||||
|         constant  uint64_t & nb0, | ||||
|         constant  uint64_t & nb1, | ||||
|         uint3   tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3   tgpg[[threadgroups_per_grid]]) { | ||||
|  | ||||
|     float v = 0.0f; | ||||
|  | ||||
|     for (int64_t c = 0; c < IC; c++) { | ||||
|         const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; | ||||
|         const int32_t input_offset = c * IL; | ||||
|  | ||||
|         for (int64_t i = 0; i < IL; i++) { | ||||
|             if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { | ||||
|                 v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); | ||||
|  | ||||
|     dst_ptr[0] = v; | ||||
| } | ||||
|  | ||||
| template [[host_name("kernel_conv_transpose_1d_f32_f32")]] | ||||
| kernel void kernel_conv_transpose_1d<float>( | ||||
|     device const float * src0, | ||||
|     device const float * src1, | ||||
|     device        char * dst, | ||||
|     constant   int32_t & IC, | ||||
|     constant   int32_t & IL, | ||||
|     constant   int32_t & K, | ||||
|     constant   int32_t & s0, | ||||
|     constant  uint64_t & nb0, | ||||
|     constant  uint64_t & nb1, | ||||
|     uint3   tgpig[[threadgroup_position_in_grid]], | ||||
|     uint3    tgpg[[threadgroups_per_grid]]); | ||||
|  | ||||
| template [[host_name("kernel_conv_transpose_1d_f16_f32")]] | ||||
| kernel void kernel_conv_transpose_1d<half>( | ||||
|     device const half  * src0, | ||||
|     device const float * src1, | ||||
|     device        char * dst, | ||||
|     constant   int32_t & IC, | ||||
|     constant   int32_t & IL, | ||||
|     constant   int32_t & K, | ||||
|     constant   int32_t & s0, | ||||
|     constant  uint64_t & nb0, | ||||
|     constant  uint64_t & nb1, | ||||
|     uint3   tgpig[[threadgroup_position_in_grid]], | ||||
|     uint3    tgpg[[threadgroups_per_grid]]); | ||||
|  | ||||
| kernel void kernel_upscale_f32( | ||||
|     device  const char * src0, | ||||
|     device        char * dst, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 PAB
					PAB