CUDA: refactor FA support/selection code (#15454)

This commit is contained in:
Johannes Gäßler
2025-08-20 23:14:14 +02:00
committed by GitHub
parent 7a6e91ad26
commit 13aeb7aef2
4 changed files with 165 additions and 111 deletions

View File

@@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (!turing_mma_available(cc)) {
return false;
}
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: