vulkan: Support FA with K/V in F32 (#16543)

This commit is contained in:
Jeff Bolz
2025-10-14 08:53:37 -05:00
committed by GitHub
parent 7ea15bb64c
commit 4258e0cfe7
4 changed files with 49 additions and 8 deletions

View File

@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
#elif defined(A_TYPE_PACKED16)
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
} else {
return v_packed.v_data_packed[a_offset + ib];
}
}
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18