mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : pad K, V and Mask when needed
This commit is contained in:
@@ -930,6 +930,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_pad");
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const ggml_tensor * op,
|
||||
@@ -937,6 +981,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
@@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg);
|
||||
@@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||
@@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
@@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg, nwg);
|
||||
@@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
||||
|
||||
@@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
@@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
|
||||
@@ -72,11 +72,12 @@
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
#define FC_FLASH_ATTN_EXT_PAD 100
|
||||
#define FC_FLASH_ATTN_EXT 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC 300
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
|
||||
#define FC_MUL_MV 500
|
||||
#define FC_MUL_MM 600
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
@@ -246,6 +247,24 @@ typedef struct {
|
||||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_pad;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
@@ -264,6 +283,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
@@ -298,6 +318,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
|
||||
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
||||
|
||||
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
||||
}
|
||||
if (node->src[2]) {
|
||||
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
||||
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
||||
}
|
||||
if (node->src[3]) {
|
||||
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
||||
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
||||
}
|
||||
if (node) {
|
||||
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
node->name);
|
||||
@@ -1873,20 +1885,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
||||
return (ne01 < 20) && (ne00 % 32 == 0);
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % 32 != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += 32*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % 64 != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += 64*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int64_t nwg = 32;
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
const int64_t ne01 = op->src[0]->ne[1];
|
||||
const int64_t ne02 = op->src[0]->ne[2];
|
||||
const int64_t ne03 = op->src[0]->ne[3];
|
||||
const int64_t ne20 = op->src[2]->ne[0];
|
||||
size_t res = 0;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const int64_t nwg = 32;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
@@ -1908,8 +1969,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
||||
@@ -1947,6 +2007,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
GGML_ASSERT(ne01 < 65536);
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||
bid_pad.offs += ggml_nbytes(op);
|
||||
|
||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// half8x8 kernel
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
@@ -1956,6 +2021,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
||||
|
||||
// 2*(2*ncpsg)
|
||||
@@ -2005,6 +2116,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2021,7 +2133,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
@@ -2038,7 +2150,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -2054,6 +2167,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
@@ -2118,6 +2277,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2134,7 +2294,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
@@ -2161,7 +2321,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
if (nwg == 1) {
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -2171,12 +2332,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
// write the results from each workgroup into a temp buffer
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
||||
bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
|
||||
@@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -4416,10 +4416,79 @@ kernel void kernel_leaky_relu_f32_4(
|
||||
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
||||
|
||||
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
|
||||
|
||||
kernel void kernel_flash_attn_ext_pad(
|
||||
constant ggml_metal_kargs_flash_attn_ext_pad & args,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
|
||||
|
||||
device char * k_pad = dst;
|
||||
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const int32_t icp = args.ne11 % C;
|
||||
const int32_t ic0 = args.ne11 - icp;
|
||||
|
||||
const int32_t i1 = tgpig[0];
|
||||
const int32_t i2 = tgpig[1];
|
||||
const int32_t i3 = tgpig[2];
|
||||
|
||||
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
|
||||
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
|
||||
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
|
||||
|
||||
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
|
||||
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
|
||||
|
||||
if (i1 >= icp) {
|
||||
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||
k_dst[i] = 0;
|
||||
}
|
||||
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||
v_dst[i] = 0;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||
k_dst[i] = k_src[i];
|
||||
}
|
||||
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||
v_dst[i] = v_src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_flash_attn_ext_pad_has_mask) {
|
||||
if (i2 < args.ne32 && i3 < args.ne33) {
|
||||
for (int ib = i1; ib < args.ne31; ib += C) {
|
||||
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
|
||||
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
|
||||
|
||||
for (int i = tiitg; i < C; i += ntg.x) {
|
||||
if (i >= icp) {
|
||||
mask_dst[i] = -MAXHALF;
|
||||
} else {
|
||||
mask_dst[i] = mask_src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
||||
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
||||
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
||||
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
||||
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
|
||||
|
||||
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
||||
@@ -4466,6 +4535,7 @@ void kernel_flash_attn_ext_impl(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16,
|
||||
uint3 tgpig,
|
||||
@@ -4521,6 +4591,7 @@ void kernel_flash_attn_ext_impl(
|
||||
|
||||
// mask storage in shared mem
|
||||
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
|
||||
threadgroup half * sm = (threadgroup half *) (sm2);
|
||||
|
||||
// per-query mask pointers
|
||||
device const half2 * pm2[NQ];
|
||||
@@ -4590,7 +4661,44 @@ void kernel_flash_attn_ext_impl(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic = 0; ic < args.ne11; ic += C) {
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
|
||||
int ic = ic0;
|
||||
|
||||
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
|
||||
k = pad;
|
||||
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||
|
||||
if (!FC_flash_attn_ext_has_mask) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
for (short i = tiisg; i < C; i += NW) {
|
||||
if (ic + i >= args.ne11) {
|
||||
sm[2*j*SH + i] = -MAXHALF;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
pm2[jj] = (device const half2 *) ((device const half *) mask +
|
||||
(iq1 + j)*C +
|
||||
(iq2%args.ne32)*(C*args.ne31) +
|
||||
(iq3%args.ne33)*(C*args.ne31*args.ne32));
|
||||
}
|
||||
}
|
||||
|
||||
ic = 0;
|
||||
}
|
||||
|
||||
// read the mask into shared mem
|
||||
if (FC_flash_attn_ext_has_mask) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
@@ -4624,7 +4732,7 @@ void kernel_flash_attn_ext_impl(
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
|
||||
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
|
||||
threadgroup const q_t * pq = sq;
|
||||
threadgroup s_t * ps = ss;
|
||||
|
||||
@@ -4696,7 +4804,7 @@ void kernel_flash_attn_ext_impl(
|
||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||
|
||||
for (short ii = 0; ii < DK16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
|
||||
|
||||
if (DK16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
@@ -4818,7 +4926,7 @@ void kernel_flash_attn_ext_impl(
|
||||
{
|
||||
auto sst = ss;
|
||||
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
|
||||
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
|
||||
|
||||
pv += 8*sgitg;
|
||||
|
||||
@@ -4860,7 +4968,7 @@ void kernel_flash_attn_ext_impl(
|
||||
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
||||
|
||||
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
|
||||
|
||||
if (DV16%4 == 0) {
|
||||
// no need for bound checks
|
||||
@@ -5004,13 +5112,14 @@ kernel void kernel_flash_attn_ext(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
switch (FC_flash_attn_ext_nsg) {
|
||||
// note: disabled cases to reduce library load time
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
@@ -5130,6 +5239,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
|
||||
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
|
||||
|
||||
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
|
||||
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
|
||||
@@ -5167,6 +5277,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@@ -5273,11 +5384,36 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
int ic = ic0 + C*sgitg;
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
|
||||
k = pad;
|
||||
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||
|
||||
if (!FC_flash_attn_ext_vec_has_mask) {
|
||||
if (ic + tiisg >= args.ne11) {
|
||||
sm[tiisg] = -MAXHALF;
|
||||
}
|
||||
} else {
|
||||
pm = (device const half *) (mask) +
|
||||
iq1*C +
|
||||
(iq2%args.ne32)*(C*args.ne31) +
|
||||
(iq3%args.ne33)*(C*args.ne31*args.ne32);
|
||||
}
|
||||
|
||||
ic = 0;
|
||||
}
|
||||
|
||||
if (FC_flash_attn_ext_vec_has_mask) {
|
||||
sm[tiisg] = pm[ic + tiisg];
|
||||
}
|
||||
@@ -5289,7 +5425,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
|
||||
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
|
||||
threadgroup const q4_t * pq4 = sq4;
|
||||
|
||||
pk4 += ty*NS10/4 + tx;
|
||||
@@ -5304,7 +5440,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
|
||||
}
|
||||
} else {
|
||||
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
|
||||
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
|
||||
|
||||
k4_t mk;
|
||||
|
||||
@@ -5402,7 +5538,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
}
|
||||
|
||||
if (is_same<vd4_t, v4_t>::value) {
|
||||
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
|
||||
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
|
||||
|
||||
pv4 += ty*NS20/4 + tx;
|
||||
|
||||
@@ -5415,7 +5551,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
|
||||
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
|
||||
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
|
||||
|
||||
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
|
||||
const short i = ii*NL + tx;
|
||||
@@ -5587,13 +5723,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
switch (FC_flash_attn_ext_vec_nsg) {
|
||||
// note: disabled cases to reduce library load time
|
||||
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
|
||||
@@ -6627,7 +6627,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
|
||||
for (int nr2 : { 1, 4, 16 }) {
|
||||
if (nr2 == 16 && hsk != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
||||
for (int kv : { 113, 512, 1024, }) {
|
||||
if (nr2 != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
|
||||
Reference in New Issue
Block a user