mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
vulkan: support arbitrary KV dimension in flash attention (#16160)
The "Clamp" spec constant is already based on whether KV is a multiple of Bc, so use that to control whether bounds checking is performed. Add bounds checking to the scalar and coopmat1 paths. Coopmat2 didn't need any changes (the K/V tensors are already optionally clamped, nothing else needed to be changed).
This commit is contained in:
@@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
const bool KV_bounds_check = Clamp != 0;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
Reference in New Issue
Block a user