CUDA: optimize FA for GQA + large batches (#12014)

This commit is contained in:
Johannes Gäßler
2025-02-22 12:20:17 +01:00
committed by GitHub
parent 335eb04a91
commit 5fa07c2f93
32 changed files with 940 additions and 411 deletions

View File

@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
const int j = blockIdx.y;
const int c = blockIdx.z;
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
dst += jt*ncols*ne02*D + channel*D;
if (jt*ncols1 + j >= ne01) {
return;
}
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
float dst_val = 0.0f;
float max_val = 0.0f;
float rowsum = 0.0f;
{
dst_val = *dst;
const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
const float2 tmp = dst_fixup[bidx0*ncols + jc];
max_val = tmp.x;
rowsum = tmp.y;
}
// Iterate over previous blocks and compute the combined results.
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val, tmp.x);
const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float diff_val = max_val - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
dst_val = scale_val*dst_val + scale_add*dst_add;
rowsum = scale_val*rowsum + scale_add*tmp.y;
max_val[j] = max_val_new;
}
max_val = max_val_new;
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
}
// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
*dst = dst_val / rowsum;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
}
// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
constexpr int ncols = ncols1 * ncols2;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@@ -763,25 +747,26 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = 2*nsm;
const int nblocks_stream_k = max_blocks;
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
@@ -793,7 +778,6 @@ void launch_fattn(
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
@@ -832,9 +816,9 @@ void launch_fattn(
if constexpr (parallel_blocks == 0) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}