mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
vulkan: Support FA with K/V in F32 (#16543)
This commit is contained in:
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed;
|
||||
#elif defined(A_TYPE_PACKED16)
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#undef BLOCK_SIZE
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
|
||||
Reference in New Issue
Block a user