#ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL // kernel parameters for mat-vec threadgroups // // N_R0: number of src0 rows to process per simdgroup // N_SG: number of simdgroups per threadgroup // // TODO: for optimal performance, become function of the device and work size #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 #define N_R0_Q4_1 4 #define N_SG_Q4_1 2 #define N_R0_Q5_0 4 #define N_SG_Q5_0 2 #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 #define N_R0_Q8_0 4 #define N_SG_Q8_0 2 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 #define N_R0_Q2_K 4 #define N_SG_Q2_K 2 #define N_R0_Q3_K 2 #define N_SG_Q3_K 2 #define N_R0_Q4_K 4 #define N_SG_Q4_K 2 #define N_R0_Q5_K 2 #define N_SG_Q5_K 2 #define N_R0_Q6_K 1 #define N_SG_Q6_K 2 #define N_R0_IQ1_S 4 #define N_SG_IQ1_S 2 #define N_R0_IQ1_M 4 #define N_SG_IQ1_M 2 #define N_R0_IQ2_XXS 4 #define N_SG_IQ2_XXS 2 #define N_R0_IQ2_XS 4 #define N_SG_IQ2_XS 2 #define N_R0_IQ2_S 4 #define N_SG_IQ2_S 2 #define N_R0_IQ3_XXS 4 #define N_SG_IQ3_XXS 2 #define N_R0_IQ3_S 4 #define N_SG_IQ3_S 2 #define N_R0_IQ4_NL 2 #define N_SG_IQ4_NL 2 #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage // however, be careful from int overflows when using those in the kernel implementation // // - strides (e.g. nb00) use uint64_t typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t dim; } ggml_metal_kargs_concat; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; uint64_t offs; uint64_t o1[8]; } ggml_metal_kargs_bin; typedef struct { int64_t ne0; int64_t ne1; size_t nb01; size_t nb02; size_t nb11; size_t nb21; } ggml_metal_kargs_add_id; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_repeat; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_cpy; typedef struct { int64_t ne10; int64_t ne11; int64_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb1; uint64_t nb2; uint64_t nb3; uint64_t offs; bool inplace; } ggml_metal_kargs_set; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t n_past; int32_t n_dims; int32_t n_ctx_orig; float freq_base; float freq_scale; float ext_factor; float attn_factor; float beta_fast; float beta_slow; int32_t sect_0; int32_t sect_1; int32_t sect_2; int32_t sect_3; } ggml_metal_kargs_rope; typedef struct { int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb21; uint64_t nb22; uint64_t nb23; int32_t ne32; int32_t ne33; uint64_t nb31; uint64_t nb32; uint64_t nb33; int32_t ne1; int32_t ne2; float scale; float max_bias; float m0; float m1; int32_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; int16_t nsg; int16_t nxpsg; int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { int32_t ne02; int32_t ne10; int32_t ne11; // n_expert_used (bcast) uint64_t nb11; uint64_t nb12; int32_t ne21; // n_tokens int32_t ne20; // n_expert_used uint64_t nb21; } ggml_metal_kargs_mul_mm_id_map0; typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne20; int32_t ne21; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm_id; typedef struct { int32_t nei0; int32_t nei1; uint64_t nbi1; int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; int32_t ne0; int32_t ne1; uint64_t nb1; } ggml_metal_kargs_mul_mv_id; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_norm; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb1; uint64_t nb2; uint64_t nb3; float eps; int32_t nef1[3]; int32_t nef2[3]; int32_t nef3[3]; uint64_t nbf1[3]; uint64_t nbf2[3]; uint64_t nbf3[3]; } ggml_metal_kargs_rms_norm; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_l2_norm; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int32_t n_groups; float eps; } ggml_metal_kargs_group_norm; typedef struct { int32_t IC; int32_t IL; int32_t K; int32_t s0; uint64_t nb0; uint64_t nb1; } ggml_metal_kargs_conv_transpose_1d; typedef struct { uint64_t ofs0; uint64_t ofs1; int32_t IW; int32_t IH; int32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; int32_t d0; int32_t d1; int32_t N; int32_t KH; int32_t KW; int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; typedef struct{ int32_t ne00; uint64_t nb01; int32_t ne10; uint64_t nb11; int32_t ne0; uint64_t nb1; int32_t i00; int32_t i10; float alpha; float limit; } ggml_metal_kargs_glu; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne10; int64_t ne11; int64_t ne12; int64_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_sum_rows; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb1; uint64_t nb2; uint64_t nb3; float scale; float max_bias; float m0; float m1; int32_t n_head_log2; } ggml_metal_kargs_soft_max; typedef struct { int64_t ne00; int64_t ne01; int n_past; } ggml_metal_kargs_diag_mask_inf; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int64_t ne10; int64_t ne11; uint64_t nb10; uint64_t nb11; int64_t ne0; int64_t ne1; int64_t ne2; uint64_t nb0; uint64_t nb1; uint64_t nb2; } ggml_metal_kargs_ssm_conv; typedef struct { int64_t d_state; int64_t d_inner; int64_t n_head; int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; uint64_t nb11; uint64_t nb12; uint64_t nb13; uint64_t nb21; uint64_t nb22; uint64_t nb31; uint64_t nb41; uint64_t nb42; uint64_t nb43; uint64_t nb51; uint64_t nb52; uint64_t nb53; } ggml_metal_kargs_ssm_scan; typedef struct { int64_t ne00; uint64_t nb01; uint64_t nb02; int64_t ne10; uint64_t nb10; uint64_t nb11; uint64_t nb1; uint64_t nb2; } ggml_metal_kargs_get_rows; typedef struct { int32_t nk0; int32_t ne01; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_set_rows; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; float sf0; float sf1; float sf2; float sf3; } ggml_metal_kargs_upscale; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_pad; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t p0; int32_t p1; } ggml_metal_kargs_pad_reflect_1d; typedef struct { uint64_t nb1; int dim; int max_period; } ggml_metal_kargs_timestep_embedding; typedef struct { float slope; } ggml_metal_kargs_leaky_relu; typedef struct { int64_t ncols; int64_t ncols_pad; } ggml_metal_kargs_argsort; typedef struct { int64_t ne0; float start; float step; } ggml_metal_kargs_arange; typedef struct { int32_t k0; int32_t k1; int32_t s0; int32_t s1; int32_t p0; int32_t p1; int64_t IH; int64_t IW; int64_t OH; int64_t OW; int64_t parallel_elements; } ggml_metal_kargs_pool_2d; #endif // GGML_METAL_IMPL