CUDA: refactor FA support/selection code (#15454)

This commit is contained in:
Johannes Gäßler
2025-08-20 23:14:14 +02:00
committed by GitHub
parent 7a6e91ad26
commit 13aeb7aef2
4 changed files with 165 additions and 111 deletions

View File

@@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results(
dst[tid] = VKQ_numerator / VKQ_denominator;
}
[[noreturn]]
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
fprintf(stderr, "By default only f16 KV cache is supported.\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
GGML_ABORT("fatal error");
} else if (D == 128) {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
} else {
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
fprintf(stderr, "Only f16 is supported.\n");
GGML_ABORT("fatal error");
}
}
template <int DV, int ncols1, int ncols2>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,

View File

@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
GGML_ABORT("fatal error");
}
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
#endif // GGML_CUDA_FA_ALL_QUANTS
on_no_fattn_vec_case(Q->ne[0]);
GGML_ABORT("fatal error");
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// Best FlashAttention kernel for a specific GPU:
enum best_fattn_kernel {
BEST_FATTN_KERNEL_NONE = 0,
BEST_FATTN_KERNEL_TILE_F32 = 200,
BEST_FATTN_KERNEL_TILE_F16 = 210,
BEST_FATTN_KERNEL_VEC_F32 = 100,
BEST_FATTN_KERNEL_VEC_F16 = 110,
BEST_FATTN_KERNEL_WMMA_F16 = 300,
BEST_FATTN_KERNEL_MMA_F16 = 400,
};
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
#ifndef FLASH_ATTN_AVAILABLE
GGML_UNUSED(device); GGML_UNUSED(dst);
return BEST_FATTN_KERNEL_NONE;
#endif// FLASH_ATTN_AVAILABLE
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
}
return;
}
if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
switch (K->ne[0]) {
case 64:
case 128:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
} else {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
break;
case 80:
case 96:
case 112:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
}
return;
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 576:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
#ifndef GGML_CUDA_FA_ALL_QUANTS
if (K->type != V->type) {
return BEST_FATTN_KERNEL_NONE;
}
#endif // GGML_CUDA_FA_ALL_QUANTS
switch (K->type) {
case GGML_TYPE_F16:
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
#ifndef GGML_CUDA_FA_ALL_QUANTS
return BEST_FATTN_KERNEL_NONE;
#endif // GGML_CUDA_FA_ALL_QUANTS
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
#ifdef GGML_CUDA_FA_ALL_QUANTS
if (K->ne[0] != 128 && K->ne[0] != 64) {
return BEST_FATTN_KERNEL_NONE;
}
#else
if (K->ne[0] != 128) {
return BEST_FATTN_KERNEL_NONE;
}
#endif // GGML_CUDA_FA_ALL_QUANTS
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
switch (V->type) {
case GGML_TYPE_F16:
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
if (K->ne[0] != 128) {
return BEST_FATTN_KERNEL_NONE;
}
break;
default:
return BEST_FATTN_KERNEL_NONE;
}
if (mask && mask->ne[2] != 1) {
return BEST_FATTN_KERNEL_NONE;
}
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
// If Turing tensor cores available, use them except for some cases with batch size 1:
if (turing_mma_available(cc)) {
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
return BEST_FATTN_KERNEL_VEC_F16;
}
return BEST_FATTN_KERNEL_VEC_F32;
}
return;
return BEST_FATTN_KERNEL_MMA_F16;
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
// Use kernels specializes for small batch sizes if possible:
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
return BEST_FATTN_KERNEL_VEC_F16;
}
return BEST_FATTN_KERNEL_VEC_F32;
}
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
// For large batch sizes, use the WMMA kernel if possible:
if (fp16_mma_available(cc)) {
return BEST_FATTN_KERNEL_WMMA_F16;
}
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
return BEST_FATTN_KERNEL_TILE_F16;
}
return BEST_FATTN_KERNEL_TILE_F32;
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_set_device(ctx.device);
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
case BEST_FATTN_KERNEL_NONE:
GGML_ABORT("fatal error");
case BEST_FATTN_KERNEL_TILE_F32:
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
break;
case BEST_FATTN_KERNEL_TILE_F16:
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
break;
case BEST_FATTN_KERNEL_VEC_F32:
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
break;
case BEST_FATTN_KERNEL_VEC_F16:
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
break;
case BEST_FATTN_KERNEL_WMMA_F16:
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
break;
case BEST_FATTN_KERNEL_MMA_F16:
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
break;
}
}
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
}

View File

@@ -1,3 +1,5 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);

View File

@@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!turing_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: