CUDA: broadcasting for FlashAttention mask (#14500)

This commit is contained in:
Johannes Gäßler
2025-07-02 13:42:12 +02:00
committed by Georgi Gerganov
parent 8875523eb3
commit 12a81af45f
7 changed files with 43 additions and 25 deletions

View File

@@ -6,7 +6,7 @@
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ Q,
@@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne12,
const int ne13,
const int ne31,
const int ne32,
const int nb31,
const int nb32,
const int nb01,
const int nb02,
const int nb03,
@@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16(
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2);
@@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);