mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: fix race conditions FlashAttention kernels (#13438)
This commit is contained in:
		| @@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         __syncthreads(); | ||||
|  | ||||
|         // Write back combined meta data: | ||||
| #pragma unroll | ||||
|         for (int imeta = 0; imeta < nmeta; ++imeta) { | ||||
|   | ||||
| @@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16( | ||||
|     for (int j = 0; j < ncols; ++j) { | ||||
|         KQ[j*D + tid] = -HALF_MAX_HALF; | ||||
|     } | ||||
|     __syncthreads(); | ||||
|  | ||||
|     half2 VKQ[ncols] = {{0.0f, 0.0f}}; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler