Files
llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
Georgi Gerganov 1d8d83deaa metal : improve MUL_MAT_ID (#15541)
* metal : mul_mm_id remove hdst

* metal : remove mul_mm_id hsrc1

* metal : mul_mm_id simplify + add test

* metal : opt mul_mm_id map0

* metal : optimize mul_mm_id id gathering

* metal : mul/div opt

* metal : optimize mul_mm_id_map0

ggml-ci
2025-08-26 12:46:15 +03:00

680 lines
12 KiB
C

#ifndef GGML_METAL_IMPL
#define GGML_METAL_IMPL
// kernel parameters for mat-vec threadgroups
//
// N_R0: number of src0 rows to process per simdgroup
// N_SG: number of simdgroups per threadgroup
//
// TODO: for optimal performance, become function of the device and work size
#define N_R0_Q4_0 4
#define N_SG_Q4_0 2
#define N_R0_Q4_1 4
#define N_SG_Q4_1 2
#define N_R0_Q5_0 4
#define N_SG_Q5_0 2
#define N_R0_Q5_1 4
#define N_SG_Q5_1 2
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2
#define N_R0_Q3_K 2
#define N_SG_Q3_K 2
#define N_R0_Q4_K 4
#define N_SG_Q4_K 2
#define N_R0_Q5_K 2
#define N_SG_Q5_K 2
#define N_R0_Q6_K 1
#define N_SG_Q6_K 2
#define N_R0_IQ1_S 4
#define N_SG_IQ1_S 2
#define N_R0_IQ1_M 4
#define N_SG_IQ1_M 2
#define N_R0_IQ2_XXS 4
#define N_SG_IQ2_XXS 2
#define N_R0_IQ2_XS 4
#define N_SG_IQ2_XS 2
#define N_R0_IQ2_S 4
#define N_SG_IQ2_S 2
#define N_R0_IQ3_XXS 4
#define N_SG_IQ3_XXS 2
#define N_R0_IQ3_S 4
#define N_SG_IQ3_S 2
#define N_R0_IQ4_NL 2
#define N_SG_IQ4_NL 2
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
// however, be careful from int overflows when using those in the kernel implementation
//
// - strides (e.g. nb00) use uint64_t
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
uint64_t o1[8];
} ggml_metal_kargs_bin;
typedef struct {
int64_t ne0;
int64_t ne1;
size_t nb01;
size_t nb02;
size_t nb11;
size_t nb21;
} ggml_metal_kargs_add_id;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_cpy;
typedef struct {
int64_t ne10;
int64_t ne11;
int64_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
bool inplace;
} ggml_metal_kargs_set;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_past;
int32_t n_dims;
int32_t n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
int32_t sect_0;
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mv;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
int16_t nsg;
int16_t nxpsg;
int16_t r1ptg;
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
int32_t ne02;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
int32_t ne21; // n_tokens
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
int32_t nef1[3];
int32_t nef2[3];
int32_t nef3[3];
uint64_t nbf1[3];
uint64_t nbf2[3];
uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_l2_norm;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t n_groups;
float eps;
} ggml_metal_kargs_group_norm;
typedef struct {
int32_t IC;
int32_t IL;
int32_t K;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
uint64_t ofs0;
uint64_t ofs1;
int32_t IW;
int32_t IH;
int32_t CHW;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int32_t d0;
int32_t d1;
int32_t N;
int32_t KH;
int32_t KW;
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col;
typedef struct{
int32_t ne00;
uint64_t nb01;
int32_t ne10;
uint64_t nb11;
int32_t ne0;
uint64_t nb1;
int32_t i00;
int32_t i10;
float alpha;
float limit;
} ggml_metal_kargs_glu;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne10;
int64_t ne11;
int64_t ne12;
int64_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_sum_rows;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
} ggml_metal_kargs_soft_max;
typedef struct {
int64_t ne00;
int64_t ne01;
int n_past;
} ggml_metal_kargs_diag_mask_inf;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
int64_t ne11;
uint64_t nb10;
uint64_t nb11;
int64_t ne0;
int64_t ne1;
int64_t ne2;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_ssm_conv;
typedef struct {
int64_t d_state;
int64_t d_inner;
int64_t n_head;
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb31;
uint64_t nb41;
uint64_t nb42;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t nb53;
} ggml_metal_kargs_ssm_scan;
typedef struct {
int64_t ne00;
uint64_t nb01;
uint64_t nb02;
int64_t ne10;
uint64_t nb10;
uint64_t nb11;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_get_rows;
typedef struct {
int32_t nk0;
int32_t ne01;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_set_rows;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float sf0;
float sf1;
float sf2;
float sf3;
} ggml_metal_kargs_upscale;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_pad;
typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t p0;
int32_t p1;
} ggml_metal_kargs_pad_reflect_1d;
typedef struct {
uint64_t nb1;
int dim;
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int64_t ncols;
int64_t ncols_pad;
} ggml_metal_kargs_argsort;
typedef struct {
int64_t ne0;
float start;
float step;
} ggml_metal_kargs_arange;
typedef struct {
int32_t k0;
int32_t k1;
int32_t s0;
int32_t s1;
int32_t p0;
int32_t p1;
int64_t IH;
int64_t IW;
int64_t OH;
int64_t OW;
int64_t parallel_elements;
} ggml_metal_kargs_pool_2d;
#endif // GGML_METAL_IMPL