metal : refactor + optimize (#15857)

* metal : refactor

ggml-ci

* cont : refactor FA-vec kernel

* cont : print metal library load time

* minor : warn to debug + bettern kernel names

ggml-ci

* metal : optimize mul_mv q8_0

ggml-ci

* metal : simplify FA pipeline creation functions

ggml-ci

* metal : improve naming consistency

* metal : safer function constants offsets

ggml-ci

* metal : comments

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-09-08 13:34:56 +03:00
committed by GitHub
parent 9fcb29f22f
commit f28d4f4ac9
4 changed files with 1415 additions and 1325 deletions

View File

@@ -20,8 +20,8 @@
#define N_R0_Q5_1 4
#define N_SG_Q5_1 2
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_Q8_0 2
#define N_SG_Q8_0 4
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
@@ -68,6 +68,11 @@
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -236,9 +241,11 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
@@ -258,10 +265,43 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext_vec;
typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
typedef struct {
int32_t ne00;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff