mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	vulkan: Implement grouped query attention in the coopmat2 FA shader (#12559)
When adjacent batches of Q share the same batches of K/V, batch them into the same workgroup. For example, when: dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1)) previously we would run 32 workgroups computing 1 result each, now we will run 8 workgroups computing 4 results each. This doesn't directly translate to better performance (at least when you have >=32 SMs), but in a subsequent change I'll enable split_k which will scale much better with 4x fewer workgroups.
This commit is contained in:
		@@ -61,6 +61,8 @@ layout (push_constant) uniform parameter {
 | 
			
		||||
    uint32_t n_head_log2;
 | 
			
		||||
    float m0;
 | 
			
		||||
    float m1;
 | 
			
		||||
 | 
			
		||||
    uint32_t gqa_ratio;
 | 
			
		||||
} p;
 | 
			
		||||
 | 
			
		||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
 | 
			
		||||
@@ -103,6 +105,28 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
 | 
			
		||||
#define DECODEFUNC
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Store the output when doing grouped query attention.
 | 
			
		||||
// Rows index by Q's dimension 2, and the first N rows are valid.
 | 
			
		||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
 | 
			
		||||
{
 | 
			
		||||
    if (r < N && c < D) {
 | 
			
		||||
        uint32_t offset = (iq2 + r) * D + c;
 | 
			
		||||
        data_o[o_offset + offset] = D_TYPE(elem);
 | 
			
		||||
    }
 | 
			
		||||
    return elem;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Load the slope matrix, indexed by Q's dimension 2.
 | 
			
		||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
 | 
			
		||||
{
 | 
			
		||||
    const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
 | 
			
		||||
 | 
			
		||||
    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
 | 
			
		||||
    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
 | 
			
		||||
 | 
			
		||||
    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void main() {
 | 
			
		||||
#ifdef NEEDS_INIT_IQ_SHMEM
 | 
			
		||||
    init_iq_shmem(gl_WorkGroupSize);
 | 
			
		||||
@@ -116,7 +140,9 @@ void main() {
 | 
			
		||||
 | 
			
		||||
    const uint32_t i = gl_WorkGroupID.x;
 | 
			
		||||
 | 
			
		||||
    const uint32_t iq2 = gl_WorkGroupID.y;
 | 
			
		||||
    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
 | 
			
		||||
    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
 | 
			
		||||
    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
 | 
			
		||||
    const uint32_t iq3 = gl_WorkGroupID.z;
 | 
			
		||||
 | 
			
		||||
    // broadcast factors
 | 
			
		||||
@@ -149,8 +175,10 @@ void main() {
 | 
			
		||||
    tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
 | 
			
		||||
    tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
 | 
			
		||||
 | 
			
		||||
    // nb?1 are already divided by the type size and are in units of elements
 | 
			
		||||
    uint32_t q_stride = p.nb01;
 | 
			
		||||
    // nb?1 are already divided by the type size and are in units of elements.
 | 
			
		||||
    // When using grouped query attention, Q is indexed by iq2, so the stride
 | 
			
		||||
    // should be nb02 (which is in bytes).
 | 
			
		||||
    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
 | 
			
		||||
    uint32_t k_stride = p.nb11;
 | 
			
		||||
    uint32_t v_stride = p.nb21;
 | 
			
		||||
    // hint to the compiler that strides are aligned for the aligned variant of the shader
 | 
			
		||||
@@ -182,16 +210,11 @@ void main() {
 | 
			
		||||
    L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
 | 
			
		||||
    M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
 | 
			
		||||
 | 
			
		||||
    ACC_TYPE slope = ACC_TYPE(1.0);
 | 
			
		||||
    coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
 | 
			
		||||
 | 
			
		||||
    // ALiBi
 | 
			
		||||
    if (p.max_bias > 0.0f) {
 | 
			
		||||
        const uint32_t h = iq2;
 | 
			
		||||
 | 
			
		||||
        const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
 | 
			
		||||
        const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
 | 
			
		||||
 | 
			
		||||
        slope = pow(base, ACC_TYPE(exph));
 | 
			
		||||
        coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[dont_unroll]]
 | 
			
		||||
@@ -215,12 +238,16 @@ void main() {
 | 
			
		||||
        if (p.mask != 0) {
 | 
			
		||||
            tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
 | 
			
		||||
            tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
 | 
			
		||||
            // When using grouped query attention, all rows use the same mask.
 | 
			
		||||
            if (p.gqa_ratio > 1) {
 | 
			
		||||
                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
 | 
			
		||||
 | 
			
		||||
            coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
 | 
			
		||||
 | 
			
		||||
            S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
 | 
			
		||||
            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Clear padding elements to -inf, so they don't contribute to rowmax
 | 
			
		||||
@@ -297,13 +324,18 @@ void main() {
 | 
			
		||||
 | 
			
		||||
    O = Ldiag*O;
 | 
			
		||||
 | 
			
		||||
    tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
 | 
			
		||||
    tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
 | 
			
		||||
 | 
			
		||||
    // permute dimensions
 | 
			
		||||
    tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
 | 
			
		||||
    uint32_t o_offset = iq3*p.ne2*p.ne1;
 | 
			
		||||
 | 
			
		||||
    coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
 | 
			
		||||
    coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
 | 
			
		||||
    if (p.gqa_ratio > 1) {
 | 
			
		||||
        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
 | 
			
		||||
    } else {
 | 
			
		||||
        tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
 | 
			
		||||
        tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
 | 
			
		||||
 | 
			
		||||
        // permute dimensions
 | 
			
		||||
        tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
 | 
			
		||||
 | 
			
		||||
        coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user