mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-19 11:57:07 +00:00
CUDA: add attention sinks for tile and wmma (#15178)
* CUDA: add attention sinks for tile and wmma * Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma
This commit is contained in:
@@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
@@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
//Attention sink: adjust running max and sum once per head
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const half sink = __float2half(sinksf[head]);
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
|
||||
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
|
||||
kqmax[j0/nwarps] = kqmax_new_j;
|
||||
|
||||
const half val = hexp(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
|
||||
Reference in New Issue
Block a user