mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
* metal : pad K, V and Mask when needed * cont : simplify * cuda : add TODO about KV padding requirement * metal : add comments * metal : remove mask padding requirement
758 lines
14 KiB
C
758 lines
14 KiB
C
#ifndef GGML_METAL_IMPL
|
|
#define GGML_METAL_IMPL
|
|
|
|
// kernel parameters for mat-vec threadgroups
|
|
//
|
|
// N_R0: number of src0 rows to process per simdgroup
|
|
// N_SG: number of simdgroups per threadgroup
|
|
//
|
|
// TODO: for optimal performance, become function of the device and work size
|
|
|
|
#define N_R0_Q4_0 4
|
|
#define N_SG_Q4_0 2
|
|
|
|
#define N_R0_Q4_1 4
|
|
#define N_SG_Q4_1 2
|
|
|
|
#define N_R0_Q5_0 4
|
|
#define N_SG_Q5_0 2
|
|
|
|
#define N_R0_Q5_1 4
|
|
#define N_SG_Q5_1 2
|
|
|
|
#define N_R0_Q8_0 2
|
|
#define N_SG_Q8_0 4
|
|
|
|
#define N_R0_MXFP4 2
|
|
#define N_SG_MXFP4 2
|
|
|
|
#define N_R0_Q2_K 4
|
|
#define N_SG_Q2_K 2
|
|
|
|
#define N_R0_Q3_K 2
|
|
#define N_SG_Q3_K 2
|
|
|
|
#define N_R0_Q4_K 2
|
|
#define N_SG_Q4_K 2
|
|
|
|
#define N_R0_Q5_K 2
|
|
#define N_SG_Q5_K 2
|
|
|
|
#define N_R0_Q6_K 2
|
|
#define N_SG_Q6_K 2
|
|
|
|
#define N_R0_IQ1_S 4
|
|
#define N_SG_IQ1_S 2
|
|
|
|
#define N_R0_IQ1_M 4
|
|
#define N_SG_IQ1_M 2
|
|
|
|
#define N_R0_IQ2_XXS 4
|
|
#define N_SG_IQ2_XXS 2
|
|
|
|
#define N_R0_IQ2_XS 4
|
|
#define N_SG_IQ2_XS 2
|
|
|
|
#define N_R0_IQ2_S 4
|
|
#define N_SG_IQ2_S 2
|
|
|
|
#define N_R0_IQ3_XXS 4
|
|
#define N_SG_IQ3_XXS 2
|
|
|
|
#define N_R0_IQ3_S 4
|
|
#define N_SG_IQ3_S 2
|
|
|
|
#define N_R0_IQ4_NL 2
|
|
#define N_SG_IQ4_NL 2
|
|
|
|
#define N_R0_IQ4_XS 2
|
|
#define N_SG_IQ4_XS 2
|
|
|
|
// function constants offsets
|
|
#define FC_FLASH_ATTN_EXT_PAD 100
|
|
#define FC_FLASH_ATTN_EXT 200
|
|
#define FC_FLASH_ATTN_EXT_VEC 300
|
|
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
|
|
#define FC_MUL_MV 500
|
|
#define FC_MUL_MM 600
|
|
|
|
// kernel argument structs
|
|
//
|
|
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
|
// however, be careful from int overflows when using those in the kernel implementation
|
|
//
|
|
// - strides (e.g. nb00) use uint64_t
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t dim;
|
|
} ggml_metal_kargs_concat;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
uint64_t o1[8];
|
|
} ggml_metal_kargs_bin;
|
|
|
|
typedef struct {
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
size_t nb01;
|
|
size_t nb02;
|
|
size_t nb11;
|
|
size_t nb21;
|
|
} ggml_metal_kargs_add_id;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_repeat;
|
|
|
|
typedef struct {
|
|
float scale;
|
|
float bias;
|
|
} ggml_metal_kargs_scale;
|
|
|
|
typedef struct {
|
|
float min;
|
|
float max;
|
|
} ggml_metal_kargs_clamp;
|
|
|
|
typedef struct {
|
|
int64_t nk0;
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_cpy;
|
|
|
|
typedef struct {
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
int64_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
uint64_t offs;
|
|
bool inplace;
|
|
} ggml_metal_kargs_set;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t n_past;
|
|
int32_t n_dims;
|
|
int32_t n_ctx_orig;
|
|
float freq_base;
|
|
float freq_scale;
|
|
float ext_factor;
|
|
float attn_factor;
|
|
float beta_fast;
|
|
float beta_slow;
|
|
int32_t sect_0;
|
|
int32_t sect_1;
|
|
int32_t sect_2;
|
|
int32_t sect_3;
|
|
} ggml_metal_kargs_rope;
|
|
|
|
typedef struct {
|
|
int32_t ne11;
|
|
int32_t ne_12_2; // assume K and V are same shape
|
|
int32_t ne_12_3;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb23;
|
|
int32_t ne31;
|
|
int32_t ne32;
|
|
int32_t ne33;
|
|
uint64_t nb31;
|
|
uint64_t nb32;
|
|
uint64_t nb33;
|
|
} ggml_metal_kargs_flash_attn_ext_pad;
|
|
|
|
typedef struct {
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne_12_2; // assume K and V are same shape
|
|
int32_t ne_12_3;
|
|
int32_t ns10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ns20;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb23;
|
|
int32_t ne31;
|
|
int32_t ne32;
|
|
int32_t ne33;
|
|
uint64_t nb31;
|
|
uint64_t nb32;
|
|
uint64_t nb33;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
int32_t n_head_log2;
|
|
float logit_softcap;
|
|
} ggml_metal_kargs_flash_attn_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
int32_t ne03;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne_12_2; // assume K and V are same shape
|
|
int32_t ne_12_3;
|
|
int32_t ns10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ns20;
|
|
uint64_t nb21;
|
|
uint64_t nb22;
|
|
uint64_t nb23;
|
|
int32_t ne31;
|
|
int32_t ne32;
|
|
int32_t ne33;
|
|
uint64_t nb31;
|
|
uint64_t nb32;
|
|
uint64_t nb33;
|
|
int32_t ne1;
|
|
int32_t ne2;
|
|
int32_t ne3;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
int32_t n_head_log2;
|
|
float logit_softcap;
|
|
} ggml_metal_kargs_flash_attn_ext_vec;
|
|
|
|
typedef struct {
|
|
int32_t nrows;
|
|
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int32_t nr0;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mv;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mv_ext;
|
|
|
|
typedef struct {
|
|
int32_t ne02;
|
|
int32_t ne10;
|
|
int32_t ne11; // n_expert_used (bcast)
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t ne21; // n_tokens
|
|
int32_t ne20; // n_expert_used
|
|
uint64_t nb21;
|
|
} ggml_metal_kargs_mul_mm_id_map0;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
int32_t ne20;
|
|
int32_t ne21;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
int16_t r2;
|
|
int16_t r3;
|
|
} ggml_metal_kargs_mul_mm_id;
|
|
|
|
typedef struct {
|
|
int32_t nei0;
|
|
int32_t nei1;
|
|
uint64_t nbi1;
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t ne10;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
int32_t ne0;
|
|
int32_t ne1;
|
|
uint64_t nb1;
|
|
int32_t nr0;
|
|
} ggml_metal_kargs_mul_mv_id;
|
|
|
|
// NORM
|
|
// RMS_NORM
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_t;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float eps;
|
|
int32_t nef1[3];
|
|
int32_t nef2[3];
|
|
int32_t nef3[3];
|
|
uint64_t nbf1[3];
|
|
uint64_t nbf2[3];
|
|
uint64_t nbf3[3];
|
|
} ggml_metal_kargs_norm;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne00_4;
|
|
uint64_t nb01;
|
|
float eps;
|
|
} ggml_metal_kargs_l2_norm;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int32_t ngrp;
|
|
float eps;
|
|
} ggml_metal_kargs_group_norm;
|
|
|
|
typedef struct {
|
|
int32_t IC;
|
|
int32_t IL;
|
|
int32_t K;
|
|
int32_t s0;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
} ggml_metal_kargs_conv_transpose_1d;
|
|
|
|
typedef struct {
|
|
uint64_t ofs0;
|
|
uint64_t ofs1;
|
|
int32_t IW;
|
|
int32_t IH;
|
|
int32_t CHW;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int32_t d0;
|
|
int32_t d1;
|
|
int32_t N;
|
|
int32_t KH;
|
|
int32_t KW;
|
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
|
} ggml_metal_kargs_im2col;
|
|
|
|
typedef struct{
|
|
int32_t ne00;
|
|
uint64_t nb01;
|
|
int32_t ne10;
|
|
uint64_t nb11;
|
|
int32_t ne0;
|
|
uint64_t nb1;
|
|
int32_t i00;
|
|
int32_t i10;
|
|
float alpha;
|
|
float limit;
|
|
} ggml_metal_kargs_glu;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_sum_rows;
|
|
|
|
typedef struct {
|
|
int32_t ne00;
|
|
int32_t ne01;
|
|
int32_t ne02;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
int32_t ne13;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb13;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float scale;
|
|
float max_bias;
|
|
float m0;
|
|
float m1;
|
|
int32_t n_head_log2;
|
|
} ggml_metal_kargs_soft_max;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
int64_t ne10;
|
|
int64_t ne11;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
} ggml_metal_kargs_ssm_conv;
|
|
|
|
typedef struct {
|
|
int64_t d_state;
|
|
int64_t d_inner;
|
|
int64_t n_head;
|
|
int64_t n_group;
|
|
int64_t n_seq_tokens;
|
|
int64_t n_seqs;
|
|
uint64_t s_off;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t ns12;
|
|
uint64_t nb13;
|
|
uint64_t nb20;
|
|
uint64_t nb21;
|
|
uint64_t ns21;
|
|
uint64_t nb22;
|
|
int64_t ne30;
|
|
uint64_t nb31;
|
|
uint64_t nb41;
|
|
uint64_t nb42;
|
|
uint64_t ns42;
|
|
uint64_t nb43;
|
|
uint64_t nb51;
|
|
uint64_t nb52;
|
|
uint64_t ns52;
|
|
uint64_t nb53;
|
|
uint64_t nb0;
|
|
} ggml_metal_kargs_ssm_scan;
|
|
|
|
typedef struct {
|
|
int32_t ne00t;
|
|
int32_t ne00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne10;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_get_rows;
|
|
|
|
typedef struct {
|
|
int32_t nk0;
|
|
int32_t ne01;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int32_t ne11;
|
|
int32_t ne12;
|
|
uint64_t nb10;
|
|
uint64_t nb11;
|
|
uint64_t nb12;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_set_rows;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
float sf0;
|
|
float sf1;
|
|
float sf2;
|
|
float sf3;
|
|
} ggml_metal_kargs_upscale;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
} ggml_metal_kargs_pad;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
int64_t ne01;
|
|
int64_t ne02;
|
|
int64_t ne03;
|
|
uint64_t nb00;
|
|
uint64_t nb01;
|
|
uint64_t nb02;
|
|
uint64_t nb03;
|
|
int64_t ne0;
|
|
int64_t ne1;
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
uint64_t nb0;
|
|
uint64_t nb1;
|
|
uint64_t nb2;
|
|
uint64_t nb3;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
} ggml_metal_kargs_pad_reflect_1d;
|
|
|
|
typedef struct {
|
|
uint64_t nb1;
|
|
int dim;
|
|
int max_period;
|
|
} ggml_metal_kargs_timestep_embedding;
|
|
|
|
typedef struct {
|
|
float slope;
|
|
} ggml_metal_kargs_leaky_relu;
|
|
|
|
typedef struct {
|
|
int64_t ncols;
|
|
int64_t ncols_pad;
|
|
} ggml_metal_kargs_argsort;
|
|
|
|
typedef struct {
|
|
int64_t ne0;
|
|
float start;
|
|
float step;
|
|
} ggml_metal_kargs_arange;
|
|
|
|
typedef struct {
|
|
int32_t k0;
|
|
int32_t k1;
|
|
int32_t s0;
|
|
int32_t s1;
|
|
int32_t p0;
|
|
int32_t p1;
|
|
int64_t IH;
|
|
int64_t IW;
|
|
int64_t OH;
|
|
int64_t OW;
|
|
int64_t np;
|
|
} ggml_metal_kargs_pool_2d;
|
|
|
|
typedef struct {
|
|
int64_t ne00;
|
|
uint64_t nb01;
|
|
} ggml_metal_kargs_argmax;
|
|
|
|
#endif // GGML_METAL_IMPL
|