mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : refactor + optimize (#15857)
* metal : refactor ggml-ci * cont : refactor FA-vec kernel * cont : print metal library load time * minor : warn to debug + bettern kernel names ggml-ci * metal : optimize mul_mv q8_0 ggml-ci * metal : simplify FA pipeline creation functions ggml-ci * metal : improve naming consistency * metal : safer function constants offsets ggml-ci * metal : comments ggml-ci
This commit is contained in:
@@ -20,8 +20,8 @@
|
||||
#define N_R0_Q5_1 4
|
||||
#define N_SG_Q5_1 2
|
||||
|
||||
#define N_R0_Q8_0 4
|
||||
#define N_SG_Q8_0 2
|
||||
#define N_R0_Q8_0 2
|
||||
#define N_SG_Q8_0 4
|
||||
|
||||
#define N_R0_MXFP4 2
|
||||
#define N_SG_MXFP4 2
|
||||
@@ -68,6 +68,11 @@
|
||||
#define N_R0_IQ4_XS 2
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -236,9 +241,11 @@ typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
int32_t ns10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ns20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
@@ -258,10 +265,43 @@ typedef struct {
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
int32_t ns10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ns20;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
int32_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext_vec;
|
||||
|
||||
typedef struct {
|
||||
int32_t nrows;
|
||||
int32_t ne20;
|
||||
} ggml_metal_kargs_flash_attn_ext_reduce;
|
||||
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user