mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma * Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
This commit is contained in:
		| @@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|     const int sequence = blockIdx.z / ne02; |     const int sequence = blockIdx.z / ne02; | ||||||
|     const int head = blockIdx.z - sequence*ne02; |     const int head = blockIdx.z - sequence*ne02; | ||||||
|     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||||||
|     const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); |     const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); | ||||||
|     const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); |     const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); | ||||||
|     const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape |     const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape | ||||||
|     const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0); |     const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0); | ||||||
|  |     const float  * sinksf = (const float  *) (sinks); | ||||||
|  |  | ||||||
|     const int stride_KV2 = nb11 / sizeof(half2); |     const int stride_KV2 = nb11 / sizeof(half2); | ||||||
|  |  | ||||||
| @@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     //Attention sink: adjust running max and sum once per head | ||||||
|  |     if (sinksf && blockIdx.y == 0) { | ||||||
|  |         const half sink = __float2half(sinksf[head]); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |             half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); | ||||||
|  |             kqmax_new_j = warp_reduce_max(kqmax_new_j); | ||||||
|  |  | ||||||
|  |             const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); | ||||||
|  |             kqmax[j0/nwarps] = kqmax_new_j; | ||||||
|  |  | ||||||
|  |             const half val = hexp(sink - kqmax[j0/nwarps]); | ||||||
|  |             kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; | ||||||
|  |             if (threadIdx.x == 0) { | ||||||
|  |                 kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|  |                 VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     float2 * dst2 = (float2 *) dst; |     float2 * dst2 = (float2 *) dst; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|   | |||||||
| @@ -60,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|     const int sequence = blockIdx.z / ne02; |     const int sequence = blockIdx.z / ne02; | ||||||
|     const int head = blockIdx.z - sequence*ne02; |     const int head = blockIdx.z - sequence*ne02; | ||||||
|     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||||||
|     const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); |     const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); | ||||||
|     const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); |     const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); | ||||||
|     const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape |     const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape | ||||||
|     const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0); |     const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0); | ||||||
|  |     const float  * sinksf = (const float  *) (sinks); | ||||||
|  |  | ||||||
|     const int stride_KV2 = nb11 / sizeof(half2); |     const int stride_KV2 = nb11 / sizeof(half2); | ||||||
|  |  | ||||||
| @@ -252,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     //Attention sink: adjust running max and sum once per head | ||||||
|  |     if (sinksf && blockIdx.y == 0) { | ||||||
|  |         const float sink = sinksf[head]; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |             float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); | ||||||
|  |             kqmax_new_j = warp_reduce_max(kqmax_new_j); | ||||||
|  |  | ||||||
|  |             const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); | ||||||
|  |             kqmax[j0/nwarps] = kqmax_new_j; | ||||||
|  |  | ||||||
|  |             const float val = expf(sink - kqmax[j0/nwarps]); | ||||||
|  |             kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; | ||||||
|  |             if (threadIdx.x == 0) { | ||||||
|  |                 kqsum[j0/nwarps] += val; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|  |                 VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; | ||||||
|  |                 VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     float2 * dst2 = (float2 *) dst; |     float2 * dst2 = (float2 *) dst; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|   | |||||||
| @@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     const int sequence = blockIdx.z / ne02; |     const int sequence = blockIdx.z / ne02; | ||||||
|     const int head = blockIdx.z - sequence*ne02; |     const int head = blockIdx.z - sequence*ne02; | ||||||
|     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||||||
|     const float * Q_f   = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); |     const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); | ||||||
|     const half  * K_h   = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); |     const half  * K_h    = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); | ||||||
|     const half  * V_h   = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape |     const half  * V_h    = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape | ||||||
|     const half  * maskh = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0); |     const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0); | ||||||
|     const half2 * mask2 = (const half2 *)  maskh; |     const half2 * mask2  = (const half2 *)  maskh; | ||||||
|  |     const float * sinksf = (const float *) sinks; | ||||||
|  |  | ||||||
|     const int stride_Q  = nb01 / sizeof(float); |     const int stride_Q  = nb01 / sizeof(float); | ||||||
|     const int stride_KV = nb11 / sizeof(half); |     const int stride_KV = nb11 / sizeof(half); | ||||||
| @@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Apply attention sinks | ||||||
|  |     if (sinksf && blockIdx.y == 0) { | ||||||
|  |         const float sinkf = sinksf[head]; | ||||||
|  |         const half  sinkh = __float2half(sinkf); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |             const int j = j0 + threadIdx.y; | ||||||
|  |  | ||||||
|  |             if (std::is_same<KQ_acc_t, float>::value) { | ||||||
|  |                 float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf); | ||||||
|  |  | ||||||
|  |                 const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new); | ||||||
|  |                 KQ_max_f[j0/nwarps] = kqmax_new; | ||||||
|  |  | ||||||
|  |                 KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]); | ||||||
|  |  | ||||||
|  |                 const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { | ||||||
|  |                     const int i = i0 + threadIdx.x; | ||||||
|  |                     if (i0 + warp_size > D/2 && i >= D/2) break; | ||||||
|  |                     VKQ2[j*(D_padded/2) + i] *= scale_h2; | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  |                 half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]); | ||||||
|  |                 half kqmax_new = fmaxf(kqmax_old, sinkh); | ||||||
|  |                 KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new); | ||||||
|  |  | ||||||
|  |                 const half  KQ_max_scale_h = hexp(kqmax_old - kqmax_new); | ||||||
|  |                 const half2 KQ_max_scale   = __half2half2(KQ_max_scale_h); | ||||||
|  |  | ||||||
|  |                 KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale; | ||||||
|  |                 const half val = hexp(sinkh - kqmax_new); | ||||||
|  |                 KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { | ||||||
|  |                     const int i = i0 + threadIdx.x; | ||||||
|  |                     if (i0 + warp_size > D/2 && i >= D/2) break; | ||||||
|  |                     VKQ2[j*(D_padded/2) + i] *= KQ_max_scale; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int j0 = 0; j0 < ncols; j0 += nwarps) { |     for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|         const int j_VKQ = j0 + threadIdx.y; |         const int j_VKQ = j0 + threadIdx.y; | ||||||
|   | |||||||
| @@ -274,23 +274,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|     const ggml_tensor * K     = dst->src[1]; |     const ggml_tensor * K     = dst->src[1]; | ||||||
|     const ggml_tensor * V     = dst->src[2]; |     const ggml_tensor * V     = dst->src[2]; | ||||||
|     const ggml_tensor * mask  = dst->src[3]; |     const ggml_tensor * mask  = dst->src[3]; | ||||||
|     const ggml_tensor * sinks = dst->src[4]; |  | ||||||
|  |  | ||||||
|     ggml_cuda_set_device(ctx.device); |     ggml_cuda_set_device(ctx.device); | ||||||
|     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||||
|     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; |     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; | ||||||
|     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); | ||||||
|  |  | ||||||
|     // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] |  | ||||||
|     if (sinks && !fp16_mma_available(cc)) { |  | ||||||
|         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { |  | ||||||
|             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |  | ||||||
|         } else { |  | ||||||
|             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |  | ||||||
|         } |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
| #if defined(GGML_HIP_ROCWMMA_FATTN) | #if defined(GGML_HIP_ROCWMMA_FATTN) | ||||||
|     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { |     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { | ||||||
|         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Aman Gupta
					Aman Gupta