vulkan: support arbitrary KV dimension in flash attention (#16160)

The "Clamp" spec constant is already based on whether KV is a multiple of Bc,
so use that to control whether bounds checking is performed. Add bounds checking
to the scalar and coopmat1 paths. Coopmat2 didn't need any changes (the K/V
tensors are already optionally clamped, nothing else needed to be changed).
This commit is contained in:
Jeff Bolz
2025-09-27 16:43:39 -04:00
committed by GitHub
parent 8656f5de68
commit e6d65fb02d
3 changed files with 38 additions and 9 deletions

View File

@@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;