mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: Implement grouped query attention in the coopmat2 FA shader (#12559)
When adjacent batches of Q share the same batches of K/V, batch them into the same workgroup. For example, when: dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1)) previously we would run 32 workgroups computing 1 result each, now we will run 8 workgroups computing 4 results each. This doesn't directly translate to better performance (at least when you have >=32 SMs), but in a subsequent change I'll enable split_k which will scale much better with 4x fewer workgroups.
This commit is contained in:
		| @@ -31,6 +31,7 @@ | ||||
|  | ||||
| #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) | ||||
| #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) | ||||
| static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } | ||||
|  | ||||
| #define VK_VENDOR_ID_AMD 0x1002 | ||||
| #define VK_VENDOR_ID_APPLE 0x106b | ||||
| @@ -501,6 +502,8 @@ struct vk_flash_attn_push_constants { | ||||
|     uint32_t n_head_log2; | ||||
|     float m0; | ||||
|     float m1; | ||||
|  | ||||
|     uint32_t gqa_ratio; | ||||
| }; | ||||
|  | ||||
| struct vk_op_push_constants { | ||||
| @@ -5402,7 +5405,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|     const uint32_t nbm1 = mask ? mask->nb[1] : 0; | ||||
|  | ||||
|     const uint32_t D = neq0; | ||||
|     const uint32_t N = neq1; | ||||
|     uint32_t N = neq1; | ||||
|     const uint32_t KV = nek1; | ||||
|  | ||||
|     GGML_ASSERT(ne0 == D); | ||||
| @@ -5460,6 +5463,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|     vk_pipeline pipeline = pipelines[aligned]; | ||||
|     assert(pipeline); | ||||
|  | ||||
|     uint32_t gqa_ratio = 1; | ||||
|     uint32_t qk_ratio = neq2 / nek2; | ||||
|     uint32_t workgroups_x = (uint32_t)neq1; | ||||
|     uint32_t workgroups_y = (uint32_t)neq2; | ||||
|     uint32_t workgroups_z = (uint32_t)neq3; | ||||
|  | ||||
|     if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && | ||||
|         qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { | ||||
|         // grouped query attention - make the N dimension equal to gqa_ratio, reduce | ||||
|         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 | ||||
|         // and change addressing calculations to index Q's dimension 2. | ||||
|         gqa_ratio = qk_ratio; | ||||
|         N = gqa_ratio; | ||||
|         workgroups_y /= N; | ||||
|     } | ||||
|  | ||||
|     if (dryrun) { | ||||
|         // Request descriptor sets | ||||
|         ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); | ||||
| @@ -5549,7 +5568,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|                                               v_stride, (uint32_t)nbv2, (uint32_t)nbv3, | ||||
|                                               nbm1, | ||||
|                                               scale, max_bias, logit_softcap, | ||||
|                                               mask != nullptr, n_head_log2, m0, m1 }; | ||||
|                                               mask != nullptr, n_head_log2, m0, m1, gqa_ratio }; | ||||
|     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, | ||||
|                                 { | ||||
|                                     vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, | ||||
| @@ -5558,7 +5577,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx | ||||
|                                     vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                     vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, | ||||
|                                 }, | ||||
|                                 sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); | ||||
|                                 sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); | ||||
| } | ||||
|  | ||||
| static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { | ||||
|   | ||||
| @@ -61,6 +61,8 @@ layout (push_constant) uniform parameter { | ||||
|     uint32_t n_head_log2; | ||||
|     float m0; | ||||
|     float m1; | ||||
|  | ||||
|     uint32_t gqa_ratio; | ||||
| } p; | ||||
|  | ||||
| layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; | ||||
| @@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele | ||||
| #define DECODEFUNC | ||||
| #endif | ||||
|  | ||||
| // Store the output when doing grouped query attention. | ||||
| // Rows index by Q's dimension 2, and the first N rows are valid. | ||||
| D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) | ||||
| { | ||||
|     if (r < N && c < D) { | ||||
|         uint32_t offset = (iq2 + r) * D + c; | ||||
|         data_o[o_offset + offset] = D_TYPE(elem); | ||||
|     } | ||||
|     return elem; | ||||
| } | ||||
|  | ||||
| // Load the slope matrix, indexed by Q's dimension 2. | ||||
| ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) | ||||
| { | ||||
|     const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); | ||||
|  | ||||
|     const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | ||||
|     const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | ||||
|  | ||||
|     return ACC_TYPE(pow(base, ACC_TYPE(exph))); | ||||
| } | ||||
|  | ||||
| void main() { | ||||
| #ifdef NEEDS_INIT_IQ_SHMEM | ||||
|     init_iq_shmem(gl_WorkGroupSize); | ||||
| @@ -116,7 +140,9 @@ void main() { | ||||
|  | ||||
|     const uint32_t i = gl_WorkGroupID.x; | ||||
|  | ||||
|     const uint32_t iq2 = gl_WorkGroupID.y; | ||||
|     // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. | ||||
|     // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. | ||||
|     const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; | ||||
|     const uint32_t iq3 = gl_WorkGroupID.z; | ||||
|  | ||||
|     // broadcast factors | ||||
| @@ -149,8 +175,10 @@ void main() { | ||||
|     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); | ||||
|     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); | ||||
|  | ||||
|     // nb?1 are already divided by the type size and are in units of elements | ||||
|     uint32_t q_stride = p.nb01; | ||||
|     // nb?1 are already divided by the type size and are in units of elements. | ||||
|     // When using grouped query attention, Q is indexed by iq2, so the stride | ||||
|     // should be nb02 (which is in bytes). | ||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; | ||||
|     uint32_t k_stride = p.nb11; | ||||
|     uint32_t v_stride = p.nb21; | ||||
|     // hint to the compiler that strides are aligned for the aligned variant of the shader | ||||
| @@ -182,16 +210,11 @@ void main() { | ||||
|     L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); | ||||
|     M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0); | ||||
|  | ||||
|     ACC_TYPE slope = ACC_TYPE(1.0); | ||||
|     coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0); | ||||
|  | ||||
|     // ALiBi | ||||
|     if (p.max_bias > 0.0f) { | ||||
|         const uint32_t h = iq2; | ||||
|  | ||||
|         const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); | ||||
|         const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); | ||||
|  | ||||
|         slope = pow(base, ACC_TYPE(exph)); | ||||
|         coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); | ||||
|     } | ||||
|  | ||||
|     [[dont_unroll]] | ||||
| @@ -215,12 +238,16 @@ void main() { | ||||
|         if (p.mask != 0) { | ||||
|             tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); | ||||
|             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | ||||
|             // When using grouped query attention, all rows use the same mask. | ||||
|             if (p.gqa_ratio > 1) { | ||||
|                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1); | ||||
|             } | ||||
|  | ||||
|             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | ||||
|  | ||||
|             coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); | ||||
|  | ||||
|             S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); | ||||
|             S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); | ||||
|         } | ||||
|  | ||||
|         // Clear padding elements to -inf, so they don't contribute to rowmax | ||||
| @@ -297,13 +324,18 @@ void main() { | ||||
|  | ||||
|     O = Ldiag*O; | ||||
|  | ||||
|     tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); | ||||
|     tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); | ||||
|  | ||||
|     // permute dimensions | ||||
|     tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); | ||||
|     uint32_t o_offset = iq3*p.ne2*p.ne1; | ||||
|  | ||||
|     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O); | ||||
|     coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); | ||||
|     if (p.gqa_ratio > 1) { | ||||
|         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); | ||||
|     } else { | ||||
|         tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); | ||||
|         tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); | ||||
|  | ||||
|         // permute dimensions | ||||
|         tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); | ||||
|  | ||||
|         coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz