CUDA: use mma PTX instructions for FlashAttention (#11583)

* CUDA: use mma PTX instructions for FlashAttention

* __shfl_sync workaround for movmatrix

* add __shfl_sync to HIP

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
Johannes Gäßler
2025-02-02 19:31:09 +01:00
committed by GitHub
parent 84ec8a58f7
commit 864a0b67a6
29 changed files with 2058 additions and 998 deletions

View File

@@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>