mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: use aligned loads for flash attention mask (#12853)
Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
This commit is contained in:
		| @@ -201,6 +201,11 @@ void main() { | |||||||
|     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; |     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; | ||||||
|     uint32_t k_stride = p.nb11; |     uint32_t k_stride = p.nb11; | ||||||
|     uint32_t v_stride = p.nb21; |     uint32_t v_stride = p.nb21; | ||||||
|  |     // When using grouped query attention, all rows use the same mask (stride 0). | ||||||
|  |     // "p.gqa_ratio >> 16" is just a roundabout way of writing zero | ||||||
|  |     // that prevents the compiler from folding the "&" through the select | ||||||
|  |     // and breaking the alignment detection. | ||||||
|  |     uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; | ||||||
|     // hint to the compiler that strides are aligned for the aligned variant of the shader |     // hint to the compiler that strides are aligned for the aligned variant of the shader | ||||||
|     if (Clamp != gl_CooperativeMatrixClampModeConstantNV) |     if (Clamp != gl_CooperativeMatrixClampModeConstantNV) | ||||||
|     { |     { | ||||||
| @@ -209,6 +214,7 @@ void main() { | |||||||
|         k_stride &= ~7; |         k_stride &= ~7; | ||||||
|         v_stride &= ~7; |         v_stride &= ~7; | ||||||
| #endif | #endif | ||||||
|  |         m_stride &= ~7; | ||||||
|     } |     } | ||||||
|     tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); |     tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); | ||||||
|     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); |     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); | ||||||
| @@ -261,10 +267,7 @@ void main() { | |||||||
|         if (p.mask != 0) { |         if (p.mask != 0) { | ||||||
|             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); |             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); | ||||||
|             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); |             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); | ||||||
|             // When using grouped query attention, all rows use the same mask. |             tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); | ||||||
|             if (p.gqa_ratio > 1) { |  | ||||||
|                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; |             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz