mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: refactor FA support/selection code (#15454)
This commit is contained in:
		| @@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results( | ||||
|     dst[tid] = VKQ_numerator / VKQ_denominator; | ||||
| } | ||||
|  | ||||
| [[noreturn]] | ||||
| static void on_no_fattn_vec_case(const int D) { | ||||
|     if (D == 64) { | ||||
|         fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); | ||||
|         fprintf(stderr, "By default only f16 KV cache is supported.\n"); | ||||
|         fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); | ||||
|         GGML_ABORT("fatal error"); | ||||
|     } else if (D == 128) { | ||||
|         fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); | ||||
|         fprintf(stderr, "Supported combinations:\n"); | ||||
|         fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n"); | ||||
|         fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n"); | ||||
|         fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n"); | ||||
|         fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); | ||||
|         GGML_ABORT("fatal error"); | ||||
|     } else { | ||||
|         fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); | ||||
|         fprintf(stderr, "Only f16 is supported.\n"); | ||||
|         GGML_ABORT("fatal error"); | ||||
|     } | ||||
| } | ||||
|  | ||||
| template <int DV, int ncols1, int ncols2> | ||||
| void launch_fattn( | ||||
|     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, | ||||
|   | ||||
| @@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg | ||||
|     FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | ||||
| #endif // GGML_CUDA_FA_ALL_QUANTS | ||||
|  | ||||
|     on_no_fattn_vec_case(Q->ne[0]); | ||||
|     GGML_ABORT("fatal error"); | ||||
| } | ||||
|  | ||||
| #define FATTN_VEC_F32_CASE(D, type_K, type_V)                               \ | ||||
| @@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg | ||||
|     FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) | ||||
| #endif // GGML_CUDA_FA_ALL_QUANTS | ||||
|  | ||||
|     on_no_fattn_vec_case(Q->ne[0]); | ||||
|     GGML_ABORT("fatal error"); | ||||
| } | ||||
|  | ||||
| void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||
| // Best FlashAttention kernel for a specific GPU: | ||||
| enum best_fattn_kernel { | ||||
|     BEST_FATTN_KERNEL_NONE     =   0, | ||||
|     BEST_FATTN_KERNEL_TILE_F32 = 200, | ||||
|     BEST_FATTN_KERNEL_TILE_F16 = 210, | ||||
|     BEST_FATTN_KERNEL_VEC_F32  = 100, | ||||
|     BEST_FATTN_KERNEL_VEC_F16  = 110, | ||||
|     BEST_FATTN_KERNEL_WMMA_F16 = 300, | ||||
|     BEST_FATTN_KERNEL_MMA_F16  = 400, | ||||
| }; | ||||
|  | ||||
| static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { | ||||
| #ifndef FLASH_ATTN_AVAILABLE | ||||
|     GGML_UNUSED(device); GGML_UNUSED(dst); | ||||
|     return BEST_FATTN_KERNEL_NONE; | ||||
| #endif// FLASH_ATTN_AVAILABLE | ||||
|  | ||||
|     const ggml_tensor * KQV   = dst; | ||||
|     const ggml_tensor * Q     = dst->src[0]; | ||||
|     const ggml_tensor * K     = dst->src[1]; | ||||
|     const ggml_tensor * V     = dst->src[2]; | ||||
|     const ggml_tensor * mask  = dst->src[3]; | ||||
|  | ||||
|     ggml_cuda_set_device(ctx.device); | ||||
|     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||
|     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; | ||||
|     const int gqa_ratio = Q->ne[2] / K->ne[2]; | ||||
|     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); | ||||
|  | ||||
|     const int cc = ggml_cuda_info().devices[device].cc; | ||||
|     const int warp_size = ggml_cuda_info().devices[device].warp_size; | ||||
|     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); | ||||
|  | ||||
| #if defined(GGML_HIP_ROCWMMA_FATTN) | ||||
|     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { | ||||
|         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||
|         return; | ||||
|     switch (K->ne[0]) { | ||||
|         case  64: | ||||
|         case 128: | ||||
|         case 256: | ||||
|             if (V->ne[0] != K->ne[0]) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
| #endif // defined(GGML_HIP_ROCWMMA_FATTN) | ||||
|  | ||||
|     if (!fast_fp16_available(cc)) { | ||||
|         if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | ||||
|             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | ||||
|         } else { | ||||
|             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | ||||
|             break; | ||||
|         case  80: | ||||
|         case  96: | ||||
|         case 112: | ||||
|             if (V->ne[0] != K->ne[0]) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|         return; | ||||
|             if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|             break; | ||||
|         case 576: | ||||
|             if (V->ne[0] != 512) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|             if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             return BEST_FATTN_KERNEL_NONE; | ||||
|     } | ||||
|  | ||||
|     if (!fp16_mma_available(cc)) { | ||||
|         if (prec == GGML_PREC_DEFAULT) { | ||||
|             if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | ||||
|                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | ||||
|             } else { | ||||
|                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); | ||||
| #ifndef GGML_CUDA_FA_ALL_QUANTS | ||||
|     if (K->type != V->type) { | ||||
|         return BEST_FATTN_KERNEL_NONE; | ||||
|     } | ||||
|         } else { | ||||
|             if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | ||||
|                 ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | ||||
|             } else { | ||||
|                 ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | ||||
| #endif // GGML_CUDA_FA_ALL_QUANTS | ||||
|  | ||||
|     switch (K->type) { | ||||
|         case GGML_TYPE_F16: | ||||
|             break; | ||||
|         case GGML_TYPE_Q4_1: | ||||
|         case GGML_TYPE_Q5_0: | ||||
|         case GGML_TYPE_Q5_1: | ||||
| #ifndef GGML_CUDA_FA_ALL_QUANTS | ||||
|             return BEST_FATTN_KERNEL_NONE; | ||||
| #endif // GGML_CUDA_FA_ALL_QUANTS | ||||
|         case GGML_TYPE_Q4_0: | ||||
|         case GGML_TYPE_Q8_0: | ||||
| #ifdef GGML_CUDA_FA_ALL_QUANTS | ||||
|             if (K->ne[0] != 128 && K->ne[0] != 64) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
| #else | ||||
|             if (K->ne[0] != 128) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|         return; | ||||
| #endif // GGML_CUDA_FA_ALL_QUANTS | ||||
|             break; | ||||
|         default: | ||||
|             return BEST_FATTN_KERNEL_NONE; | ||||
|     } | ||||
|  | ||||
|     switch (V->type) { | ||||
|         case GGML_TYPE_F16: | ||||
|             break; | ||||
|         case GGML_TYPE_Q4_1: | ||||
|         case GGML_TYPE_Q5_0: | ||||
|         case GGML_TYPE_Q5_1: | ||||
|         case GGML_TYPE_Q4_0: | ||||
|         case GGML_TYPE_Q8_0: | ||||
|             if (K->ne[0] != 128) { | ||||
|                 return BEST_FATTN_KERNEL_NONE; | ||||
|             } | ||||
|             break; | ||||
|         default: | ||||
|             return BEST_FATTN_KERNEL_NONE; | ||||
|     } | ||||
|  | ||||
|     if (mask && mask->ne[2] != 1) { | ||||
|         return BEST_FATTN_KERNEL_NONE; | ||||
|     } | ||||
|  | ||||
|     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations | ||||
|     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; | ||||
|     const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); | ||||
|     const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && | ||||
|         (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); | ||||
|     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; | ||||
|  | ||||
|     // If Turing tensor cores available, use them except for some cases with batch size 1: | ||||
|     if (turing_mma_available(cc)) { | ||||
|         const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations | ||||
|         const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; | ||||
|         const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192); | ||||
|         const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion && | ||||
|             (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); | ||||
|         if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { | ||||
|         if (prec == GGML_PREC_DEFAULT) { | ||||
|             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | ||||
|         } else { | ||||
|             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | ||||
|             if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { | ||||
|                 return BEST_FATTN_KERNEL_VEC_F16; | ||||
|             } | ||||
|         return; | ||||
|             return BEST_FATTN_KERNEL_VEC_F32; | ||||
|         } | ||||
|         return BEST_FATTN_KERNEL_MMA_F16; | ||||
|     } | ||||
|  | ||||
|     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: | ||||
|     if (fp16_mma_available(cc) && !turing_mma_available(cc)) { | ||||
|         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||
|         return; | ||||
|     // Use kernels specializes for small batch sizes if possible: | ||||
|     if (Q->ne[1] <= 8 && can_use_vector_kernel) { | ||||
|         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { | ||||
|             return BEST_FATTN_KERNEL_VEC_F16; | ||||
|         } | ||||
|         return BEST_FATTN_KERNEL_VEC_F32; | ||||
|     } | ||||
|  | ||||
|     ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); | ||||
|     // For large batch sizes, use the WMMA kernel if possible: | ||||
|     if (fp16_mma_available(cc)) { | ||||
|         return BEST_FATTN_KERNEL_WMMA_F16; | ||||
|     } | ||||
|  | ||||
|     // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: | ||||
|     if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { | ||||
|         return BEST_FATTN_KERNEL_TILE_F16; | ||||
|     } | ||||
|     return BEST_FATTN_KERNEL_TILE_F32; | ||||
| } | ||||
|  | ||||
| void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||
|     ggml_cuda_set_device(ctx.device); | ||||
|     switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { | ||||
|         case BEST_FATTN_KERNEL_NONE: | ||||
|             GGML_ABORT("fatal error"); | ||||
|         case BEST_FATTN_KERNEL_TILE_F32: | ||||
|             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | ||||
|             break; | ||||
|         case BEST_FATTN_KERNEL_TILE_F16: | ||||
|             ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); | ||||
|             break; | ||||
|         case BEST_FATTN_KERNEL_VEC_F32: | ||||
|             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | ||||
|             break; | ||||
|         case BEST_FATTN_KERNEL_VEC_F16: | ||||
|             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | ||||
|             break; | ||||
|         case BEST_FATTN_KERNEL_WMMA_F16: | ||||
|             ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||
|             break; | ||||
|         case BEST_FATTN_KERNEL_MMA_F16: | ||||
|             ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); | ||||
|             break; | ||||
|     } | ||||
| } | ||||
|  | ||||
| bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { | ||||
|     return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; | ||||
| } | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| #include "common.cuh" | ||||
|  | ||||
| void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||||
|  | ||||
| bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); | ||||
|   | ||||
| @@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | ||||
|         case GGML_OP_GATED_LINEAR_ATTN: | ||||
|         case GGML_OP_RWKV_WKV7: | ||||
|             return true; | ||||
|         case GGML_OP_FLASH_ATTN_EXT: { | ||||
| #ifndef FLASH_ATTN_AVAILABLE | ||||
|             return false; | ||||
| #endif // FLASH_ATTN_AVAILABLE | ||||
|             if (op->src[1]->ne[0] != op->src[2]->ne[0]) { | ||||
|                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; | ||||
|                 if (!turing_mma_available(cc)) { | ||||
|                     return false; | ||||
|                 } | ||||
|                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; | ||||
|                 return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; | ||||
|             } | ||||
|             // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] | ||||
|             if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) | ||||
|                     && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { | ||||
|                 return false; | ||||
|             } | ||||
|             if (op->src[0]->ne[0] == 192) { | ||||
|                 return false; | ||||
|             } | ||||
|             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { | ||||
|                 return false; | ||||
|             } | ||||
|             if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) { | ||||
|                 return true; | ||||
|             } | ||||
|             if (op->src[0]->ne[0] == 128) { | ||||
|                 return true; | ||||
|             } | ||||
|             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { | ||||
|                 return true; | ||||
|             } | ||||
|             if (op->src[3] && op->src[3]->ne[2] != 1) { | ||||
|                 return false; | ||||
|             } | ||||
|             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && | ||||
|                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; | ||||
|         } | ||||
|         case GGML_OP_FLASH_ATTN_EXT: | ||||
|             return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); | ||||
|         case GGML_OP_CROSS_ENTROPY_LOSS: | ||||
|         case GGML_OP_CROSS_ENTROPY_LOSS_BACK: | ||||
|         case GGML_OP_OPT_STEP_ADAMW: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler