mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: fix negative KV_max values in FA (#15321)
This commit is contained in:
		| @@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max( | |||||||
|         all_inf = warp_reduce_all(all_inf); |         all_inf = warp_reduce_all(all_inf); | ||||||
|  |  | ||||||
|         if (!all_inf) { |         if (!all_inf) { | ||||||
|             KV_max_sj += FATTN_KQ_STRIDE; |  | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. | ||||||
|  |     // If the break was triggered it's the lower edge of the tile with the first non-masked values. | ||||||
|  |     // In either case, walk back the decrementation by FATTN_KQ_STRIDE. | ||||||
|  |     KV_max_sj += FATTN_KQ_STRIDE; | ||||||
|  |  | ||||||
|     if (threadIdx.x != 0) { |     if (threadIdx.x != 0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler