mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	vulkan: Handle updated FA dim2/3 definition (#14518)
* vulkan: Handle updated FA dim2/3 definition Pack mask boolean and n_head_log2 into a single dword to keep the push constant block under the 128B limit. * handle null mask for gqa * allow gqa with dim3>1
This commit is contained in:
		@@ -101,8 +101,8 @@ void main() {
 | 
			
		||||
    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 | 
			
		||||
#endif
 | 
			
		||||
    uint32_t m_offset = 0;
 | 
			
		||||
    if (p.nem2 != 1) {
 | 
			
		||||
        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
 | 
			
		||||
    if (p.nem2 != 1 || p.nem3 != 1) {
 | 
			
		||||
        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[dont_unroll]]
 | 
			
		||||
@@ -149,7 +149,7 @@ void main() {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (p.mask != 0) {
 | 
			
		||||
        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
 | 
			
		||||
 | 
			
		||||
            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
 | 
			
		||||
                uint32_t c = (idx + tid) % Bc;
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
 | 
			
		||||
    uint32_t nev3;
 | 
			
		||||
    uint32_t nem1;
 | 
			
		||||
    uint32_t nem2;
 | 
			
		||||
    uint32_t nem3;
 | 
			
		||||
 | 
			
		||||
    uint32_t nb01;
 | 
			
		||||
    uint32_t nb02;
 | 
			
		||||
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
 | 
			
		||||
    float max_bias;
 | 
			
		||||
    float logit_softcap;
 | 
			
		||||
 | 
			
		||||
    uint32_t mask;
 | 
			
		||||
    uint32_t n_head_log2;
 | 
			
		||||
    uint32_t mask_n_head_log2;
 | 
			
		||||
    float m0;
 | 
			
		||||
    float m1;
 | 
			
		||||
 | 
			
		||||
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
 | 
			
		||||
    uint32_t k_num;
 | 
			
		||||
} p;
 | 
			
		||||
 | 
			
		||||
#define MASK_ENABLE_BIT (1<<16)
 | 
			
		||||
#define N_LOG2_MASK 0xFFFF
 | 
			
		||||
 | 
			
		||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
 | 
			
		||||
 | 
			
		||||
#if defined(A_TYPE_PACKED16)
 | 
			
		||||
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
 | 
			
		||||
{
 | 
			
		||||
    const uint32_t h = iq2 + (r % p.gqa_ratio);
 | 
			
		||||
 | 
			
		||||
    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
 | 
			
		||||
    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
 | 
			
		||||
    uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
 | 
			
		||||
 | 
			
		||||
    const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
 | 
			
		||||
    const int      exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
 | 
			
		||||
 | 
			
		||||
    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -126,8 +126,8 @@ void main() {
 | 
			
		||||
    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 | 
			
		||||
#endif
 | 
			
		||||
    uint32_t m_offset = 0;
 | 
			
		||||
    if (p.nem2 != 1) {
 | 
			
		||||
        m_offset = (iq3 % p.nem2) * p.nem1 * KV;
 | 
			
		||||
    if (p.nem2 != 1 || p.nem3 != 1) {
 | 
			
		||||
        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[dont_unroll]]
 | 
			
		||||
@@ -182,7 +182,7 @@ void main() {
 | 
			
		||||
            barrier();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (p.mask != 0) {
 | 
			
		||||
        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
 | 
			
		||||
            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
 | 
			
		||||
                uint32_t c = (idx + tid) % Bc;
 | 
			
		||||
                uint32_t r = (idx + tid) / Bc;
 | 
			
		||||
 
 | 
			
		||||
@@ -131,8 +131,8 @@ void main() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    uint32_t m_offset = 0;
 | 
			
		||||
    if (p.nem2 != 1) {
 | 
			
		||||
        m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
 | 
			
		||||
    if (p.nem2 != 1 || p.nem3 != 1) {
 | 
			
		||||
        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    [[dont_unroll]]
 | 
			
		||||
@@ -153,7 +153,7 @@ void main() {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (p.mask != 0) {
 | 
			
		||||
        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
 | 
			
		||||
            tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
 | 
			
		||||
            tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
 | 
			
		||||
            tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user