mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
619 lines
11 KiB
C
619 lines
11 KiB
C
#ifndef GGML_METAL_IMPL
|
|
#define GGML_METAL_IMPL
|
|
|
|
// kernel parameters for mat-vec threadgroups
|
|
//
|
|
// N_R0: number of src0 rows to process per simdgroup
|
|
// N_SG: number of simdgroups per threadgroup
|
|
//
|
|
// TODO: for optimal performance, become function of the device and work size
|
|
|
|
#define N_R0_Q4_0 4
|
|
#define N_SG_Q4_0 2
|
|
|
|
#define N_R0_Q4_1 4
|
|
#define N_SG_Q4_1 2
|
|
|
|
#define N_R0_Q5_0 4
|
|
#define N_SG_Q5_0 2
|
|
|
|
#define N_R0_Q5_1 4
|
|
#define N_SG_Q5_1 2
|
|
|
|
#define N_R0_Q8_0 4
|
|
#define N_SG_Q8_0 2
|
|
|
|
#define N_R0_Q2_K 4
|
|
#define N_SG_Q2_K 2
|
|
|
|
#define N_R0_Q3_K 2
|
|
#define N_SG_Q3_K 2
|
|
|
|
#define N_R0_Q4_K 4
|
|
#define N_SG_Q4_K 2
|
|
|
|
#define N_R0_Q5_K 2
|
|
#define N_SG_Q5_K 2
|
|
|
|
#define N_R0_Q6_K 1
|
|
#define N_SG_Q6_K 2
|
|
|
|
#define N_R0_IQ1_S 4
|
|
#define N_SG_IQ1_S 2
|
|
|
|
#define N_R0_IQ1_M 4
|
|
#define N_SG_IQ1_M 2
|
|
|
|
#define N_R0_IQ2_XXS 4
|
|
#define N_SG_IQ2_XXS 2
|
|
|
|
#define N_R0_IQ2_XS 4
|
|
#define N_SG_IQ2_XS 2
|
|
|
|
#define N_R0_IQ2_S 4
|
|
#define N_SG_IQ2_S 2
|
|
|
|
#define N_R0_IQ3_XXS 4
|
|
#define N_SG_IQ3_XXS 2
|
|
|
|
#define N_R0_IQ3_S 4
|
|
#define N_SG_IQ3_S 2
|
|
|
|
#define N_R0_IQ4_NL 2
|
|
#define N_SG_IQ4_NL 2
|
|
|
|
#define N_R0_IQ4_XS 2
|
|
#define N_SG_IQ4_XS 2
|
|
|
|
// kernel argument structs
|
|
//
|
|
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
|
// however, be careful from int overflows when using those in the kernel implementation
|
|
//
|
|
// - strides (e.g. nb00) use uint64_t
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t dim;
|
|
} ggml_metal_kargs_concat;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
} ggml_metal_kargs_bin;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_repeat;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_cpy;
|
|
|
|
typedef struct {
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
int64_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
bool inplace;
|
|
} ggml_metal_kargs_set;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t n_past;
|
|
int32_t n_dims;
|
|
int32_t n_ctx_orig;
|
|
float freq_base;
|
|
float freq_scale;
|
|
float ext_factor;
|
|
float attn_factor;
|
|
float beta_fast;
|
|
float beta_slow;
|
|
} ggml_metal_kargs_rope;
|
|
|
|
typedef struct {
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne_12_2; // assume K and V are same shape
|
|
int32_t ne_12_3;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb23;
|
|
uint64_t nb31;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
uint16_t n_head_log2;
|
|
float logit_softcap;
|
|
} ggml_metal_kargs_flash_attn_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mv;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
int16_t nsg;
|
|
int16_t nxpsg;
|
|
int16_t r1ptg;
|
|
} ggml_metal_kargs_mul_mv_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne10;
|
|
int32_t ne11; // n_expert_used (bcast)
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t neh11; // n_tokens
|
|
uint64_t nbh11;
|
|
int32_t ne20; // n_expert_used
|
|
uint64_t nb21;
|
|
} ggml_metal_kargs_mul_mm_id_map0;
|
|
|
|
typedef struct {
|
|
int32_t ne20; // n_expert_used
|
|
int32_t neh0;
|
|
int32_t neh1;
|
|
uint64_t nbh1;
|
|
uint64_t nbh2;
|
|
int32_t ne0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_mul_mm_id_map1;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t neh12;
|
|
uint64_t nbh10;
|
|
uint64_t nbh11;
|
|
uint64_t nbh12;
|
|
uint64_t nbh13;
|
|
int32_t neh0;
|
|
int32_t neh1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm_id;
|
|
|
|
typedef struct {
|
|
int32_t nei0;
|
|
int32_t nei1;
|
|
uint64_t nbi1;
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
uint64_t nb1;
|
|
} ggml_metal_kargs_mul_mv_id;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_norm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_rms_norm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_l2_norm;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t n_groups;
|
|
float eps;
|
|
} ggml_metal_kargs_group_norm;
|
|
|
|
typedef struct {
|
|
int32_t IC;
|
|
int32_t IL;
|
|
int32_t K;
|
|
int32_t s0;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
} ggml_metal_kargs_conv_transpose_1d;
|
|
|
|
typedef struct {
|
|
uint64_t ofs0;
|
|
uint64_t ofs1;
|
|
int32_t IW;
|
|
int32_t IH;
|
|
int32_t CHW;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int32_t d0;
|
|
int32_t d1;
|
|
int32_t N;
|
|
int32_t KH;
|
|
int32_t KW;
|
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
} ggml_metal_kargs_im2col;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
int64_t ne12;
|
|
int64_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_sum_rows;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
uint32_t n_head_log2;
|
|
} ggml_metal_kargs_soft_max;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int n_past;
|
|
} ggml_metal_kargs_diag_mask_inf;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_ssm_conv;
|
|
|
|
typedef struct {
|
|
int64_t d_state;
|
|
int64_t d_inner;
|
|
int64_t n_seq_tokens;
|
|
int64_t n_seqs;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb20;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb30;
|
|
uint64_t nb31;
|
|
uint64_t nb40;
|
|
uint64_t nb41;
|
|
uint64_t nb42;
|
|
uint64_t nb50;
|
|
uint64_t nb51;
|
|
uint64_t nb52;
|
|
} ggml_metal_kargs_ssm_scan;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int64_t ne10;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_get_rows;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float sf0;
|
|
float sf1;
|
|
float sf2;
|
|
float sf3;
|
|
} ggml_metal_kargs_upscale;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_pad;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
} ggml_metal_kargs_pad_reflect_1d;
|
|
|
|
typedef struct {
|
|
uint64_t nb1;
|
|
int dim;
|
|
int max_period;
|
|
} ggml_metal_kargs_timestep_embedding;
|
|
|
|
typedef struct {
|
|
float slope;
|
|
} ggml_metal_kargs_leaky_relu;
|
|
|
|
typedef struct {
|
|
int64_t ncols;
|
|
int64_t ncols_pad;
|
|
} ggml_metal_kargs_argsort;
|
|
|
|
typedef struct {
|
|
int64_t ne0;
|
|
float start;
|
|
float step;
|
|
} ggml_metal_kargs_arange;
|
|
|
|
typedef struct {
|
|
int32_t k0;
|
|
int32_t k1;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int64_t IH;
|
|
int64_t IW;
|
|
int64_t OH;
|
|
int64_t OW;
|
|
int64_t parallel_elements;
|
|
} ggml_metal_kargs_pool_2d;
|
|
|
|
#endif // GGML_METAL_IMPL
|