diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e46f0e2081..d4ed938391 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max( all_inf = warp_reduce_all(all_inf); if (!all_inf) { - KV_max_sj += FATTN_KQ_STRIDE; break; } } + // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. + // If the break was triggered it's the lower edge of the tile with the first non-masked values. + // In either case, walk back the decrementation by FATTN_KQ_STRIDE. + KV_max_sj += FATTN_KQ_STRIDE; + if (threadIdx.x != 0) { return; }