mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: use aligned loads for flash attention mask (#12853)
Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
This commit is contained in:
		| @@ -201,6 +201,11 @@ void main() { | ||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; | ||||
|     uint32_t k_stride = p.nb11; | ||||
|     uint32_t v_stride = p.nb21; | ||||
|     // When using grouped query attention, all rows use the same mask (stride 0). | ||||
|     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero | ||||
|     // that prevents the compiler from folding the "&" through the select | ||||
|     // and breaking the alignment detection. | ||||
|     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; | ||||
|     // hint to the compiler that strides are aligned for the aligned variant of the shader | ||||
|     if (Clamp != gl_CooperativeMatrixClampModeConstantNV) | ||||
|     { | ||||
| @@ -209,6 +214,7 @@ void main() { | ||||
|         k_stride &= ~7; | ||||
|         v_stride &= ~7; | ||||
| #endif | ||||
|         m_stride &= ~7; | ||||
|     } | ||||
|     tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); | ||||
|     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); | ||||
| @@ -261,10 +267,7 @@ void main() { | ||||
|         if (p.mask != 0) { | ||||
|             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); | ||||
|             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | ||||
|             // When using grouped query attention, all rows use the same mask. | ||||
|             if (p.gqa_ratio > 1) { | ||||
|                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1); | ||||
|             } | ||||
|             tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); | ||||
|  | ||||
|             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz