CUDA: determine FA parallel blocks at runtime

This commit is contained in:
Johannes Gäßler
2025-03-06 16:47:33 +01:00
parent 3d652bfddf
commit f6711cef44
10 changed files with 166 additions and 257 deletions

View File

@@ -3218,6 +3218,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}