metal : pad K, V and Mask when needed

This commit is contained in:
Georgi Gerganov
2025-09-21 17:59:31 +03:00
parent d8359f5fde
commit 5d0d2d2289
8 changed files with 420 additions and 42 deletions

View File

@@ -930,6 +930,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_UNUSED(op);
char base[256];
char name[256];
snprintf(base, 256, "kernel_%s",
"flash_attn_ext_pad");
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
base,
has_mask,
ncpsg);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const ggml_tensor * op,
@@ -937,6 +981,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg);
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
dk,
dv);
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
base,
has_mask,
has_sinks,
has_bias,
has_scap,
has_kvpad,
ns10,
ns20,
nsg, nwg);
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);

View File

@@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
bool has_mask,
int32_t ncpsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
ggml_metal_library_t lib,
const struct ggml_tensor * op,
@@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
bool has_sinks,
bool has_bias,
bool has_scap,
bool has_kvpad,
int32_t nsg,
int32_t nwg);

View File

@@ -72,11 +72,12 @@
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
#define FC_MUL_MV 400
#define FC_MUL_MM 500
#define FC_FLASH_ATTN_EXT_PAD 100
#define FC_FLASH_ATTN_EXT 200
#define FC_FLASH_ATTN_EXT_VEC 300
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
#define FC_MUL_MV 500
#define FC_MUL_MM 600
// kernel argument structs
//
@@ -246,6 +247,24 @@ typedef struct {
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
} ggml_metal_kargs_flash_attn_ext_pad;
typedef struct {
int32_t ne01;
int32_t ne02;
@@ -264,6 +283,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
@@ -298,6 +318,7 @@ typedef struct {
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne31;
int32_t ne32;
int32_t ne33;
uint64_t nb31;

View File

@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ggml_is_contiguous(node->src[1]), node->src[1]->name);
}
if (node->src[2]) {
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
ggml_is_contiguous(node->src[2]), node->src[2]->name);
}
if (node->src[3]) {
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
ggml_is_contiguous(node->src[3]), node->src[3]->name);
}
if (node) {
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
node->name);
@@ -1873,20 +1885,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
return (ne01 < 20) && (ne00 % 32 == 0);
}
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
size_t res = 0;
const bool has_mask = op->src[3] != nullptr;
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const bool has_kvpad = ne11 % 32 != 0;
if (has_kvpad) {
res += 32*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
} else {
const bool has_kvpad = ne11 % 64 != 0;
if (has_kvpad) {
res += 64*(
nb11*ne12*ne13 +
nb21*ne22*ne23 +
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
}
}
return res;
}
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
size_t res = 0;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
const int64_t nwg = 32;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
return res;
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
@@ -1908,8 +1969,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
@@ -1947,6 +2007,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne01 < 65536);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_pad = bid_dst;
bid_pad.offs += ggml_nbytes(op);
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
@@ -1956,6 +2021,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
}
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
// 2*(2*ncpsg)
@@ -2005,6 +2116,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@@ -2021,7 +2133,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@@ -2038,7 +2150,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
}
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -2054,6 +2167,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
const bool has_kvpad = ne11 % ncpsg != 0;
if (has_kvpad) {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
ggml_metal_kargs_flash_attn_ext_pad args0 = {
/*.ne11 =*/ne11,
/*.ne_12_2 =*/ne12,
/*.ne_12_3 =*/ne13,
/*.nb11 =*/nb11,
/*.nb12 =*/nb12,
/*.nb13 =*/nb13,
/*.nb21 =*/nb21,
/*.nb22 =*/nb22,
/*.nb23 =*/nb23,
/*.ne31 =*/ne31,
/*.ne32 =*/ne32,
/*.ne33 =*/ne33,
/*.nb31 =*/nb31,
/*.nb32 =*/nb32,
/*.nb33 =*/nb33,
};
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
ggml_metal_encoder_set_pipeline(enc, pipeline0);
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
if (op->src[3]) {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
} else {
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
}
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
assert(ne12 == ne22);
assert(ne13 == ne23);
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
} else {
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
}
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
@@ -2118,6 +2277,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne31 =*/ ne31,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
@@ -2134,7 +2294,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.logit_softcap =*/ logit_softcap,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
@@ -2161,7 +2321,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
if (nwg == 1) {
// using 1 workgroup -> write the result directly into dst
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@@ -2171,12 +2332,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
// write the results from each workgroup into a temp buffer
ggml_metal_buffer_id bid_tmp = bid_dst;
bid_tmp.offs += ggml_nbytes(op);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op);
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);

View File

@@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
// return true if we should use the FA vector kernel for this op
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);

View File

@@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
}
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
} break;
default:
break;

View File

@@ -4416,10 +4416,79 @@ kernel void kernel_leaky_relu_f32_4(
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
}
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
kernel void kernel_flash_attn_ext_pad(
constant ggml_metal_kargs_flash_attn_ext_pad & args,
device const char * k,
device const char * v,
device const char * mask,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
device char * k_pad = dst;
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
const int32_t icp = args.ne11 % C;
const int32_t ic0 = args.ne11 - icp;
const int32_t i1 = tgpig[0];
const int32_t i2 = tgpig[1];
const int32_t i3 = tgpig[2];
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
if (i1 >= icp) {
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = 0;
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = 0;
}
} else {
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = k_src[i];
}
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
v_dst[i] = v_src[i];
}
}
}
if (FC_flash_attn_ext_pad_has_mask) {
if (i2 < args.ne32 && i3 < args.ne33) {
for (int ib = i1; ib < args.ne31; ib += C) {
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
for (int i = tiitg; i < C; i += ntg.x) {
if (i >= icp) {
mask_dst[i] = -MAXHALF;
} else {
mask_dst[i] = mask_src[i];
}
}
}
}
}
}
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
@@ -4466,6 +4535,7 @@ void kernel_flash_attn_ext_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16,
uint3 tgpig,
@@ -4521,6 +4591,7 @@ void kernel_flash_attn_ext_impl(
// mask storage in shared mem
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
threadgroup half * sm = (threadgroup half *) (sm2);
// per-query mask pointers
device const half2 * pm2[NQ];
@@ -4590,7 +4661,44 @@ void kernel_flash_attn_ext_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic = 0; ic < args.ne11; ic += C) {
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
int ic = ic0;
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_has_mask) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
for (short i = tiisg; i < C; i += NW) {
if (ic + i >= args.ne11) {
sm[2*j*SH + i] = -MAXHALF;
}
}
}
} else {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
pm2[jj] = (device const half2 *) ((device const half *) mask +
(iq1 + j)*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32));
}
}
ic = 0;
}
// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
@@ -4624,7 +4732,7 @@ void kernel_flash_attn_ext_impl(
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
threadgroup const q_t * pq = sq;
threadgroup s_t * ps = ss;
@@ -4696,7 +4804,7 @@ void kernel_flash_attn_ext_impl(
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
for (short ii = 0; ii < DK16; ii += 4) {
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
if (DK16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
@@ -4818,7 +4926,7 @@ void kernel_flash_attn_ext_impl(
{
auto sst = ss;
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
pv += 8*sgitg;
@@ -4860,7 +4968,7 @@ void kernel_flash_attn_ext_impl(
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
if (DV16%4 == 0) {
// no need for bound checks
@@ -5004,13 +5112,14 @@ kernel void kernel_flash_attn_ext(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_nsg) {
// note: disabled cases to reduce library load time
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
@@ -5130,6 +5239,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
@@ -5167,6 +5277,7 @@ void kernel_flash_attn_ext_vec_impl(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@@ -5273,11 +5384,36 @@ void kernel_flash_attn_ext_vec_impl(
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
const int ic = ic0 + C*sgitg;
int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
}
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
if (!FC_flash_attn_ext_vec_has_mask) {
if (ic + tiisg >= args.ne11) {
sm[tiisg] = -MAXHALF;
}
} else {
pm = (device const half *) (mask) +
iq1*C +
(iq2%args.ne32)*(C*args.ne31) +
(iq3%args.ne33)*(C*args.ne31*args.ne32);
}
ic = 0;
}
if (FC_flash_attn_ext_vec_has_mask) {
sm[tiisg] = pm[ic + tiisg];
}
@@ -5289,7 +5425,7 @@ void kernel_flash_attn_ext_vec_impl(
// Q*K^T
{
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
threadgroup const q4_t * pq4 = sq4;
pk4 += ty*NS10/4 + tx;
@@ -5304,7 +5440,7 @@ void kernel_flash_attn_ext_vec_impl(
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
}
} else {
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
k4_t mk;
@@ -5402,7 +5538,7 @@ void kernel_flash_attn_ext_vec_impl(
}
if (is_same<vd4_t, v4_t>::value) {
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
pv4 += ty*NS20/4 + tx;
@@ -5415,7 +5551,7 @@ void kernel_flash_attn_ext_vec_impl(
}
} else {
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
const short i = ii*NL + tx;
@@ -5587,13 +5723,14 @@ kernel void kernel_flash_attn_ext_vec(
device const char * v,
device const char * mask,
device const char * sinks,
device const char * pad,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
switch (FC_flash_attn_ext_vec_nsg) {
// note: disabled cases to reduce library load time
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;

View File

@@ -6627,7 +6627,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
for (int nr2 : { 1, 4, 16 }) {
if (nr2 == 16 && hsk != 128) continue;
for (int kv : { 512, 1024, }) {
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
for (int kv : { 113, 512, 1024, }) {
if (nr2 != 1 && kv != 512) continue;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {