mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
metal : make the FA extra sizes consistent (#17143)
This commit is contained in:
@@ -1975,7 +1975,9 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|||||||
const bool has_mask = op->src[3] != nullptr;
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
|
|
||||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
// note: always reserve the padding space to avoid graph reallocations
|
||||||
|
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||||
|
const bool has_kvpad = true;
|
||||||
|
|
||||||
if (has_kvpad) {
|
if (has_kvpad) {
|
||||||
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
||||||
@@ -1984,7 +1986,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
|||||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||||
|
const bool has_kvpad = true;
|
||||||
|
|
||||||
if (has_kvpad) {
|
if (has_kvpad) {
|
||||||
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
||||||
@@ -2020,9 +2023,10 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|||||||
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
||||||
|
|
||||||
// this optimization is not useful for the vector kernels
|
// this optimization is not useful for the vector kernels
|
||||||
if (is_vec) {
|
// note: always reserve the blk buffer to avoid graph reallocations
|
||||||
return res;
|
//if (is_vec) {
|
||||||
}
|
// return res;
|
||||||
|
//}
|
||||||
|
|
||||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||||
@@ -2049,13 +2053,16 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
|||||||
|
|
||||||
size_t res = 0;
|
size_t res = 0;
|
||||||
|
|
||||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
// note: always reserve the temp buffer to avoid graph reallocations
|
||||||
|
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
|
if (true) {
|
||||||
const int64_t nwg = 32;
|
const int64_t nwg = 32;
|
||||||
|
const int64_t ne01_max = std::min(ne01, 32);
|
||||||
|
|
||||||
// temp buffer for writing the results from each workgroup
|
// temp buffer for writing the results from each workgroup
|
||||||
// - ne20: the size of the Value head
|
// - ne20: the size of the Value head
|
||||||
// - + 2: the S and M values for each intermediate result
|
// - + 2: the S and M values for each intermediate result
|
||||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
|||||||
Reference in New Issue
Block a user