mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
CUDA: refactor FA support/selection code (#15454)
This commit is contained in:
@@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results(
|
|||||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
static void on_no_fattn_vec_case(const int D) {
|
|
||||||
if (D == 64) {
|
|
||||||
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
|
||||||
fprintf(stderr, "By default only f16 KV cache is supported.\n");
|
|
||||||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} else if (D == 128) {
|
|
||||||
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
|
|
||||||
fprintf(stderr, "Supported combinations:\n");
|
|
||||||
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
|
|
||||||
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
|
|
||||||
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
|
|
||||||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
|
||||||
fprintf(stderr, "Only f16 is supported.\n");
|
|
||||||
GGML_ABORT("fatal error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int DV, int ncols1, int ncols2>
|
template <int DV, int ncols1, int ncols2>
|
||||||
void launch_fattn(
|
void launch_fattn(
|
||||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
|
|||||||
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
|
||||||
on_no_fattn_vec_case(Q->ne[0]);
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||||
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|||||||
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
||||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
|
||||||
on_no_fattn_vec_case(Q->ne[0]);
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
// Best FlashAttention kernel for a specific GPU:
|
||||||
|
enum best_fattn_kernel {
|
||||||
|
BEST_FATTN_KERNEL_NONE = 0,
|
||||||
|
BEST_FATTN_KERNEL_TILE_F32 = 200,
|
||||||
|
BEST_FATTN_KERNEL_TILE_F16 = 210,
|
||||||
|
BEST_FATTN_KERNEL_VEC_F32 = 100,
|
||||||
|
BEST_FATTN_KERNEL_VEC_F16 = 110,
|
||||||
|
BEST_FATTN_KERNEL_WMMA_F16 = 300,
|
||||||
|
BEST_FATTN_KERNEL_MMA_F16 = 400,
|
||||||
|
};
|
||||||
|
|
||||||
|
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
||||||
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
|
GGML_UNUSED(device); GGML_UNUSED(dst);
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
#endif// FLASH_ATTN_AVAILABLE
|
||||||
|
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
|
||||||
|
const int cc = ggml_cuda_info().devices[device].cc;
|
||||||
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||||
|
|
||||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
switch (K->ne[0]) {
|
||||||
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
|
case 64:
|
||||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
case 128:
|
||||||
return;
|
case 256:
|
||||||
}
|
if (V->ne[0] != K->ne[0]) {
|
||||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
|
||||||
if (!fast_fp16_available(cc)) {
|
|
||||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
||||||
} else {
|
|
||||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!fp16_mma_available(cc)) {
|
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
|
||||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
||||||
} else {
|
|
||||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
|
||||||
}
|
}
|
||||||
} else {
|
break;
|
||||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
case 80:
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
case 96:
|
||||||
} else {
|
case 112:
|
||||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
if (V->ne[0] != K->ne[0]) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
}
|
if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
||||||
return;
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 576:
|
||||||
|
if (V->ne[0] != 512) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
if (K->type != V->type) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
|
||||||
|
switch (K->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
#ifndef GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
if (K->ne[0] != 128 && K->ne[0] != 64) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (K->ne[0] != 128) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (V->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
if (K->ne[0] != 128) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask && mask->ne[2] != 1) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
|
||||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
|
||||||
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
|
|
||||||
const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
|
|
||||||
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
|
||||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
||||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
if (turing_mma_available(cc)) {
|
||||||
} else {
|
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
|
||||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||||
|
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
|
||||||
|
const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
|
||||||
|
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
|
||||||
|
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||||
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
|
return BEST_FATTN_KERNEL_VEC_F16;
|
||||||
|
}
|
||||||
|
return BEST_FATTN_KERNEL_VEC_F32;
|
||||||
}
|
}
|
||||||
return;
|
return BEST_FATTN_KERNEL_MMA_F16;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
// Use kernels specializes for small batch sizes if possible:
|
||||||
if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
|
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
return;
|
return BEST_FATTN_KERNEL_VEC_F16;
|
||||||
|
}
|
||||||
|
return BEST_FATTN_KERNEL_VEC_F32;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
// For large batch sizes, use the WMMA kernel if possible:
|
||||||
|
if (fp16_mma_available(cc)) {
|
||||||
|
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||||
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||||
|
return BEST_FATTN_KERNEL_TILE_F16;
|
||||||
|
}
|
||||||
|
return BEST_FATTN_KERNEL_TILE_F32;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_set_device(ctx.device);
|
||||||
|
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
||||||
|
case BEST_FATTN_KERNEL_NONE:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
case BEST_FATTN_KERNEL_TILE_F32:
|
||||||
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||||
|
break;
|
||||||
|
case BEST_FATTN_KERNEL_TILE_F16:
|
||||||
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||||
|
break;
|
||||||
|
case BEST_FATTN_KERNEL_VEC_F32:
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||||
|
break;
|
||||||
|
case BEST_FATTN_KERNEL_VEC_F16:
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
|
break;
|
||||||
|
case BEST_FATTN_KERNEL_WMMA_F16:
|
||||||
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||||
|
break;
|
||||||
|
case BEST_FATTN_KERNEL_MMA_F16:
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
||||||
|
return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);
|
||||||
|
|||||||
@@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_GATED_LINEAR_ATTN:
|
case GGML_OP_GATED_LINEAR_ATTN:
|
||||||
case GGML_OP_RWKV_WKV7:
|
case GGML_OP_RWKV_WKV7:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT: {
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
|
||||||
return false;
|
|
||||||
#endif // FLASH_ATTN_AVAILABLE
|
|
||||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
||||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
|
||||||
if (!turing_mma_available(cc)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
|
||||||
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
|
||||||
}
|
|
||||||
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
|
|
||||||
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
|
|
||||||
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 192) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 128) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
|
||||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
|
||||||
}
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
|
|||||||
Reference in New Issue
Block a user