mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: move common FA code to flash_attn_base.comp (#13556)
* vulkan: move common FA code to flash_attn_base.comp * vulkan: move common FA index/stride setup code to flash_attn_base.comp * build fix
This commit is contained in:
		| @@ -9,60 +9,13 @@ | |||||||
| #extension GL_KHR_shader_subgroup_shuffle : enable | #extension GL_KHR_shader_subgroup_shuffle : enable | ||||||
|  |  | ||||||
| #include "types.comp" | #include "types.comp" | ||||||
|  | #include "flash_attn_base.comp" | ||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |  | ||||||
|  |  | ||||||
| layout (constant_id = 0) const uint32_t WorkGroupSize = 128; |  | ||||||
| layout (constant_id = 1) const uint32_t Br = 1; |  | ||||||
| layout (constant_id = 2) const uint32_t Bc = 32; |  | ||||||
| layout (constant_id = 3) const uint32_t D = 32; |  | ||||||
|  |  | ||||||
| layout (constant_id = 5) const uint32_t D_split = 16; |  | ||||||
| const uint32_t D_per_thread = D / D_split; | const uint32_t D_per_thread = D / D_split; | ||||||
|  |  | ||||||
| const uint32_t cols_per_iter = WorkGroupSize / D_split; | const uint32_t cols_per_iter = WorkGroupSize / D_split; | ||||||
| const uint32_t cols_per_thread = Bc / cols_per_iter; | const uint32_t cols_per_thread = Bc / cols_per_iter; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { |  | ||||||
|     uint32_t N; |  | ||||||
|     uint32_t KV; |  | ||||||
|  |  | ||||||
|     uint32_t ne1; |  | ||||||
|     uint32_t ne2; |  | ||||||
|     uint32_t ne3; |  | ||||||
|  |  | ||||||
|     uint32_t neq2; |  | ||||||
|     uint32_t neq3; |  | ||||||
|     uint32_t nek2; |  | ||||||
|     uint32_t nek3; |  | ||||||
|     uint32_t nev2; |  | ||||||
|     uint32_t nev3; |  | ||||||
|     uint32_t nem1; |  | ||||||
|  |  | ||||||
|     uint32_t nb01; |  | ||||||
|     uint32_t nb02; |  | ||||||
|     uint32_t nb03; |  | ||||||
|     uint32_t nb11; |  | ||||||
|     uint32_t nb12; |  | ||||||
|     uint32_t nb13; |  | ||||||
|     uint32_t nb21; |  | ||||||
|     uint32_t nb22; |  | ||||||
|     uint32_t nb23; |  | ||||||
|     uint32_t nb31; |  | ||||||
|  |  | ||||||
|     float scale; |  | ||||||
|     float max_bias; |  | ||||||
|     float logit_softcap; |  | ||||||
|  |  | ||||||
|     uint32_t mask; |  | ||||||
|     uint32_t n_head_log2; |  | ||||||
|     float m0; |  | ||||||
|     float m1; |  | ||||||
|  |  | ||||||
|     uint32_t gqa_ratio; |  | ||||||
|     uint32_t split_kv; |  | ||||||
|     uint32_t k_num; |  | ||||||
| } p; |  | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer Q {float data_q[];}; | layout (binding = 0) readonly buffer Q {float data_q[];}; | ||||||
| layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; | layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; | ||||||
| @@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; | |||||||
| layout (binding = 2) readonly buffer V {float16_t data_v[];}; | layout (binding = 2) readonly buffer V {float16_t data_v[];}; | ||||||
| layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; | layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; | ||||||
| layout (binding = 3) readonly buffer M {float16_t data_m[];}; | layout (binding = 3) readonly buffer M {float16_t data_m[];}; | ||||||
| layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; |  | ||||||
|  |  | ||||||
| #if defined(A_TYPE_PACKED16) |  | ||||||
| #define BINDING_IDX_K 0 |  | ||||||
| #define BINDING_IDX_V 1 |  | ||||||
| layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if defined(DATA_A_Q4_0) |  | ||||||
| #define BLOCK_BYTE_SIZE 18 |  | ||||||
|  |  | ||||||
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { |  | ||||||
|     uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); |  | ||||||
|     uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); |  | ||||||
|     uint shift = (iqs & 0x10) >> 2; |  | ||||||
|     vui_lo >>= shift; |  | ||||||
|     vui_hi >>= shift; |  | ||||||
|  |  | ||||||
|     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if defined(DATA_A_Q8_0) |  | ||||||
| #define BLOCK_BYTE_SIZE 34 |  | ||||||
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { |  | ||||||
|     const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 |  | ||||||
|     const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; |  | ||||||
|  |  | ||||||
|     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) |  | ||||||
|  |  | ||||||
| // Store the output when doing grouped query attention. | // Store the output when doing grouped query attention. | ||||||
| // Rows index by Q's dimension 2, and the first N rows are valid. | // Rows index by Q's dimension 2, and the first N rows are valid. | ||||||
| @@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY | |||||||
|     return elem; |     return elem; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Store column zero. This is used to save per-row m and L values for split_k. |  | ||||||
| ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) |  | ||||||
| { |  | ||||||
|     if (r < N && c == 0) { |  | ||||||
|         uint32_t offset = iq2 + r; |  | ||||||
|         data_o[o_offset + offset] = D_TYPE(elem); |  | ||||||
|     } |  | ||||||
|     return elem; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Load the slope matrix, indexed by Q's dimension 2. |  | ||||||
| ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) |  | ||||||
| { |  | ||||||
|     const uint32_t h = iq2 + (r % p.gqa_ratio); |  | ||||||
|  |  | ||||||
|     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); |  | ||||||
|     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); |  | ||||||
|  |  | ||||||
|     return ACC_TYPE(pow(base, ACC_TYPE(exph))); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| shared FLOAT_TYPE tmpsh[WorkGroupSize]; | shared FLOAT_TYPE tmpsh[WorkGroupSize]; | ||||||
| shared vec4 tmpshv4[WorkGroupSize]; | shared vec4 tmpshv4[WorkGroupSize]; | ||||||
|  |  | ||||||
| @@ -146,58 +45,12 @@ void main() { | |||||||
|     init_iq_shmem(gl_WorkGroupSize); |     init_iq_shmem(gl_WorkGroupSize); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     const uint32_t tid = gl_LocalInvocationIndex; |     init_indices(); | ||||||
|     const uint32_t N = p.N; |  | ||||||
|     const uint32_t KV = p.KV; |  | ||||||
|  |  | ||||||
|  |     const uint32_t tid = gl_LocalInvocationIndex; | ||||||
|     const uint32_t d_tid = gl_LocalInvocationIndex % D_split; |     const uint32_t d_tid = gl_LocalInvocationIndex % D_split; | ||||||
|     const uint32_t col_tid = gl_LocalInvocationIndex / D_split; |     const uint32_t col_tid = gl_LocalInvocationIndex / D_split; | ||||||
|  |  | ||||||
|     uint32_t i = gl_WorkGroupID.x; |  | ||||||
|     uint32_t split_k_index = 0; |  | ||||||
|  |  | ||||||
|     if (p.k_num > 1) { |  | ||||||
|         i = 0; |  | ||||||
|         split_k_index = gl_WorkGroupID.x; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     const uint32_t Tr = CEIL_DIV(N, Br); |  | ||||||
|  |  | ||||||
|     const uint32_t start_j = split_k_index * p.split_kv / Bc; |  | ||||||
|     const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); |  | ||||||
|  |  | ||||||
|     // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. |  | ||||||
|     // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. |  | ||||||
|     const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; |  | ||||||
|     const uint32_t iq3 = gl_WorkGroupID.z; |  | ||||||
|  |  | ||||||
|     // broadcast factors |  | ||||||
|     const uint32_t rk2 = p.neq2/p.nek2; |  | ||||||
|     const uint32_t rk3 = p.neq3/p.nek3; |  | ||||||
|  |  | ||||||
|     const uint32_t rv2 = p.neq2/p.nev2; |  | ||||||
|     const uint32_t rv3 = p.neq3/p.nev3; |  | ||||||
|  |  | ||||||
|     // k indices |  | ||||||
|     const uint32_t ik3 = iq3 / rk3; |  | ||||||
|     const uint32_t ik2 = iq2 / rk2; |  | ||||||
|  |  | ||||||
|     // v indices |  | ||||||
|     const uint32_t iv3 = iq3 / rv3; |  | ||||||
|     const uint32_t iv2 = iq2 / rv2; |  | ||||||
|  |  | ||||||
|     // nb?1 are already divided by the type size and are in units of elements. |  | ||||||
|     // When using grouped query attention, Q is indexed by iq2, so the stride |  | ||||||
|     // should be nb02 (which is in bytes). |  | ||||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; |  | ||||||
|     uint32_t k_stride = p.nb11; |  | ||||||
|     uint32_t v_stride = p.nb21; |  | ||||||
|     // When using grouped query attention, all rows use the same mask (stride 0). |  | ||||||
|     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero |  | ||||||
|     // that prevents the compiler from folding the "&" through the select |  | ||||||
|     // and breaking the alignment detection. |  | ||||||
|     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; |  | ||||||
|  |  | ||||||
|     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; |     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; | ||||||
|  |  | ||||||
|     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { |     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { | ||||||
|   | |||||||
							
								
								
									
										162
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | |||||||
|  |  | ||||||
|  | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
|  | layout (constant_id = 0) const uint32_t WorkGroupSize = 128; | ||||||
|  | layout (constant_id = 1) const uint32_t Br = 1; | ||||||
|  | layout (constant_id = 2) const uint32_t Bc = 32; | ||||||
|  | layout (constant_id = 3) const uint32_t D = 32; | ||||||
|  | layout (constant_id = 4) const uint32_t Clamp = 0; | ||||||
|  | layout (constant_id = 5) const uint32_t D_split = 16; | ||||||
|  |  | ||||||
|  |  | ||||||
|  | layout (push_constant) uniform parameter { | ||||||
|  |     uint32_t N; | ||||||
|  |     uint32_t KV; | ||||||
|  |  | ||||||
|  |     uint32_t ne1; | ||||||
|  |     uint32_t ne2; | ||||||
|  |     uint32_t ne3; | ||||||
|  |  | ||||||
|  |     uint32_t neq2; | ||||||
|  |     uint32_t neq3; | ||||||
|  |     uint32_t nek2; | ||||||
|  |     uint32_t nek3; | ||||||
|  |     uint32_t nev2; | ||||||
|  |     uint32_t nev3; | ||||||
|  |     uint32_t nem1; | ||||||
|  |  | ||||||
|  |     uint32_t nb01; | ||||||
|  |     uint32_t nb02; | ||||||
|  |     uint32_t nb03; | ||||||
|  |     uint32_t nb11; | ||||||
|  |     uint32_t nb12; | ||||||
|  |     uint32_t nb13; | ||||||
|  |     uint32_t nb21; | ||||||
|  |     uint32_t nb22; | ||||||
|  |     uint32_t nb23; | ||||||
|  |     uint32_t nb31; | ||||||
|  |  | ||||||
|  |     float scale; | ||||||
|  |     float max_bias; | ||||||
|  |     float logit_softcap; | ||||||
|  |  | ||||||
|  |     uint32_t mask; | ||||||
|  |     uint32_t n_head_log2; | ||||||
|  |     float m0; | ||||||
|  |     float m1; | ||||||
|  |  | ||||||
|  |     uint32_t gqa_ratio; | ||||||
|  |     uint32_t split_kv; | ||||||
|  |     uint32_t k_num; | ||||||
|  | } p; | ||||||
|  |  | ||||||
|  | layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; | ||||||
|  |  | ||||||
|  | #if defined(A_TYPE_PACKED16) | ||||||
|  | #define BINDING_IDX_K 0 | ||||||
|  | #define BINDING_IDX_V 1 | ||||||
|  | layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if defined(DATA_A_Q4_0) | ||||||
|  | #define BLOCK_BYTE_SIZE 18 | ||||||
|  |  | ||||||
|  | vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { | ||||||
|  |     uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); | ||||||
|  |     uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); | ||||||
|  |     uint shift = (iqs & 0x10) >> 2; | ||||||
|  |     vui_lo >>= shift; | ||||||
|  |     vui_hi >>= shift; | ||||||
|  |  | ||||||
|  |     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #if defined(DATA_A_Q8_0) | ||||||
|  | #define BLOCK_BYTE_SIZE 34 | ||||||
|  | vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { | ||||||
|  |     const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 | ||||||
|  |     const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; | ||||||
|  |  | ||||||
|  |     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | // Store column zero. This is used to save per-row m and L values for split_k. | ||||||
|  | ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) | ||||||
|  | { | ||||||
|  |     if (r < N && c == 0) { | ||||||
|  |         uint32_t offset = iq2 + r; | ||||||
|  |         data_o[o_offset + offset] = D_TYPE(elem); | ||||||
|  |     } | ||||||
|  |     return elem; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Load the slope matrix, indexed by Q's dimension 2. | ||||||
|  | ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) | ||||||
|  | { | ||||||
|  |     const uint32_t h = iq2 + (r % p.gqa_ratio); | ||||||
|  |  | ||||||
|  |     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | ||||||
|  |     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | ||||||
|  |  | ||||||
|  |     return ACC_TYPE(pow(base, ACC_TYPE(exph))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, | ||||||
|  |          iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, | ||||||
|  |          q_stride, k_stride, v_stride, m_stride; | ||||||
|  |  | ||||||
|  | void init_indices() | ||||||
|  | { | ||||||
|  |     N = p.N; | ||||||
|  |     KV = p.KV; | ||||||
|  |  | ||||||
|  |     i = gl_WorkGroupID.x; | ||||||
|  |     split_k_index = 0; | ||||||
|  |  | ||||||
|  |     if (p.k_num > 1) { | ||||||
|  |         i = 0; | ||||||
|  |         split_k_index = gl_WorkGroupID.x; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     Tr = CEIL_DIV(N, Br); | ||||||
|  |  | ||||||
|  |     start_j = split_k_index * p.split_kv / Bc; | ||||||
|  |     end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); | ||||||
|  |  | ||||||
|  |     // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. | ||||||
|  |     // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. | ||||||
|  |     iq2 = gl_WorkGroupID.y * p.gqa_ratio; | ||||||
|  |     iq3 = gl_WorkGroupID.z; | ||||||
|  |  | ||||||
|  |     // broadcast factors | ||||||
|  |     rk2 = p.neq2/p.nek2; | ||||||
|  |     rk3 = p.neq3/p.nek3; | ||||||
|  |  | ||||||
|  |     rv2 = p.neq2/p.nev2; | ||||||
|  |     rv3 = p.neq3/p.nev3; | ||||||
|  |  | ||||||
|  |     // k indices | ||||||
|  |     ik3 = iq3 / rk3; | ||||||
|  |     ik2 = iq2 / rk2; | ||||||
|  |  | ||||||
|  |     // v indices | ||||||
|  |     iv3 = iq3 / rv3; | ||||||
|  |     iv2 = iq2 / rv2; | ||||||
|  |  | ||||||
|  |     // nb?1 are already divided by the type size and are in units of elements. | ||||||
|  |     // When using grouped query attention, Q is indexed by iq2, so the stride | ||||||
|  |     // should be nb02 (which is in bytes). | ||||||
|  |     q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; | ||||||
|  |     k_stride = p.nb11; | ||||||
|  |     v_stride = p.nb21; | ||||||
|  |     // When using grouped query attention, all rows use the same mask (stride 0). | ||||||
|  |     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero | ||||||
|  |     // that prevents the compiler from folding the "&" through the select | ||||||
|  |     // and breaking the alignment detection. | ||||||
|  |     m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; | ||||||
|  | } | ||||||
| @@ -11,14 +11,7 @@ | |||||||
| #extension GL_KHR_cooperative_matrix : enable | #extension GL_KHR_cooperative_matrix : enable | ||||||
|  |  | ||||||
| #include "types.comp" | #include "types.comp" | ||||||
|  | #include "flash_attn_base.comp" | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |  | ||||||
|  |  | ||||||
| layout (constant_id = 1) const uint32_t Br = 1; |  | ||||||
| layout (constant_id = 2) const uint32_t Bc = 32; |  | ||||||
| layout (constant_id = 3) const uint32_t D = 32; |  | ||||||
|  |  | ||||||
| layout (constant_id = 5) const uint32_t D_split = 16; |  | ||||||
|  |  | ||||||
| const uint32_t D_per_thread = D / D_split; | const uint32_t D_per_thread = D / D_split; | ||||||
| const uint32_t row_split = 4; | const uint32_t row_split = 4; | ||||||
| @@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split; | |||||||
| const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; | const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; | ||||||
| const uint32_t cols_per_thread = Bc / cols_per_iter; | const uint32_t cols_per_thread = Bc / cols_per_iter; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { |  | ||||||
|     uint32_t N; |  | ||||||
|     uint32_t KV; |  | ||||||
|  |  | ||||||
|     uint32_t ne1; |  | ||||||
|     uint32_t ne2; |  | ||||||
|     uint32_t ne3; |  | ||||||
|  |  | ||||||
|     uint32_t neq2; |  | ||||||
|     uint32_t neq3; |  | ||||||
|     uint32_t nek2; |  | ||||||
|     uint32_t nek3; |  | ||||||
|     uint32_t nev2; |  | ||||||
|     uint32_t nev3; |  | ||||||
|     uint32_t nem1; |  | ||||||
|  |  | ||||||
|     uint32_t nb01; |  | ||||||
|     uint32_t nb02; |  | ||||||
|     uint32_t nb03; |  | ||||||
|     uint32_t nb11; |  | ||||||
|     uint32_t nb12; |  | ||||||
|     uint32_t nb13; |  | ||||||
|     uint32_t nb21; |  | ||||||
|     uint32_t nb22; |  | ||||||
|     uint32_t nb23; |  | ||||||
|     uint32_t nb31; |  | ||||||
|  |  | ||||||
|     float scale; |  | ||||||
|     float max_bias; |  | ||||||
|     float logit_softcap; |  | ||||||
|  |  | ||||||
|     uint32_t mask; |  | ||||||
|     uint32_t n_head_log2; |  | ||||||
|     float m0; |  | ||||||
|     float m1; |  | ||||||
|  |  | ||||||
|     uint32_t gqa_ratio; |  | ||||||
|     uint32_t split_kv; |  | ||||||
|     uint32_t k_num; |  | ||||||
| } p; |  | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer Q {float data_q[];}; | layout (binding = 0) readonly buffer Q {float data_q[];}; | ||||||
| layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; | layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; | ||||||
| @@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; | |||||||
| layout (binding = 2) readonly buffer V {float16_t data_v[];}; | layout (binding = 2) readonly buffer V {float16_t data_v[];}; | ||||||
| layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; | layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; | ||||||
| layout (binding = 3) readonly buffer M {float16_t data_m[];}; | layout (binding = 3) readonly buffer M {float16_t data_m[];}; | ||||||
| layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; |  | ||||||
|  |  | ||||||
| #if defined(A_TYPE_PACKED16) |  | ||||||
| #define BINDING_IDX_K 0 |  | ||||||
| #define BINDING_IDX_V 1 |  | ||||||
| layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if defined(DATA_A_Q4_0) |  | ||||||
| #define BLOCK_BYTE_SIZE 18 |  | ||||||
|  |  | ||||||
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { |  | ||||||
|     uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); |  | ||||||
|     uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); |  | ||||||
|     uint shift = (iqs & 0x10) >> 2; |  | ||||||
|     vui_lo >>= shift; |  | ||||||
|     vui_hi >>= shift; |  | ||||||
|  |  | ||||||
|     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #if defined(DATA_A_Q8_0) |  | ||||||
| #define BLOCK_BYTE_SIZE 34 |  | ||||||
| vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { |  | ||||||
|     const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 |  | ||||||
|     const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; |  | ||||||
|  |  | ||||||
|     return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) |  | ||||||
|  |  | ||||||
| // Store the output when doing grouped query attention. | // Store the output when doing grouped query attention. | ||||||
| // Rows index by Q's dimension 2, and the first N rows are valid. | // Rows index by Q's dimension 2, and the first N rows are valid. | ||||||
| @@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY | |||||||
|     return elem; |     return elem; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Store column zero. This is used to save per-row m and L values for split_k. |  | ||||||
| ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) |  | ||||||
| { |  | ||||||
|     if (r < N && c == 0) { |  | ||||||
|         uint32_t offset = iq2 + r; |  | ||||||
|         data_o[o_offset + offset] = D_TYPE(elem); |  | ||||||
|     } |  | ||||||
|     return elem; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Load the slope matrix, indexed by Q's dimension 2. |  | ||||||
| ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) |  | ||||||
| { |  | ||||||
|     const uint32_t h = iq2 + (r % p.gqa_ratio); |  | ||||||
|  |  | ||||||
|     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); |  | ||||||
|     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); |  | ||||||
|  |  | ||||||
|     return ACC_TYPE(pow(base, ACC_TYPE(exph))); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd | // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd | ||||||
| const uint32_t MatBr = 16; | const uint32_t MatBr = 16; | ||||||
| const uint32_t MatBc = 16; | const uint32_t MatBc = 16; | ||||||
| @@ -162,9 +61,9 @@ void main() { | |||||||
|     init_iq_shmem(gl_WorkGroupSize); |     init_iq_shmem(gl_WorkGroupSize); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  |     init_indices(); | ||||||
|  |  | ||||||
|     const uint32_t tid = gl_LocalInvocationIndex; |     const uint32_t tid = gl_LocalInvocationIndex; | ||||||
|     const uint32_t N = p.N; |  | ||||||
|     const uint32_t KV = p.KV; |  | ||||||
|  |  | ||||||
|     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; |     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; | ||||||
|     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; |     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; | ||||||
| @@ -173,51 +72,6 @@ void main() { | |||||||
|  |  | ||||||
| #define tile_row(r) (row_tid * rows_per_thread + (r)) | #define tile_row(r) (row_tid * rows_per_thread + (r)) | ||||||
|  |  | ||||||
|     uint32_t i = gl_WorkGroupID.x; |  | ||||||
|     uint32_t split_k_index = 0; |  | ||||||
|  |  | ||||||
|     if (p.k_num > 1) { |  | ||||||
|         i = 0; |  | ||||||
|         split_k_index = gl_WorkGroupID.x; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     const uint32_t Tr = CEIL_DIV(N, Br); |  | ||||||
|  |  | ||||||
|     const uint32_t start_j = split_k_index * p.split_kv / Bc; |  | ||||||
|     const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); |  | ||||||
|  |  | ||||||
|     // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. |  | ||||||
|     // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. |  | ||||||
|     const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; |  | ||||||
|     const uint32_t iq3 = gl_WorkGroupID.z; |  | ||||||
|  |  | ||||||
|     // broadcast factors |  | ||||||
|     const uint32_t rk2 = p.neq2/p.nek2; |  | ||||||
|     const uint32_t rk3 = p.neq3/p.nek3; |  | ||||||
|  |  | ||||||
|     const uint32_t rv2 = p.neq2/p.nev2; |  | ||||||
|     const uint32_t rv3 = p.neq3/p.nev3; |  | ||||||
|  |  | ||||||
|     // k indices |  | ||||||
|     const uint32_t ik3 = iq3 / rk3; |  | ||||||
|     const uint32_t ik2 = iq2 / rk2; |  | ||||||
|  |  | ||||||
|     // v indices |  | ||||||
|     const uint32_t iv3 = iq3 / rv3; |  | ||||||
|     const uint32_t iv2 = iq2 / rv2; |  | ||||||
|  |  | ||||||
|     // nb?1 are already divided by the type size and are in units of elements. |  | ||||||
|     // When using grouped query attention, Q is indexed by iq2, so the stride |  | ||||||
|     // should be nb02 (which is in bytes). |  | ||||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; |  | ||||||
|     uint32_t k_stride = p.nb11; |  | ||||||
|     uint32_t v_stride = p.nb21; |  | ||||||
|     // When using grouped query attention, all rows use the same mask (stride 0). |  | ||||||
|     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero |  | ||||||
|     // that prevents the compiler from folding the "&" through the select |  | ||||||
|     // and breaking the alignment detection. |  | ||||||
|     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; |  | ||||||
|  |  | ||||||
|     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; |     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; | ||||||
|  |  | ||||||
|     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { |     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { | ||||||
|   | |||||||
| @@ -18,62 +18,12 @@ | |||||||
|  |  | ||||||
| #include "types.comp" | #include "types.comp" | ||||||
| #include "dequant_funcs_cm2.comp" | #include "dequant_funcs_cm2.comp" | ||||||
|  | #include "flash_attn_base.comp" | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; |  | ||||||
|  |  | ||||||
| layout (constant_id = 1) const uint32_t Br = 32; |  | ||||||
| layout (constant_id = 2) const uint32_t Bc = 32; |  | ||||||
| layout (constant_id = 3) const uint32_t D = 32; |  | ||||||
| layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; |  | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { |  | ||||||
|     uint32_t N; |  | ||||||
|     uint32_t KV; |  | ||||||
|  |  | ||||||
|     uint32_t ne1; |  | ||||||
|     uint32_t ne2; |  | ||||||
|     uint32_t ne3; |  | ||||||
|  |  | ||||||
|     uint32_t neq2; |  | ||||||
|     uint32_t neq3; |  | ||||||
|     uint32_t nek2; |  | ||||||
|     uint32_t nek3; |  | ||||||
|     uint32_t nev2; |  | ||||||
|     uint32_t nev3; |  | ||||||
|     uint32_t nem1; |  | ||||||
|  |  | ||||||
|     uint32_t nb01; |  | ||||||
|     uint32_t nb02; |  | ||||||
|     uint32_t nb03; |  | ||||||
|     uint32_t nb11; |  | ||||||
|     uint32_t nb12; |  | ||||||
|     uint32_t nb13; |  | ||||||
|     uint32_t nb21; |  | ||||||
|     uint32_t nb22; |  | ||||||
|     uint32_t nb23; |  | ||||||
|     uint32_t nb31; |  | ||||||
|  |  | ||||||
|     float scale; |  | ||||||
|     float max_bias; |  | ||||||
|     float logit_softcap; |  | ||||||
|  |  | ||||||
|     uint32_t mask; |  | ||||||
|     uint32_t n_head_log2; |  | ||||||
|     float m0; |  | ||||||
|     float m1; |  | ||||||
|  |  | ||||||
|     uint32_t gqa_ratio; |  | ||||||
|     uint32_t split_kv; |  | ||||||
|     uint32_t k_num; |  | ||||||
| } p; |  | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | ||||||
| layout (binding = 1) readonly buffer K {uint8_t data_k[];}; | layout (binding = 1) readonly buffer K {uint8_t data_k[];}; | ||||||
| layout (binding = 2) readonly buffer V {uint8_t data_v[];}; | layout (binding = 2) readonly buffer V {uint8_t data_v[];}; | ||||||
| layout (binding = 3) readonly buffer M {uint8_t data_m[];}; | layout (binding = 3) readonly buffer M {uint8_t data_m[];}; | ||||||
| layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; |  | ||||||
|  |  | ||||||
| #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) |  | ||||||
|  |  | ||||||
| ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { | ||||||
|     return max(x, y); |     return max(x, y); | ||||||
| @@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY | |||||||
|     return elem; |     return elem; | ||||||
| } | } | ||||||
|  |  | ||||||
| // Store column zero. This is used to save per-row m and L values for split_k. |  | ||||||
| ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) |  | ||||||
| { |  | ||||||
|     if (r < N && c == 0) { |  | ||||||
|         uint32_t offset = iq2 + r; |  | ||||||
|         data_o[o_offset + offset] = D_TYPE(elem); |  | ||||||
|     } |  | ||||||
|     return elem; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Load the slope matrix, indexed by Q's dimension 2. |  | ||||||
| ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) |  | ||||||
| { |  | ||||||
|     const uint32_t h = iq2 + (r % p.gqa_ratio); |  | ||||||
|  |  | ||||||
|     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); |  | ||||||
|     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); |  | ||||||
|  |  | ||||||
|     return ACC_TYPE(pow(base, ACC_TYPE(exph))); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void main() { | void main() { | ||||||
| #ifdef NEEDS_INIT_IQ_SHMEM | #ifdef NEEDS_INIT_IQ_SHMEM | ||||||
|     init_iq_shmem(gl_WorkGroupSize); |     init_iq_shmem(gl_WorkGroupSize); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     const uint32_t N = p.N; |     init_indices(); | ||||||
|     const uint32_t KV = p.KV; |  | ||||||
|  |  | ||||||
|     uint32_t i = gl_WorkGroupID.x; |  | ||||||
|     uint32_t split_k_index = 0; |  | ||||||
|  |  | ||||||
|     if (p.k_num > 1) { |  | ||||||
|         i = 0; |  | ||||||
|         split_k_index = gl_WorkGroupID.x; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     const uint32_t Tr = CEIL_DIV(N, Br); |  | ||||||
|  |  | ||||||
|     const uint32_t start_j = split_k_index * p.split_kv / Bc; |  | ||||||
|     const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); |  | ||||||
|  |  | ||||||
|     // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. |  | ||||||
|     // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. |  | ||||||
|     const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; |  | ||||||
|     const uint32_t iq3 = gl_WorkGroupID.z; |  | ||||||
|  |  | ||||||
|     // broadcast factors |  | ||||||
|     const uint32_t rk2 = p.neq2/p.nek2; |  | ||||||
|     const uint32_t rk3 = p.neq3/p.nek3; |  | ||||||
|  |  | ||||||
|     const uint32_t rv2 = p.neq2/p.nev2; |  | ||||||
|     const uint32_t rv3 = p.neq3/p.nev3; |  | ||||||
|  |  | ||||||
|     // k indices |  | ||||||
|     const uint32_t ik3 = iq3 / rk3; |  | ||||||
|     const uint32_t ik2 = iq2 / rk2; |  | ||||||
|  |  | ||||||
|     // v indices |  | ||||||
|     const uint32_t iv3 = iq3 / rv3; |  | ||||||
|     const uint32_t iv2 = iq2 / rv2; |  | ||||||
|  |  | ||||||
|     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); |     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||||
|     tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); |     tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); | ||||||
| @@ -195,17 +90,6 @@ void main() { | |||||||
|     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); |     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); | ||||||
|     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); |     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); | ||||||
|  |  | ||||||
|     // nb?1 are already divided by the type size and are in units of elements. |  | ||||||
|     // When using grouped query attention, Q is indexed by iq2, so the stride |  | ||||||
|     // should be nb02 (which is in bytes). |  | ||||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; |  | ||||||
|     uint32_t k_stride = p.nb11; |  | ||||||
|     uint32_t v_stride = p.nb21; |  | ||||||
|     // When using grouped query attention, all rows use the same mask (stride 0). |  | ||||||
|     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero |  | ||||||
|     // that prevents the compiler from folding the "&" through the select |  | ||||||
|     // and breaking the alignment detection. |  | ||||||
|     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; |  | ||||||
|     // hint to the compiler that strides are aligned for the aligned variant of the shader |     // hint to the compiler that strides are aligned for the aligned variant of the shader | ||||||
|     if (Clamp != gl_CooperativeMatrixClampModeConstantNV) |     if (Clamp != gl_CooperativeMatrixClampModeConstantNV) | ||||||
|     { |     { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz