mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : add support for non-padded FA KV (#16148)
* metal : pad K, V and Mask when needed * cont : simplify * cuda : add TODO about KV padding requirement * metal : add comments * metal : remove mask padding requirement
This commit is contained in:
@@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
// TODO: temporary until support is extended
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
|
||||
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 64:
|
||||
case 128:
|
||||
|
||||
@@ -924,6 +924,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_pad");
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const ggml_tensor * op,
|
||||
@@ -931,6 +975,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
@@ -943,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
||||
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
||||
|
||||
// do bounds checks for the mask?
|
||||
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
bc_mask,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg);
|
||||
@@ -970,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||
@@ -989,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
@@ -1008,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg, nwg);
|
||||
@@ -1029,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
||||
|
||||
@@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
@@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
@@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
|
||||
@@ -69,11 +69,12 @@
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
#define FC_FLASH_ATTN_EXT_PAD 100
|
||||
#define FC_FLASH_ATTN_EXT 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC 300
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
|
||||
#define FC_MUL_MV 500
|
||||
#define FC_MUL_MM 600
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
@@ -244,6 +245,24 @@ typedef struct {
|
||||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_pad;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
@@ -262,6 +281,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
@@ -296,6 +316,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
|
||||
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
||||
|
||||
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
||||
}
|
||||
if (node->src[2]) {
|
||||
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
||||
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
||||
}
|
||||
if (node->src[3]) {
|
||||
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
||||
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
||||
}
|
||||
if (node) {
|
||||
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
node->name);
|
||||
@@ -1889,20 +1901,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
||||
return (ne01 < 20) && (ne00 % 32 == 0);
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % 32 != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += 32*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % 64 != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += 64*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int64_t nwg = 32;
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
const int64_t ne01 = op->src[0]->ne[1];
|
||||
const int64_t ne02 = op->src[0]->ne[2];
|
||||
const int64_t ne03 = op->src[0]->ne[3];
|
||||
const int64_t ne20 = op->src[2]->ne[0];
|
||||
size_t res = 0;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const int64_t nwg = 32;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
@@ -1924,8 +1985,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
||||
@@ -1935,8 +1995,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ne12 == ne22);
|
||||
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
@@ -1963,6 +2023,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
GGML_ASSERT(ne01 < 65536);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
||||
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
||||
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||
bid_pad.offs += ggml_nbytes(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_pad;
|
||||
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||
|
||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// half8x8 kernel
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
@@ -1972,6 +2046,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
||||
|
||||
// 2*(2*ncpsg)
|
||||
@@ -2021,6 +2137,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2037,24 +2154,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -2070,6 +2180,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
@@ -2134,6 +2286,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2150,25 +2303,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
@@ -2176,23 +2321,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
||||
|
||||
if (nwg == 1) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
||||
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
} else {
|
||||
// sanity checks
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
||||
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
// write the results from each workgroup into a temp buffer
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
|
||||
@@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -4349,10 +4349,83 @@ kernel void kernel_leaky_relu_f32_4(
|
||||
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
||||
|
||||
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
|
||||
|
||||
// pad the last chunk of C elements of k and v into a an extra pad buffer
|
||||
kernel void kernel_flash_attn_ext_pad(
|
||||
constant ggml_metal_kargs_flash_attn_ext_pad & args,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
|
||||
|
||||
device char * k_pad = dst;
|
||||
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const int32_t icp = args.ne11 % C;
|
||||
const int32_t ic0 = args.ne11 - icp;
|
||||
|
||||
const int32_t i1 = tgpig[0];
|
||||
const int32_t i2 = tgpig[1];
|
||||
const int32_t i3 = tgpig[2];
|
||||
|
||||
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
|
||||
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
|
||||
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
|
||||
|
||||
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
|
||||
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
|
||||
|
||||
if (i1 >= icp) {
|
||||
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
|
||||
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||
k_dst[i] = 0;
|
||||
}
|
||||
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||
v_dst[i] = 0;
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||
k_dst[i] = k_src[i];
|
||||
}
|
||||
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||
v_dst[i] = v_src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_flash_attn_ext_pad_has_mask) {
|
||||
if (i2 < args.ne32 && i3 < args.ne33) {
|
||||
for (int ib = i1; ib < args.ne31; ib += C) {
|
||||
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
|
||||
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
|
||||
|
||||
for (int i = tiitg; i < C; i += ntg.x) {
|
||||
if (i >= icp) {
|
||||
mask_dst[i] = -MAXHALF;
|
||||
} else {
|
||||
mask_dst[i] = mask_src[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
||||
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
||||
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
||||
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
||||
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
|
||||
|
||||
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||
|
||||
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
||||
@@ -4399,6 +4472,7 @@ void kernel_flash_attn_ext_impl(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16,
|
||||
uint3 tgpig,
|
||||
@@ -4523,13 +4597,58 @@ void kernel_flash_attn_ext_impl(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic = 0; ic < args.ne11; ic += C) {
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
|
||||
int ic = ic0;
|
||||
|
||||
// the last partial chunk uses the pad buffer as source
|
||||
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
|
||||
k = pad;
|
||||
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||
|
||||
if (!FC_flash_attn_ext_has_mask) {
|
||||
threadgroup half * sm = (threadgroup half *) (sm2);
|
||||
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
for (short i = tiisg; i < C; i += NW) {
|
||||
if (ic + i >= args.ne11) {
|
||||
sm[2*j*SH + i] = -MAXHALF;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
pm2[jj] = (device const half2 *) ((device const half *) mask +
|
||||
(iq1 + j)*C +
|
||||
(iq2%args.ne32)*(C*args.ne31) +
|
||||
(iq3%args.ne33)*(C*args.ne31*args.ne32));
|
||||
}
|
||||
}
|
||||
|
||||
ic = 0;
|
||||
}
|
||||
|
||||
// read the mask into shared mem
|
||||
if (FC_flash_attn_ext_has_mask) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
||||
if (FC_flash_attn_ext_bc_mask) {
|
||||
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
||||
} else {
|
||||
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
||||
}
|
||||
|
||||
pm2[jj] += NW;
|
||||
}
|
||||
|
||||
@@ -4557,7 +4676,7 @@ void kernel_flash_attn_ext_impl(
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
|
||||
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
|
||||
threadgroup const q_t * pq = sq;
|
||||
threadgroup s_t * ps = ss;
|
||||
|
||||
@@ -4629,7 +4748,7 @@ void kernel_flash_attn_ext_impl(
|
||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||
|
||||
for (short ii = 0; ii < DK16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
|
||||
|
||||
if (DK16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
@@ -4751,7 +4870,7 @@ void kernel_flash_attn_ext_impl(
|
||||
{
|
||||
auto sst = ss;
|
||||
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
|
||||
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
|
||||
|
||||
pv += 8*sgitg;
|
||||
|
||||
@@ -4793,7 +4912,7 @@ void kernel_flash_attn_ext_impl(
|
||||
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
||||
|
||||
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
|
||||
|
||||
if (DV16%4 == 0) {
|
||||
// no need for bound checks
|
||||
@@ -4937,13 +5056,14 @@ kernel void kernel_flash_attn_ext(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
switch (FC_flash_attn_ext_nsg) {
|
||||
// note: disabled cases to reduce library load time
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
@@ -5063,6 +5183,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
|
||||
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
|
||||
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
|
||||
|
||||
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
|
||||
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
|
||||
@@ -5100,6 +5221,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@@ -5206,11 +5328,37 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
int ic = ic0 + C*sgitg;
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// the last partial chunk uses the pad buffer as source
|
||||
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
|
||||
k = pad;
|
||||
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||
|
||||
if (!FC_flash_attn_ext_vec_has_mask) {
|
||||
if (ic + tiisg >= args.ne11) {
|
||||
sm[tiisg] = -MAXHALF;
|
||||
}
|
||||
} else {
|
||||
pm = (device const half *) (mask) +
|
||||
iq1*C +
|
||||
(iq2%args.ne32)*(C*args.ne31) +
|
||||
(iq3%args.ne33)*(C*args.ne31*args.ne32);
|
||||
}
|
||||
|
||||
ic = 0;
|
||||
}
|
||||
|
||||
if (FC_flash_attn_ext_vec_has_mask) {
|
||||
sm[tiisg] = pm[ic + tiisg];
|
||||
}
|
||||
@@ -5222,7 +5370,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
|
||||
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
|
||||
threadgroup const q4_t * pq4 = sq4;
|
||||
|
||||
pk4 += ty*NS10/4 + tx;
|
||||
@@ -5237,7 +5385,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
|
||||
}
|
||||
} else {
|
||||
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
|
||||
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
|
||||
|
||||
k4_t mk;
|
||||
|
||||
@@ -5335,7 +5483,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
}
|
||||
|
||||
if (is_same<vd4_t, v4_t>::value) {
|
||||
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
|
||||
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
|
||||
|
||||
pv4 += ty*NS20/4 + tx;
|
||||
|
||||
@@ -5348,7 +5496,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
|
||||
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
|
||||
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
|
||||
|
||||
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
|
||||
const short i = ii*NL + tx;
|
||||
@@ -5520,13 +5668,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
switch (FC_flash_attn_ext_vec_nsg) {
|
||||
// note: disabled cases to reduce library load time
|
||||
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
|
||||
Reference in New Issue
Block a user