mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
metal : refactor kernel args into structs (#10238)
* metal : add kernel arg structs (wip) * metal : fattn args ggml-ci * metal : cont + avoid potential int overflow [no ci] * metal : mul mat struct (wip) * cont : mul mat vec * cont : pass by reference * cont : args is first argument * cont : use char ptr * cont : shmem style * cont : thread counters style * cont : mul mm id ggml-ci * cont : int safety + register optimizations ggml-ci * metal : GGML_OP_CONCAT ggml-ci * metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV * metal : GGML_OP_REPEAT * metal : GGML_OP_CPY * metal : GGML_OP_RMS_NORM * metal : GGML_OP_NORM * metal : add TODOs for rest of ops * ggml : add ggml-metal-impl.h ggml-ci
This commit is contained in:
249
ggml/src/ggml-metal/ggml-metal-impl.h
Normal file
249
ggml/src/ggml-metal/ggml-metal-impl.h
Normal file
@@ -0,0 +1,249 @@
|
||||
#ifndef GGML_METAL_IMPL
|
||||
#define GGML_METAL_IMPL
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
// however, be careful from int overflows when using those in the kernel implementation
|
||||
//
|
||||
// - strides (e.g. nb00) use uint64_t
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
uint64_t offs;
|
||||
} ggml_metal_kargs_bin;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_cpy;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
int32_t n_past;
|
||||
int32_t n_dims;
|
||||
int32_t n_ctx_orig;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb_12_1;
|
||||
uint64_t nb_12_2;
|
||||
uint64_t nb_12_3;
|
||||
uint64_t nb31;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
uint16_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mv;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
uint64_t nb1;
|
||||
} ggml_metal_kargs_mul_mv_id;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_norm;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
uint64_t nb01;
|
||||
float eps;
|
||||
} ggml_metal_kargs_rms_norm;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
Reference in New Issue
Block a user