mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : remove mask padding requirement
This commit is contained in:
@@ -994,19 +994,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
||||
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
||||
|
||||
// do bounds checks for the mask?
|
||||
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
bc_mask,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg);
|
||||
@@ -1024,6 +1028,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||
ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
|
||||
|
||||
@@ -1979,8 +1979,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ne12 == ne22);
|
||||
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
|
||||
@@ -4492,6 +4492,8 @@ constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT
|
||||
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
||||
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
|
||||
|
||||
constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||
|
||||
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
||||
//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
|
||||
@@ -4708,7 +4710,12 @@ void kernel_flash_attn_ext_impl(
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
if (FC_flash_attn_ext_bc_mask) {
|
||||
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
||||
} else {
|
||||
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
||||
}
|
||||
|
||||
pm2[jj] += NW;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user