mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : add GGML_OP_CONV_TRANSPOSE_1D kernels (ggml/1026)
				
					
				
			* wip * wip implementation f32 * kernel conv transpose 1d f32 working * initial commit
This commit is contained in:
		@@ -306,6 +306,8 @@ enum ggml_metal_kernel_type {
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_IM2COL_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_PAD_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
 | 
			
		||||
@@ -870,6 +872,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,     conv_transpose_1d_f32_f32,      true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,     conv_transpose_1d_f16_f32,      true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
 | 
			
		||||
@@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 | 
			
		||||
        case GGML_OP_REPEAT:
 | 
			
		||||
        case GGML_OP_SCALE:
 | 
			
		||||
        case GGML_OP_CLAMP:
 | 
			
		||||
        case GGML_OP_CONV_TRANSPOSE_1D:
 | 
			
		||||
            return true;
 | 
			
		||||
        case GGML_OP_SQR:
 | 
			
		||||
        case GGML_OP_SQRT:
 | 
			
		||||
@@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node(
 | 
			
		||||
                    [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
 | 
			
		||||
                }
 | 
			
		||||
            } break;
 | 
			
		||||
        case GGML_OP_CONV_TRANSPOSE_1D:
 | 
			
		||||
            {
 | 
			
		||||
                GGML_ASSERT(ggml_is_contiguous(src0));
 | 
			
		||||
                GGML_ASSERT(ggml_is_contiguous(src1));
 | 
			
		||||
                GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
 | 
			
		||||
                GGML_ASSERT(src1->type == GGML_TYPE_F32);
 | 
			
		||||
                GGML_ASSERT( dst->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
                const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 | 
			
		||||
 | 
			
		||||
                const int32_t IC = src1->ne[1];
 | 
			
		||||
                const int32_t IL = src1->ne[0];
 | 
			
		||||
 | 
			
		||||
                const int32_t K  = src0->ne[0];
 | 
			
		||||
 | 
			
		||||
                const int32_t OL = dst->ne[0];
 | 
			
		||||
                const int32_t OC = dst->ne[1];
 | 
			
		||||
 | 
			
		||||
                id<MTLComputePipelineState> pipeline;
 | 
			
		||||
 | 
			
		||||
                switch (src0->type) {
 | 
			
		||||
                    case GGML_TYPE_F32: {
 | 
			
		||||
                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
 | 
			
		||||
                    } break;
 | 
			
		||||
                    case GGML_TYPE_F16: {
 | 
			
		||||
                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
 | 
			
		||||
                    } break;
 | 
			
		||||
                    default: GGML_ABORT("fatal error");
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                [encoder setComputePipelineState:pipeline];
 | 
			
		||||
                [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0];
 | 
			
		||||
                [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1];
 | 
			
		||||
                [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2];
 | 
			
		||||
                [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3];
 | 
			
		||||
                [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4];
 | 
			
		||||
                [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5];
 | 
			
		||||
                [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6];
 | 
			
		||||
                [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7];
 | 
			
		||||
                [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8];
 | 
			
		||||
 | 
			
		||||
                [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 | 
			
		||||
            } break;
 | 
			
		||||
        case GGML_OP_UPSCALE:
 | 
			
		||||
            {
 | 
			
		||||
                GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
 
 | 
			
		||||
@@ -2671,6 +2671,79 @@ kernel void kernel_im2col_ext(
 | 
			
		||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
 | 
			
		||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
 | 
			
		||||
 | 
			
		||||
typedef void (conv_transpose_1d_t)(
 | 
			
		||||
        device const float * src0,
 | 
			
		||||
        device const float * src1,
 | 
			
		||||
        device        char * dst,
 | 
			
		||||
        constant   int32_t & IC,
 | 
			
		||||
        constant   int32_t & IL,
 | 
			
		||||
        constant   int32_t & K,
 | 
			
		||||
        constant   int32_t & s0,
 | 
			
		||||
        constant  uint64_t & nb0,
 | 
			
		||||
        constant  uint64_t & nb1,
 | 
			
		||||
        uint3   tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
        uint3    tgpg[[threadgroups_per_grid]]);
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void kernel_conv_transpose_1d(
 | 
			
		||||
        device const     T * src0,
 | 
			
		||||
        device const float * src1,
 | 
			
		||||
        device        char * dst,
 | 
			
		||||
        constant   int32_t & IC,
 | 
			
		||||
        constant   int32_t & IL,
 | 
			
		||||
        constant   int32_t & K,
 | 
			
		||||
        constant   int32_t & s0,
 | 
			
		||||
        constant  uint64_t & nb0,
 | 
			
		||||
        constant  uint64_t & nb1,
 | 
			
		||||
        uint3   tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
        uint3   tgpg[[threadgroups_per_grid]]) {
 | 
			
		||||
 | 
			
		||||
    float v = 0.0f;
 | 
			
		||||
 | 
			
		||||
    for (int64_t c = 0; c < IC; c++) {
 | 
			
		||||
        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
 | 
			
		||||
        const int32_t input_offset = c * IL;
 | 
			
		||||
 | 
			
		||||
        for (int64_t i = 0; i < IL; i++) {
 | 
			
		||||
            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
 | 
			
		||||
                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
 | 
			
		||||
 | 
			
		||||
    dst_ptr[0] = v;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
 | 
			
		||||
kernel void kernel_conv_transpose_1d<float>(
 | 
			
		||||
    device const float * src0,
 | 
			
		||||
    device const float * src1,
 | 
			
		||||
    device        char * dst,
 | 
			
		||||
    constant   int32_t & IC,
 | 
			
		||||
    constant   int32_t & IL,
 | 
			
		||||
    constant   int32_t & K,
 | 
			
		||||
    constant   int32_t & s0,
 | 
			
		||||
    constant  uint64_t & nb0,
 | 
			
		||||
    constant  uint64_t & nb1,
 | 
			
		||||
    uint3   tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
    uint3    tgpg[[threadgroups_per_grid]]);
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
 | 
			
		||||
kernel void kernel_conv_transpose_1d<half>(
 | 
			
		||||
    device const half  * src0,
 | 
			
		||||
    device const float * src1,
 | 
			
		||||
    device        char * dst,
 | 
			
		||||
    constant   int32_t & IC,
 | 
			
		||||
    constant   int32_t & IL,
 | 
			
		||||
    constant   int32_t & K,
 | 
			
		||||
    constant   int32_t & s0,
 | 
			
		||||
    constant  uint64_t & nb0,
 | 
			
		||||
    constant  uint64_t & nb1,
 | 
			
		||||
    uint3   tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
    uint3    tgpg[[threadgroups_per_grid]]);
 | 
			
		||||
 | 
			
		||||
kernel void kernel_upscale_f32(
 | 
			
		||||
    device  const char * src0,
 | 
			
		||||
    device        char * dst,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user