CUDA: fix FA occupancy, optimize tile kernel (#15982)

This commit is contained in:
Johannes Gäßler
2025-09-17 15:32:42 +02:00
committed by GitHub
parent cd08fc3ecc
commit c959b676be
4 changed files with 361 additions and 253 deletions

View File

@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
}
template<int D> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP)
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
const float KQ_max_scale = expf(meta[l].x - kqmax);
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
@@ -836,11 +831,10 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
}
int parallel_blocks = 1;
const dim3 block_dim(warp_size, nwarps, 1);
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
int parallel_blocks = max_blocks_per_sm;
dim3 blocks_num;
if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);