mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			163 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			163 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
#include "common.cuh"
 | 
						|
 | 
						|
#include <cstdint>
 | 
						|
 | 
						|
#define FATTN_KQ_STRIDE       256
 | 
						|
#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
 | 
						|
#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 | 
						|
 | 
						|
typedef void (* fattn_kernel_t)(
 | 
						|
        const char * __restrict__ Q,
 | 
						|
        const char * __restrict__ K,
 | 
						|
        const char * __restrict__ V,
 | 
						|
        const char * __restrict__ mask,
 | 
						|
        float      * __restrict__ dst,
 | 
						|
        float2     * __restrict__ dst_meta,
 | 
						|
        const float scale,
 | 
						|
        const float max_bias,
 | 
						|
        const float m0,
 | 
						|
        const float m1,
 | 
						|
        const uint32_t n_head_log2,
 | 
						|
        const int ne00,
 | 
						|
        const int ne01,
 | 
						|
        const int ne02,
 | 
						|
        const int ne03,
 | 
						|
        const int ne10,
 | 
						|
        const int ne11,
 | 
						|
        const int ne12,
 | 
						|
        const int ne13,
 | 
						|
        const int ne31,
 | 
						|
        const int nb31,
 | 
						|
        const int nb01,
 | 
						|
        const int nb02,
 | 
						|
        const int nb03,
 | 
						|
        const int nb11,
 | 
						|
        const int nb12,
 | 
						|
        const int nb13,
 | 
						|
        const int ne0,
 | 
						|
        const int ne1,
 | 
						|
        const int ne2,
 | 
						|
        const int ne3);
 | 
						|
 | 
						|
template<int D, int parallel_blocks> // D == head size
 | 
						|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 | 
						|
__launch_bounds__(D, 1)
 | 
						|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 | 
						|
static __global__ void flash_attn_combine_results(
 | 
						|
        const float  * __restrict__ VKQ_parts,
 | 
						|
        const float2 * __restrict__ VKQ_meta,
 | 
						|
        float * __restrict__ dst) {
 | 
						|
    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
 | 
						|
    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
 | 
						|
    dst       +=                 D * gridDim.y*blockIdx.x;
 | 
						|
 | 
						|
    const int tid = threadIdx.x;
 | 
						|
    __builtin_assume(tid < D);
 | 
						|
 | 
						|
    __shared__ float2 meta[parallel_blocks];
 | 
						|
    if (tid < 2*parallel_blocks) {
 | 
						|
        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
 | 
						|
    }
 | 
						|
 | 
						|
    __syncthreads();
 | 
						|
 | 
						|
    float kqmax = meta[0].x;
 | 
						|
#pragma unroll
 | 
						|
    for (int l = 1; l < parallel_blocks; ++l) {
 | 
						|
        kqmax = max(kqmax, meta[l].x);
 | 
						|
    }
 | 
						|
 | 
						|
    float VKQ_numerator   = 0.0f;
 | 
						|
    float VKQ_denominator = 0.0f;
 | 
						|
#pragma unroll
 | 
						|
    for (int l = 0; l < parallel_blocks; ++l) {
 | 
						|
        const float diff = meta[l].x - kqmax;
 | 
						|
        const float KQ_max_scale = expf(diff);
 | 
						|
        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
 | 
						|
        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 | 
						|
 | 
						|
        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
 | 
						|
        VKQ_denominator += KQ_max_scale * meta[l].y;
 | 
						|
    }
 | 
						|
 | 
						|
    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
 | 
						|
}
 | 
						|
 | 
						|
template <int D, int parallel_blocks>
 | 
						|
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
 | 
						|
    const ggml_tensor * Q = dst->src[0];
 | 
						|
    const ggml_tensor * K = dst->src[1];
 | 
						|
    const ggml_tensor * V = dst->src[2];
 | 
						|
 | 
						|
    const ggml_tensor * mask = dst->src[3];
 | 
						|
 | 
						|
    ggml_tensor * KQV = dst;
 | 
						|
 | 
						|
    GGML_ASSERT(Q->type == GGML_TYPE_F32);
 | 
						|
    GGML_ASSERT(K->type == GGML_TYPE_F16);
 | 
						|
    GGML_ASSERT(V->type == GGML_TYPE_F16);
 | 
						|
    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 | 
						|
 | 
						|
    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 | 
						|
    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
 | 
						|
                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 | 
						|
 | 
						|
    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 | 
						|
 | 
						|
    ggml_cuda_pool & pool = ctx.pool();
 | 
						|
    cudaStream_t main_stream = ctx.stream();
 | 
						|
 | 
						|
    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
 | 
						|
    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 | 
						|
 | 
						|
    if (parallel_blocks > 1) {
 | 
						|
        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
 | 
						|
        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
 | 
						|
    }
 | 
						|
 | 
						|
    const dim3 block_dim(WARP_SIZE, nwarps, 1);
 | 
						|
    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
 | 
						|
    const int  shmem = 0;
 | 
						|
 | 
						|
    float scale    = 1.0f;
 | 
						|
    float max_bias = 0.0f;
 | 
						|
 | 
						|
    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
 | 
						|
    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
 | 
						|
 | 
						|
    const uint32_t n_head      = Q->ne[2];
 | 
						|
    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 | 
						|
 | 
						|
    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 | 
						|
    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 | 
						|
 | 
						|
    fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
 | 
						|
        (const char *) Q->data,
 | 
						|
        (const char *) K->data,
 | 
						|
        (const char *) V->data,
 | 
						|
        mask ? ((const char *) mask->data) : nullptr,
 | 
						|
        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
 | 
						|
        scale, max_bias, m0, m1, n_head_log2,
 | 
						|
        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
 | 
						|
        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
 | 
						|
        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
 | 
						|
        Q->nb[1], Q->nb[2], Q->nb[3],
 | 
						|
        K->nb[1], K->nb[2], K->nb[3],
 | 
						|
        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
 | 
						|
    );
 | 
						|
    CUDA_CHECK(cudaGetLastError());
 | 
						|
 | 
						|
    if ((parallel_blocks) == 1) {
 | 
						|
        return;
 | 
						|
    }
 | 
						|
 | 
						|
    const dim3 block_dim_combine(D, 1, 1);
 | 
						|
    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
 | 
						|
    const int  shmem_combine = 0;
 | 
						|
 | 
						|
    flash_attn_combine_results<D, parallel_blocks>
 | 
						|
        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
 | 
						|
        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
 | 
						|
    CUDA_CHECK(cudaGetLastError());
 | 
						|
}
 |