metal : add comments

This commit is contained in:
Georgi Gerganov
2025-09-28 18:08:13 +03:00
parent 0629437601
commit 50d2b21d7c

View File

@@ -4420,6 +4420,7 @@ constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_E
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
// pad the last chunk of C elements of k and v into a an extra pad buffer
kernel void kernel_flash_attn_ext_pad(
constant ggml_metal_kargs_flash_attn_ext_pad & args,
device const char * k,
@@ -4450,6 +4451,7 @@ kernel void kernel_flash_attn_ext_pad(
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
if (i1 >= icp) {
// here it is not important the exact value that will be used as we rely on masking out the scores in the attention
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
k_dst[i] = 0;
}
@@ -4663,6 +4665,7 @@ void kernel_flash_attn_ext_impl(
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
int ic = ic0;
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
@@ -5390,6 +5393,7 @@ void kernel_flash_attn_ext_vec_impl(
break;
}
// the last partial chunk uses the pad buffer as source
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
k = pad;
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;