CUDA: refactor FA support/selection code (#15454)

This commit is contained in:
Johannes Gäßler
2025-08-20 23:14:14 +02:00
committed by GitHub
parent 7a6e91ad26
commit 13aeb7aef2
4 changed files with 165 additions and 111 deletions

View File

@@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results(
dst[tid] = VKQ_numerator / VKQ_denominator;
}
[[noreturn]]
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
fprintf(stderr, "By default only f16 KV cache is supported.\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
GGML_ABORT("fatal error");
} else if (D == 128) {
fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
fprintf(stderr, "Supported combinations:\n");
fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
GGML_ABORT("fatal error");
} else {
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
fprintf(stderr, "Only f16 is supported.\n");
GGML_ABORT("fatal error");
}
}
template <int DV, int ncols1, int ncols2>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,