metal : simplify kernel arguments using a struct (#3229) (#12194)

* metal : refactor im2col parameters into a struct

* metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets

* metal : refactor sum_rows parameters into a struct

* metal : refactor soft_max parameters into a struct

* metal : refactor diag_mask_inf parameters into a struct

* metal : refactor ssm_conv parameters into a struct

* metal : refactor ssm_scan parameters into a struct

* metal : refactor get_rows parameters into a struct

* metal : refactor group_norm parameters into a struct

* metal : refactor conv_transpose_1d parameters into a struct

* metal : refactor upscale parameters into a struct

* metal : refactor pad parameters into a struct

* metal : refactor pad_reflect_1d parameters into a struct

* metal : refactor arange parameters into a struct

* metal : refactor timestep_embedding parameters into a struct

* metal : refactor argsort parameters into a struct

* metal : refactor leaky_relu parameters into a struct

* metal : refactor pool_2d parameters into a struct

* metal : fix trailing whitespace

---------

Co-authored-by: alexju <alexju@tencent.com>
This commit is contained in:
BB-fat
2025-03-07 15:35:57 +08:00
committed by GitHub
parent f1648e91cf
commit 5e2d57b2b2
3 changed files with 685 additions and 643 deletions

View File

@@ -285,4 +285,239 @@ typedef struct {
float eps;
} ggml_metal_kargs_rms_norm;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t n_groups;
float eps;
} ggml_metal_kargs_group_norm;
typedef struct {
int32_t IC;
int32_t IL;
int32_t K;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
int32_t IW;
int32_t IH;
int32_t CHW;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int32_t d0;
int32_t d1;
int32_t N;
int32_t KH;
int32_t KW;
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne10;
int64_t ne11;
int64_t ne12;
int64_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_sum_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
float scale;
float max_bias;
float m0;
float m1;
uint32_t n_head_log2;
} ggml_metal_kargs_soft_max;
typedef struct {
int64_t ne00;
int64_t ne01;
int n_past;
} ggml_metal_kargs_diag_mask_inf;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
int64_t ne11;
uint64_t nb10;
uint64_t nb11;
int64_t ne0;
int64_t ne1;
int64_t ne2;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_ssm_conv;
typedef struct {
int64_t d_state;
int64_t d_inner;
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb30;
uint64_t nb31;
uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
uint64_t nb50;
uint64_t nb51;
uint64_t nb52;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_get_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float sf0;
float sf1;
float sf2;
float sf3;
} ggml_metal_kargs_upscale;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_pad;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t p0;
int32_t p1;
} ggml_metal_kargs_pad_reflect_1d;
typedef struct {
uint64_t nb1;
int dim;
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ncols;
int64_t ncols_pad;
} ggml_metal_kargs_argsort;
typedef struct {
int64_t ne0;
float start;
float step;
} ggml_metal_kargs_arange;
typedef struct {
int32_t k0;
int32_t k1;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int64_t IH;
int64_t IW;
int64_t OH;
int64_t OW;
int64_t parallel_elements;
} ggml_metal_kargs_pool_2d;
#endif // GGML_METAL_IMPL