mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: fix FA out-of-bounds writes (#7465)
This commit is contained in:
		| @@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16( | ||||
|     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { | ||||
|         const int j_VKQ = j_VKQ_0 + threadIdx.y; | ||||
|  | ||||
|         if (ic0 + j_VKQ >= ne01) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); | ||||
|         kqsum_j = warp_reduce_sum(kqsum_j); | ||||
|  | ||||
|   | ||||
| @@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32( | ||||
|     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { | ||||
|         const int j_VKQ = j_VKQ_0 + threadIdx.y; | ||||
|  | ||||
|         if (ic0 + j_VKQ >= ne01) { | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         float kqsum_j = kqsum[j_VKQ_0/nwarps]; | ||||
|         kqsum_j = warp_reduce_sum(kqsum_j); | ||||
|  | ||||
|   | ||||
| @@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16( | ||||
|  | ||||
| #pragma unroll | ||||
|     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { | ||||
|         if (ic0 + j_VKQ >= ne01) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; | ||||
|         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); | ||||
|  | ||||
| @@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( | ||||
|         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; | ||||
|     } | ||||
|  | ||||
|     if (parallel_blocks != 1 && tid < ncols) { | ||||
|     if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { | ||||
|         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); | ||||
|     } | ||||
| #else | ||||
|   | ||||
| @@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32( | ||||
|  | ||||
| #pragma unroll | ||||
|     for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { | ||||
|         if (ic0 + j_VKQ >= ne01) { | ||||
|             break; | ||||
|         } | ||||
|  | ||||
|         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; | ||||
|         kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); | ||||
|  | ||||
| @@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32( | ||||
|         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; | ||||
|     } | ||||
|  | ||||
|     if (parallel_blocks != 1 && tid < ncols) { | ||||
|     if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { | ||||
|         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler