CUDA: fix FA occupancy, optimize tile kernel (#15982)

This commit is contained in:
Johannes Gäßler
2025-09-17 15:32:42 +02:00
committed by GitHub
parent cd08fc3ecc
commit c959b676be
4 changed files with 361 additions and 253 deletions

View File

@@ -75,6 +75,8 @@
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
// Moore Threads
@@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
}
// Maximum number of bytes that can be copied in a single instruction.
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
#ifdef GGML_USE_HIP
return 16;
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
return 16;
#else
return 8;
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // GGML_USE_HIP
}
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

View File

@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
}
template<int D> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP)
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
const float KQ_max_scale = expf(meta[l].x - kqmax);
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
@@ -836,11 +831,10 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
}
int parallel_blocks = 1;
const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
int parallel_blocks = max_blocks_per_sm;
dim3 blocks_num;
if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);

View File

@@ -2,20 +2,30 @@
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
#define FATTN_TILE_NTHREADS 256
// kq_stride == number of KQ rows to process per iteration
// kq_nbatch == number of K columns to load in parallel for KQ calculation
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA(cc)) {
switch (D) {
case 64:
return 128;
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return 64;
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return ncols <= 16 ? 64 : 32;
} else {
return 64;
}
return 32;
default:
GGML_ABORT("fatal error");
return -1;
@@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
#ifdef RDNA
switch (D) {
case 64:
return 64;
return 128;
case 128:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // RDNA
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
@@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
case 64:
return 64;
case 128:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 64 : 256;
#endif // defined(GCN) || defined(CDNA)
return 128;
default:
return -1;
}
@@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
case 64:
return 64;
case 128:
return ncols <= 16 ? 128 : 64;
case 256:
return ncols <= 16 ? 64 : 128;
return 128;
default:
return -1;
}
@@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
GGML_UNUSED_VARS(ncols, warp_size);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
#ifdef GGML_USE_HIP
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
return 256;
GGML_UNUSED_VARS(cc, ncols);
}
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
return 256;
GGML_UNUSED(ncols);
}
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
#ifdef RDNA
return 3;
#else
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
#endif // GGML_USE_HIP
return ncols <= 16 ? 3 : 2;
#endif // RDNA
GGML_UNUSED(ncols);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
static __global__ void flash_attn_tile(
const char * __restrict__ Q,
const char * __restrict__ K,
@@ -193,7 +212,7 @@ static __global__ void flash_attn_tile(
}
constexpr int warp_size = 32;
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
@@ -206,90 +225,126 @@ static __global__ void flash_attn_tile(
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
#if defined(GGML_USE_HIP)
constexpr int cpy_nb = 16;
#else
constexpr int cpy_nb = 8;
#endif // defined(GGML_USE_HIP) && defined(GCN)
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
__shared__ float KQ[ncols][kq_stride];
constexpr int cpw = ncols/nwarps; // cols per warp
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
#ifdef FAST_FP16_AVAILABLE
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
float kqmax[ncols/nwarps];
float KQ_max[cpw];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
float KQ_sum[cpw] = {0.0f};
// Load Q data, convert to FP16 if fast.
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
const int j = j0 + threadIdx.y*cpw;
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
float tmp_f[cpy_ne_D] = {0.0f};
if (ic0 + j < ne01) {
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp_f[i1] *= scale;
}
#ifdef FAST_FP16_AVAILABLE
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
half2 tmp_h2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
#else
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
float KQ_max_new[cpw];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
for (int j = 0; j < cpw; ++j) {
KQ_max_new[j] = KQ_max[j];
}
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
// KQ = K @ Q matrix multiplication:
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
#else
const float2 tmp_f2 = __half22float2(tmp_h2);
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
#endif // FAST_FP16_AVAILABLE
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
}
#else
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
half2 tmp_h2[cpy_ne_kqnb/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
float2 tmp_f2[cpy_ne_kqnb/2];
#pragma unroll
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
@@ -298,12 +353,12 @@ static __global__ void flash_attn_tile(
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
half2 K_k[kq_stride/warp_size][cpy_ne];
half2 Q_k[ncols/nwarps][cpy_ne];
half2 Q_k[cpw][cpy_ne];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
float K_k[kq_stride/warp_size][cpy_ne];
float Q_k[ncols/nwarps][cpy_ne];
float Q_k[cpw][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
@@ -311,29 +366,29 @@ static __global__ void flash_attn_tile(
const int i_KQ = i_KQ_0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
#pragma unroll
for (int k = 0; k < cpy_ne; ++k) {
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
}
}
}
@@ -344,104 +399,77 @@ static __global__ void flash_attn_tile(
}
}
// Apply logit softcap, mask, update KQ_max:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
if (use_logit_softcap) {
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
}
__syncthreads();
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
const int i = i0 + 4*threadIdx.x;
float4 val = *(const float4 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
val.z = expf(val.z - kqmax[j0/nwarps]);
val.w = expf(val.w - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y + val.z + val.w;
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
#ifdef FAST_FP16_AVAILABLE
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
half tmp[kq_stride/warp_size][softmax_iter_j];
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
float tmp[kq_stride/warp_size][softmax_iter_j];
#endif // FAST_FP16_AVAILABLE
}
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
const int i = i0 + 2*threadIdx.x;
float2 val = *(const float2 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y;
#ifdef FAST_FP16_AVAILABLE
const half2 tmp = make_half2(val.x, val.y);
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
#endif // FAST_FP16_AVAILABLE
}
} else {
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
KQ_max[j0+j1] = KQ_max_new[j0+j1];
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j][i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
#ifdef FAST_FP16_AVAILABLE
((half *) KQ[j])[i] = val;
#else
KQ[j][i] = val;
#endif // FAST_FP16_AVAILABLE
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
KQ_sum_add += val;
tmp[i0/warp_size][j1] = val;
}
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
}
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
}
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
}
}
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
// VKQ = V @ KQ matrix multiplication:
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
#pragma unroll
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
@@ -449,65 +477,96 @@ static __global__ void flash_attn_tile(
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
const int k_tile = k1 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
#else
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
#endif // FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
}
#else
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
half2 tmp_h2[cpy_ne_D/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
float2 tmp_f2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
#ifdef FAST_FP16_AVAILABLE
half2 V_k[(D/2)/warp_size];
half2 KQ_k[ncols/nwarps];
#else
float2 V_k[(D/2)/warp_size];
float KQ_k[ncols/nwarps];
#endif // FAST_FP16_AVAILABLE
half2 KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
#else
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
#endif // FAST_FP16_AVAILABLE
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
#ifdef FAST_FP16_AVAILABLE
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
#else
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
#endif // FAST_FP16_AVAILABLE
half tmp[softmax_iter_j];
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
&tmp, KQ[j][k0 + k1]);
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_k[j0+j1] = __half2half2(tmp[j1]);
}
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
#ifdef FAST_FP16_AVAILABLE
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
#else
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
#endif // FAST_FP16_AVAILABLE
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
}
}
}
#else
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
float2 V_k[(D/2)/warp_size];
float KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
&KQ_k[j0], KQ[j][k0 + k1]);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
}
}
}
#endif // FAST_FP16_AVAILABLE
__syncthreads();
}
@@ -519,69 +578,92 @@ static __global__ void flash_attn_tile(
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
for (int j0 = 0; j0 < cpw; ++j0) {
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
KQ_max[j0] = KQ_max_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
const float val = expf(sink - KQ_max[j0]);
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
KQ_sum[j0] += val;
}
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
}
if (gridDim.y == 1) {
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
}
#else
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results:
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
const int i0 = i00 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
float2 tmp[cpy_ne_D];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
}
#else
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
}
#endif // FAST_FP16_AVAILABLE
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
}
}
#else
@@ -602,15 +684,29 @@ template <int D, bool use_logit_softcap>
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
constexpr size_t nbytes_shared = 0;
#ifdef GGML_USE_HIP
if constexpr (D <= 128) {
if (Q->ne[1] > 32) {
constexpr int cols_per_block = 64;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
}
#endif // GGML_USE_HIP
if (Q->ne[1] > 16) {
constexpr int cols_per_block = 32;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
@@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
}
constexpr int cols_per_block = 16;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>

View File

@@ -158,41 +158,41 @@
#define __CUDA_ARCH__ 1300
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif
#if defined(__gfx900__) || defined(__gfx906__)
#define GCN5
#endif
#endif // defined(__gfx900__) || defined(__gfx906__)
#if defined(__gfx803__)
#define GCN4
#endif
#endif // defined(__gfx803__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA // For the entire family
#endif
#if defined(GCN5) || defined(GCN4)
#define GCN
#endif // defined(GCN5) || defined(GCN4)
#if defined(__gfx942__)
#define CDNA3
#endif
#endif // defined(__gfx942__)
#if defined(__gfx90a__)
#define CDNA2
#endif
#endif // defined(__gfx90a__)
#if defined(__gfx908__)
#define CDNA1
#endif
#endif // defined(__gfx908__)
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
#define CDNA // For the entire family
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
#if defined(__GFX12__)
#define RDNA4
#endif
#endif // defined(__GFX12__)
#if defined(__GFX11__)
#define RDNA3
#endif
#endif // defined(__GFX11__)
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
@@ -201,7 +201,11 @@
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#endif // defined(__gfx1010__) || defined(__gfx1012__)
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
#define RDNA // For the entire family
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
#ifndef __has_builtin
#define __has_builtin(x) 0