diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d636f7d776..7acc5c29cf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4420,6 +4420,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]]; +// pad the last chunk of C elements of k and v into a an extra pad buffer kernel void kernel_flash_attn_ext_pad( constant ggml_metal_kargs_flash_attn_ext_pad & args, device const char * k, @@ -4450,6 +4451,7 @@ kernel void kernel_flash_attn_ext_pad( device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { k_dst[i] = 0; } @@ -4663,6 +4665,7 @@ void kernel_flash_attn_ext_impl( for (int ic0 = 0; ic0 < args.ne11; ic0 += C) { int ic = ic0; + // the last partial chunk uses the pad buffer as source if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) { k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; @@ -5390,6 +5393,7 @@ void kernel_flash_attn_ext_vec_impl( break; } + // the last partial chunk uses the pad buffer as source if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;