mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* CUDA: add FP32 FlashAttention vector kernel * fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! CUDA: add FP32 FlashAttention vector kernel * fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
		
			
				
	
	
		
			48 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
#define FATTN_KQ_STRIDE       256
 | 
						|
#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
 | 
						|
#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 | 
						|
 | 
						|
template<int D, int parallel_blocks> // D == head size
 | 
						|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 | 
						|
__launch_bounds__(D, 1)
 | 
						|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 | 
						|
static __global__ void flash_attn_combine_results(
 | 
						|
        const float  * __restrict__ VKQ_parts,
 | 
						|
        const float2 * __restrict__ VKQ_meta,
 | 
						|
        float * __restrict__ dst) {
 | 
						|
    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
 | 
						|
    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
 | 
						|
    dst       +=                 D * gridDim.y*blockIdx.x;
 | 
						|
 | 
						|
    const int tid = threadIdx.x;
 | 
						|
    __builtin_assume(tid < D);
 | 
						|
 | 
						|
    __shared__ float2 meta[parallel_blocks];
 | 
						|
    if (tid < 2*parallel_blocks) {
 | 
						|
        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
 | 
						|
    }
 | 
						|
 | 
						|
    __syncthreads();
 | 
						|
 | 
						|
    float kqmax = meta[0].x;
 | 
						|
#pragma unroll
 | 
						|
    for (int l = 1; l < parallel_blocks; ++l) {
 | 
						|
        kqmax = max(kqmax, meta[l].x);
 | 
						|
    }
 | 
						|
 | 
						|
    float VKQ_numerator   = 0.0f;
 | 
						|
    float VKQ_denominator = 0.0f;
 | 
						|
#pragma unroll
 | 
						|
    for (int l = 0; l < parallel_blocks; ++l) {
 | 
						|
        const float diff = meta[l].x - kqmax;
 | 
						|
        const float KQ_max_scale = expf(diff);
 | 
						|
        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
 | 
						|
        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 | 
						|
 | 
						|
        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
 | 
						|
        VKQ_denominator += KQ_max_scale * meta[l].y;
 | 
						|
    }
 | 
						|
 | 
						|
    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
 | 
						|
}
 |