mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)
				
					
				
			* ggml_pad_reflect_1d defined in header * implemented on CPU * called the forward pass * impl Metal kernel * added Metal kernel * added OP_PAD_REFLECT_1D in test-backend-ops.cpp * add test-pad-reflect-1d test case * test case support multiple backend
This commit is contained in:
		@@ -310,6 +310,7 @@ enum ggml_metal_kernel_type {
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_PAD_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
 | 
			
		||||
    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
 | 
			
		||||
@@ -877,6 +878,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,     conv_transpose_1d_f16_f32,      true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
 | 
			
		||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
 | 
			
		||||
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 | 
			
		||||
        case GGML_OP_POOL_2D:
 | 
			
		||||
        case GGML_OP_UPSCALE:
 | 
			
		||||
        case GGML_OP_PAD:
 | 
			
		||||
        case GGML_OP_PAD_REFLECT_1D:
 | 
			
		||||
        case GGML_OP_ARANGE:
 | 
			
		||||
        case GGML_OP_TIMESTEP_EMBEDDING:
 | 
			
		||||
        case GGML_OP_ARGSORT:
 | 
			
		||||
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
 | 
			
		||||
 | 
			
		||||
                const int nth = MIN(1024, ne0);
 | 
			
		||||
 | 
			
		||||
                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 | 
			
		||||
            } break;
 | 
			
		||||
        case GGML_OP_PAD_REFLECT_1D:
 | 
			
		||||
            {
 | 
			
		||||
                GGML_ASSERT(src0->type == GGML_TYPE_F32);
 | 
			
		||||
 | 
			
		||||
                const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
 | 
			
		||||
                const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
 | 
			
		||||
 | 
			
		||||
                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
 | 
			
		||||
 | 
			
		||||
                [encoder setComputePipelineState:pipeline];
 | 
			
		||||
                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 | 
			
		||||
                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 | 
			
		||||
                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
 | 
			
		||||
                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
 | 
			
		||||
                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
 | 
			
		||||
                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
 | 
			
		||||
                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:6];
 | 
			
		||||
                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
 | 
			
		||||
                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
 | 
			
		||||
                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
 | 
			
		||||
                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
 | 
			
		||||
                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:11];
 | 
			
		||||
                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:12];
 | 
			
		||||
                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:13];
 | 
			
		||||
                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:14];
 | 
			
		||||
                [encoder setBytes:&p0   length:sizeof(p0)   atIndex:15];
 | 
			
		||||
                [encoder setBytes:&p1   length:sizeof(p1)   atIndex:16];
 | 
			
		||||
 | 
			
		||||
                const int nth = MIN(1024, ne0);
 | 
			
		||||
 | 
			
		||||
                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 | 
			
		||||
            } break;
 | 
			
		||||
        case GGML_OP_ARANGE:
 | 
			
		||||
 
 | 
			
		||||
@@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
kernel void kernel_pad_reflect_1d_f32(
 | 
			
		||||
    device  const char * src0,
 | 
			
		||||
    device        char * dst,
 | 
			
		||||
    constant   int64_t & ne00,
 | 
			
		||||
    constant   int64_t & ne01,
 | 
			
		||||
    constant   int64_t & ne02,
 | 
			
		||||
    constant   int64_t & ne03,
 | 
			
		||||
    constant   int64_t & ne0,
 | 
			
		||||
    constant  uint64_t & nb00,
 | 
			
		||||
    constant  uint64_t & nb01,
 | 
			
		||||
    constant  uint64_t & nb02,
 | 
			
		||||
    constant  uint64_t & nb03,
 | 
			
		||||
    constant  uint64_t & nb0,
 | 
			
		||||
    constant  uint64_t & nb1,
 | 
			
		||||
    constant  uint64_t & nb2,
 | 
			
		||||
    constant  uint64_t & nb3,
 | 
			
		||||
    constant   int32_t & p0,
 | 
			
		||||
    constant   int32_t & p1,
 | 
			
		||||
    uint3 tgpig[[threadgroup_position_in_grid]],
 | 
			
		||||
    uint3  tgpg[[threadgroups_per_grid]],
 | 
			
		||||
    uint3 tpitg[[thread_position_in_threadgroup]],
 | 
			
		||||
    uint3   ntg[[threads_per_threadgroup]]) {
 | 
			
		||||
 | 
			
		||||
    const int64_t i3 = tgpig.z;
 | 
			
		||||
    const int64_t i2 = tgpig.y;
 | 
			
		||||
    const int64_t i1 = tgpig.x;
 | 
			
		||||
 | 
			
		||||
    const int64_t i03 = i3;
 | 
			
		||||
    const int64_t i02 = i2;
 | 
			
		||||
    const int64_t i01 = i1;
 | 
			
		||||
 | 
			
		||||
    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
 | 
			
		||||
    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
 | 
			
		||||
 | 
			
		||||
    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
 | 
			
		||||
        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
 | 
			
		||||
            if (i0 < p0) {
 | 
			
		||||
                dst_ptr[i0] = src0_ptr[p0 - i0];
 | 
			
		||||
            } else if (i0 < ne0 - p1) {
 | 
			
		||||
                dst_ptr[i0] = src0_ptr[i0 - p0];
 | 
			
		||||
            } else {
 | 
			
		||||
                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
kernel void kernel_arange_f32(
 | 
			
		||||
    device        char * dst,
 | 
			
		||||
    constant   int64_t & ne0,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user